Skip to content

Commit 12d955e

Browse files
committed
support int8 weight-only quantization
Signed-off-by: Yuening Li <[email protected]>
1 parent b233ad1 commit 12d955e

File tree

8 files changed

+33
-9
lines changed

8 files changed

+33
-9
lines changed

docs/source/guides/_compress_quantized_models.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ After PTQ, the model can be compressed with the following code:
3232
Initialize HF models with compressed weights for lower memory usage
3333
===================================================================
3434

35-
When working with large language models, memory constraints can be a significant challenge. ModelOpt provides a workflow for initaializing HF models with compressed weights across multiple GPUs to dramatically reduce memory usage.
35+
When working with large language models, memory constraints can be a significant challenge. ModelOpt provides a workflow for initializing HF models with compressed weights across multiple GPUs to dramatically reduce memory usage.
3636

3737
For quantized formats like NVFP4, you can reduce memory usage by up to 4x compared to FP16/BF16 models. One limitation is that this workflow only works with max calibration algorithm.
3838

examples/llm_ptq/hf_ptq.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = {
6262
"int8": mtq.INT8_DEFAULT_CFG,
6363
"int8_sq": mtq.INT8_SMOOTHQUANT_CFG,
64+
"int8_wo": mtq.INT8_WEIGHT_ONLY_CFG,
6465
"fp8": mtq.FP8_DEFAULT_CFG,
6566
"int4_awq": mtq.INT4_AWQ_CFG,
6667
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
@@ -94,6 +95,7 @@ def auto_quantize(
9495
qformat
9596
in [
9697
"fp8",
98+
"int8_wo",
9799
"int4_awq",
98100
"nvfp4",
99101
"nvfp4_awq",
@@ -229,6 +231,8 @@ def main(args):
229231
assert (
230232
args.qformat
231233
in [
234+
"int8",
235+
"int8_wo",
232236
"int4_awq",
233237
"fp8",
234238
"nvfp4",

examples/llm_ptq/scripts/huggingface_example.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ if [ "$EXPORT_FORMAT" = "hf" ]; then
4545
IFS=","
4646
for qformat in $QFORMAT; do
4747
case $qformat in
48-
fp16 | bf16 | fp8 | fp8_pc_pt | fp8_pb_wo | int4_awq | nvfp4 | nvfp4_awq | w4a8_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
48+
fp16 | bf16 | fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int4_awq | nvfp4 | nvfp4_awq | w4a8_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
4949
*)
50-
echo "Unsupported quant argument: Expected one of: [fp16, bf16, fp8, fp8_pc_pt, fp8_pb_wo, int4_awq, nvfp4, nvfp4_awq, w4a8_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2
50+
echo "Unsupported quant argument: Expected one of: [fp16, bf16, fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int4_awq, nvfp4, nvfp4_awq, w4a8_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2
5151
exit 1
5252
;;
5353
esac
@@ -74,9 +74,9 @@ esac
7474
IFS=","
7575
for qformat in $QFORMAT; do
7676
case $qformat in
77-
fp8 | fp8_pc_pt | fp8_pb_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
77+
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
7878
*)
79-
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2
79+
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2
8080
exit 1
8181
;;
8282
esac

modelopt/torch/export/model_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
QUANTIZATION_NONE = None
3030
QUANTIZATION_FP8 = "fp8"
3131
QUANTIZATION_INT8_SQ = "int8_sq"
32+
QUANTIZATION_INT8_WO = "int8_wo"
3233
QUANTIZATION_INT4_AWQ = "int4_awq"
3334
QUANTIZATION_W4A8_AWQ = "w4a8_awq"
3435
QUANTIZATION_NVFP4 = "nvfp4"
@@ -201,7 +202,7 @@ def weights_scaling_factor(self):
201202
return None
202203

203204
if self.q.weights_scaling_factor.numel() != 1:
204-
# for NVFP4, Int4 AWQ and Int8 SQ case, we concatenate the
205+
# for NVFP4, Int4 AWQ, Int8 SQ and Int8 WO case, we concatenate the
205206
# q_weight_scaling_factor, k_weight_scaling_factor, v_weight_scaling_factor
206207
if self.q.weights_scaling_factor.dtype == torch.float8_e4m3fn:
207208
# For NVFP4, we recompute the weights_scaling_factor using max weights_scaling_factor2

modelopt/torch/export/quant_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
QUANTIZATION_FP8_PC_PT,
5151
QUANTIZATION_INT4_AWQ,
5252
QUANTIZATION_INT8_SQ,
53+
QUANTIZATION_INT8_WO,
5354
QUANTIZATION_MXFP4,
5455
QUANTIZATION_NONE,
5556
QUANTIZATION_NVFP4,
@@ -454,7 +455,10 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
454455
return QUANTIZATION_INT4_AWQ
455456

456457
if weight_quantizer.num_bits == 8:
457-
return QUANTIZATION_INT8_SQ
458+
if input_quantizer is not None and input_quantizer.is_enabled:
459+
return QUANTIZATION_INT8_SQ
460+
else:
461+
return QUANTIZATION_INT8_WO
458462

459463
if weight_quantizer.num_bits == (4, 3):
460464
if weight_quantizer.block_sizes:
@@ -626,6 +630,8 @@ def process_layer_quant_config(layer_config_dict):
626630
"has_zero_point": False,
627631
"pre_quant_scale": True,
628632
}
633+
elif v == "int8_wo":
634+
layer_config = {"quant_algo": "W8A16"}
629635
elif v == "int8_sq":
630636
layer_config = {"quant_algo": "W8A8_SQ_PER_CHANNEL"}
631637
elif v == "nvfp4":
@@ -751,7 +757,7 @@ def to_quantized_weight(
751757
return (weight / weights_scaling_factor.unsqueeze(-1)).to(torch.float8_e4m3fn)
752758
return (weight / weights_scaling_factor).to(torch.float8_e4m3fn)
753759

754-
if quantization == QUANTIZATION_INT8_SQ:
760+
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
755761
return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8)
756762

757763
if quantization == QUANTIZATION_FP8_PB_WO:
@@ -803,7 +809,7 @@ def from_quantized_weight(
803809
torch_dtype
804810
)
805811

806-
if quantization == QUANTIZATION_INT8_SQ:
812+
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
807813
return weight.to(torch_dtype) * weights_scaling_factor[:, None].to(torch_dtype)
808814

809815
raise NotImplementedError(f"quantization format {quantization} not supported")

modelopt/torch/quantization/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,15 @@
178178
"algorithm": "smoothquant",
179179
}
180180

181+
INT8_WEIGHT_ONLY_CFG = {
182+
"quant_cfg": {
183+
"*weight_quantizer": {"num_bits": 8, "axis": 0},
184+
"*input_quantizer": {"enable": False},
185+
**_default_disabled_quantizer_cfg,
186+
},
187+
"algorithm": "max",
188+
}
189+
181190
FP8_DEFAULT_CFG = {
182191
"quant_cfg": {
183192
"*weight_quantizer": {"num_bits": (4, 3), "axis": None},
@@ -622,6 +631,7 @@
622631
"INT4_BLOCKWISE_WEIGHT_ONLY_CFG",
623632
"INT8_DEFAULT_CFG",
624633
"INT8_SMOOTHQUANT_CFG",
634+
"INT8_WEIGHT_ONLY_CFG",
625635
"MXFP4_DEFAULT_CFG",
626636
"MXFP8_DEFAULT_CFG",
627637
"MXINT8_DEFAULT_CFG",

tests/examples/llm_ptq/test_llm_ptq.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def llama_path(tiny_llama_path):
9595
PTQCommand(quant="bf16"),
9696
PTQCommand(quant="int8_sq"),
9797
# ("int8_sq", "tensorrt_llm", "sparsegpt"),
98+
PTQCommand(quant="int8_wo", export_fmt="hf"),
9899
PTQCommand(quant="int4_awq"),
99100
PTQCommand(quant="int4_awq", export_fmt="hf"),
100101
PTQCommand(quant="nvfp4"),

tests/gpu/torch/export/test_export.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
FP8_DEFAULT_CFG,
5858
INT4_AWQ_CFG,
5959
INT8_SMOOTHQUANT_CFG,
60+
INT8_WEIGHT_ONLY_CFG,
6061
NVFP4_AWQ_LITE_CFG,
6162
NVFP4_DEFAULT_CFG,
6263
W4A8_AWQ_BETA_CFG,
@@ -323,6 +324,7 @@ def test_adjust_attn_amax_values(
323324
("config", "expected_block_size"),
324325
[
325326
(FP8_DEFAULT_CFG, 0),
327+
(INT8_WEIGHT_ONLY_CFG, 0),
326328
(INT8_SMOOTHQUANT_CFG, 0),
327329
(NVFP4_DEFAULT_CFG, 16),
328330
(NVFP4_AWQ_LITE_CFG, 16),

0 commit comments

Comments
 (0)