Skip to content

Commit aa960ea

Browse files
committed
add int8 weight-only quant support
Signed-off-by: Yuening Li <[email protected]>
1 parent 5f190eb commit aa960ea

File tree

7 files changed

+41
-10
lines changed

7 files changed

+41
-10
lines changed

docs/source/guides/_compress_quantized_models.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ After PTQ, the model can be compressed with the following code:
3232
Initialize HF models with compressed weights for lower memory usage
3333
===================================================================
3434

35-
When working with large language models, memory constraints can be a significant challenge. ModelOpt provides a workflow for initaializing HF models with compressed weights across multiple GPUs to dramatically reduce memory usage.
35+
When working with large language models, memory constraints can be a significant challenge. ModelOpt provides a workflow for initializing HF models with compressed weights across multiple GPUs to dramatically reduce memory usage.
3636

3737
For quantized formats like NVFP4, you can reduce memory usage by up to 4x compared to FP16/BF16 models. One limitation is that this workflow only works with max calibration algorithm.
3838

examples/llm_ptq/hf_ptq.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = {
5959
"int8": mtq.INT8_DEFAULT_CFG,
6060
"int8_sq": mtq.INT8_SMOOTHQUANT_CFG,
61+
"int8_wo": mtq.INT8_WEIGHT_ONLY_CFG,
6162
"fp8": mtq.FP8_DEFAULT_CFG,
6263
"int4_awq": mtq.INT4_AWQ_CFG,
6364
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
@@ -87,7 +88,16 @@ def auto_quantize(
8788
if args.export_fmt == "hf":
8889
assert all(
8990
qformat
90-
in ["fp8", "int4_awq", "nvfp4", "nvfp4_awq", "w4a8_awq", "fp8_pb_wo", "w4a8_mxfp4_fp8"]
91+
in [
92+
"fp8",
93+
"int8_wo",
94+
"int4_awq",
95+
"nvfp4",
96+
"nvfp4_awq",
97+
"w4a8_awq",
98+
"fp8_pb_wo",
99+
"w4a8_mxfp4_fp8",
100+
]
91101
for qformat in qformat_list
92102
), (
93103
"One or more quantization formats provided are not supported for unified checkpoint export"
@@ -215,6 +225,8 @@ def main(args):
215225
assert (
216226
args.qformat
217227
in [
228+
"int8",
229+
"int8_wo",
218230
"int4_awq",
219231
"fp8",
220232
"nvfp4",

examples/llm_ptq/scripts/huggingface_example.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ if [ "$EXPORT_FORMAT" = "hf" ]; then
4545
IFS=","
4646
for qformat in $QFORMAT; do
4747
case $qformat in
48-
fp16 | bf16 | fp8 | fp8_pc_pt | fp8_pb_wo | int4_awq | nvfp4 | nvfp4_awq | w4a8_awq | w4a8_mxfp4_fp8) ;;
48+
fp16 | bf16 | fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int4_awq | nvfp4 | nvfp4_awq | w4a8_awq | w4a8_mxfp4_fp8) ;;
4949
*)
50-
echo "Unsupported quant argument: Expected one of: [fp16, bf16, fp8, fp8_pc_pt, fp8_pb_wo, int4_awq, nvfp4, nvfp4_awq, w4a8_awq, w4a8_mxfp4_fp8]" >&2
50+
echo "Unsupported quant argument: Expected one of: [fp16, bf16, fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int4_awq, nvfp4, nvfp4_awq, w4a8_awq, w4a8_mxfp4_fp8]" >&2
5151
exit 1
5252
;;
5353
esac
@@ -74,9 +74,9 @@ esac
7474
IFS=","
7575
for qformat in $QFORMAT; do
7676
case $qformat in
77-
fp8 | fp8_pc_pt | fp8_pb_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_mxfp4_fp8) ;;
77+
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_mxfp4_fp8) ;;
7878
*)
79-
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_mxfp4_fp8]" >&2
79+
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_mxfp4_fp8]" >&2
8080
exit 1
8181
;;
8282
esac

modelopt/torch/export/model_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
QUANTIZATION_NONE = None
3030
QUANTIZATION_FP8 = "fp8"
3131
QUANTIZATION_INT8_SQ = "int8_sq"
32+
QUANTIZATION_INT8_WO = "int8_wo"
3233
QUANTIZATION_INT4_AWQ = "int4_awq"
3334
QUANTIZATION_W4A8_AWQ = "w4a8_awq"
3435
QUANTIZATION_NVFP4 = "nvfp4"
@@ -200,7 +201,7 @@ def weights_scaling_factor(self):
200201
return None
201202

202203
if self.q.weights_scaling_factor.numel() != 1:
203-
# for NVFP4, Int4 AWQ and Int8 SQ case, we concatenate the
204+
# for NVFP4, Int4 AWQ, Int8 SQ and Int8_WO case, we concatenate the
204205
# q_weight_scaling_factor, k_weight_scaling_factor, v_weight_scaling_factor
205206
if self.q.weights_scaling_factor.dtype == torch.float8_e4m3fn:
206207
# For NVFP4, we recompute the weights_scaling_factor using max weights_scaling_factor2

modelopt/torch/export/quant_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
QUANTIZATION_FP8_PC_PT,
5151
QUANTIZATION_INT4_AWQ,
5252
QUANTIZATION_INT8_SQ,
53+
QUANTIZATION_INT8_WO,
5354
QUANTIZATION_MXFP4,
5455
QUANTIZATION_NONE,
5556
QUANTIZATION_NVFP4,
@@ -419,7 +420,10 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
419420
return QUANTIZATION_INT4_AWQ
420421

421422
if weight_quantizer.num_bits == 8:
422-
return QUANTIZATION_INT8_SQ
423+
if input_quantizer is not None and input_quantizer.is_enabled:
424+
return QUANTIZATION_INT8_SQ
425+
else:
426+
return QUANTIZATION_INT8_WO
423427

424428
if weight_quantizer.num_bits == (4, 3):
425429
if weight_quantizer.block_sizes:
@@ -578,6 +582,8 @@ def process_layer_quant_config(layer_config_dict):
578582
"has_zero_point": False,
579583
"pre_quant_scale": True,
580584
}
585+
elif v == "int8_wo":
586+
layer_config = {"quant_algo": "W8A16"}
581587
elif v == "int8_sq":
582588
layer_config = {"quant_algo": "W8A8_SQ_PER_CHANNEL"}
583589
elif v == "nvfp4":
@@ -696,7 +702,7 @@ def to_quantized_weight(
696702
return (weight / weights_scaling_factor.unsqueeze(-1)).to(torch.float8_e4m3fn)
697703
return (weight / weights_scaling_factor).to(torch.float8_e4m3fn)
698704

699-
if quantization == QUANTIZATION_INT8_SQ:
705+
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
700706
return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8)
701707

702708
if quantization == QUANTIZATION_FP8_PB_WO:
@@ -748,7 +754,7 @@ def from_quantized_weight(
748754
torch_dtype
749755
)
750756

751-
if quantization == QUANTIZATION_INT8_SQ:
757+
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
752758
return weight.to(torch_dtype) * weights_scaling_factor[:, None].to(torch_dtype)
753759

754760
raise NotImplementedError(f"quantization format {quantization} not supported")

modelopt/torch/quantization/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,15 @@
178178
"algorithm": "smoothquant",
179179
}
180180

181+
INT8_WEIGHT_ONLY_CFG = {
182+
"quant_cfg": {
183+
"*weight_quantizer": {"num_bits": 8, "axis": 0},
184+
"*input_quantizer": {"enable": False},
185+
**_default_disabled_quantizer_cfg,
186+
},
187+
"algorithm": "max",
188+
}
189+
181190
FP8_DEFAULT_CFG = {
182191
"quant_cfg": {
183192
"*weight_quantizer": {"num_bits": (4, 3), "axis": None},
@@ -574,6 +583,7 @@
574583
"INT4_BLOCKWISE_WEIGHT_ONLY_CFG",
575584
"INT8_DEFAULT_CFG",
576585
"INT8_SMOOTHQUANT_CFG",
586+
"INT8_WEIGHT_ONLY_CFG",
577587
"MXFP4_DEFAULT_CFG",
578588
"MXFP8_DEFAULT_CFG",
579589
"MXINT8_DEFAULT_CFG",

tests/gpu/torch/export/test_export.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
FP8_DEFAULT_CFG,
5858
INT4_AWQ_CFG,
5959
INT8_SMOOTHQUANT_CFG,
60+
INT8_WEIGHT_ONLY_CFG,
6061
NVFP4_AWQ_LITE_CFG,
6162
NVFP4_DEFAULT_CFG,
6263
W4A8_AWQ_BETA_CFG,
@@ -323,6 +324,7 @@ def test_adjust_attn_amax_values(
323324
("config", "expected_block_size"),
324325
[
325326
(FP8_DEFAULT_CFG, 0),
327+
(INT8_WEIGHT_ONLY_CFG, 0),
326328
(INT8_SMOOTHQUANT_CFG, 0),
327329
(NVFP4_DEFAULT_CFG, 16),
328330
(NVFP4_AWQ_LITE_CFG, 16),

0 commit comments

Comments
 (0)