From 4ea627f7c17cdc61b43f14f330ee4b6bf2a285af Mon Sep 17 00:00:00 2001 From: Yuening Li <62227368+Yuening-wa@users.noreply.github.com> Date: Fri, 26 Sep 2025 00:40:13 -0700 Subject: [PATCH] support int8 weight-only quantization Signed-off-by: Yuening Li <62227368+Yuening-wa@users.noreply.github.com> --- docs/source/guides/_compress_quantized_models.rst | 2 +- examples/llm_ptq/hf_ptq.py | 3 +++ examples/llm_ptq/scripts/huggingface_example.sh | 4 ++-- modelopt/torch/export/model_config.py | 3 ++- modelopt/torch/export/quant_utils.py | 12 +++++++++--- modelopt/torch/quantization/config.py | 10 ++++++++++ tests/examples/llm_ptq/test_llm_ptq.py | 1 + tests/gpu/torch/export/test_export.py | 2 ++ .../test_unified_hf_export_and_check_safetensors.py | 1 + 9 files changed, 31 insertions(+), 7 deletions(-) diff --git a/docs/source/guides/_compress_quantized_models.rst b/docs/source/guides/_compress_quantized_models.rst index fd0f7bc72..fd044556e 100644 --- a/docs/source/guides/_compress_quantized_models.rst +++ b/docs/source/guides/_compress_quantized_models.rst @@ -32,7 +32,7 @@ After PTQ, the model can be compressed with the following code: Initialize HF models with compressed weights for lower memory usage =================================================================== -When working with large language models, memory constraints can be a significant challenge. ModelOpt provides a workflow for initaializing HF models with compressed weights across multiple GPUs to dramatically reduce memory usage. +When working with large language models, memory constraints can be a significant challenge. ModelOpt provides a workflow for initializing HF models with compressed weights across multiple GPUs to dramatically reduce memory usage. For quantized formats like NVFP4, you can reduce memory usage by up to 4x compared to FP16/BF16 models. One limitation is that this workflow only works with max calibration algorithm. diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 3ef40d4b3..0ac11f2f5 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -68,6 +68,7 @@ QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = { "int8": mtq.INT8_DEFAULT_CFG, "int8_sq": mtq.INT8_SMOOTHQUANT_CFG, + "int8_wo": mtq.INT8_WEIGHT_ONLY_CFG, "fp8": mtq.FP8_DEFAULT_CFG, "int4_awq": mtq.INT4_AWQ_CFG, "w4a8_awq": mtq.W4A8_AWQ_BETA_CFG, @@ -101,6 +102,7 @@ def auto_quantize( in [ "fp8", "int8_sq", + "int8_wo", "int4_awq", "nvfp4", "nvfp4_awq", @@ -223,6 +225,7 @@ def main(args): assert ( args.qformat in [ + "int8_wo", "int4_awq", "fp8", "nvfp4", diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh index 97d14ea03..9c741c7df 100755 --- a/examples/llm_ptq/scripts/huggingface_example.sh +++ b/examples/llm_ptq/scripts/huggingface_example.sh @@ -53,9 +53,9 @@ esac IFS="," for qformat in $QFORMAT; do case $qformat in - fp8 | fp8_pc_pt | fp8_pb_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;; + fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;; *) - echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2 + echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2 exit 1 ;; esac diff --git a/modelopt/torch/export/model_config.py b/modelopt/torch/export/model_config.py index 49465b00a..306348f2c 100755 --- a/modelopt/torch/export/model_config.py +++ b/modelopt/torch/export/model_config.py @@ -29,6 +29,7 @@ QUANTIZATION_NONE = None QUANTIZATION_FP8 = "fp8" QUANTIZATION_INT8_SQ = "int8_sq" +QUANTIZATION_INT8_WO = "int8_wo" QUANTIZATION_INT4_AWQ = "int4_awq" QUANTIZATION_W4A8_AWQ = "w4a8_awq" QUANTIZATION_NVFP4 = "nvfp4" @@ -201,7 +202,7 @@ def weights_scaling_factor(self): return None if self.q.weights_scaling_factor.numel() != 1: - # for NVFP4, Int4 AWQ and Int8 SQ case, we concatenate the + # for NVFP4, Int4 AWQ, Int8 SQ and Int8 WO case, we concatenate the # q_weight_scaling_factor, k_weight_scaling_factor, v_weight_scaling_factor if self.q.weights_scaling_factor.dtype == torch.float8_e4m3fn: # For NVFP4, we recompute the weights_scaling_factor using max weights_scaling_factor2 diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 17a110de1..885a12582 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -50,6 +50,7 @@ QUANTIZATION_FP8_PC_PT, QUANTIZATION_INT4_AWQ, QUANTIZATION_INT8_SQ, + QUANTIZATION_INT8_WO, QUANTIZATION_MXFP4, QUANTIZATION_NONE, QUANTIZATION_NVFP4, @@ -454,7 +455,10 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames return QUANTIZATION_INT4_AWQ if weight_quantizer.num_bits == 8: - return QUANTIZATION_INT8_SQ + if input_quantizer is not None and input_quantizer.is_enabled: + return QUANTIZATION_INT8_SQ + else: + return QUANTIZATION_INT8_WO if weight_quantizer.num_bits == (4, 3): if weight_quantizer.block_sizes: @@ -626,6 +630,8 @@ def process_layer_quant_config(layer_config_dict): "has_zero_point": False, "pre_quant_scale": True, } + elif v == "int8_wo": + layer_config = {"quant_algo": "W8A16"} elif v == "int8_sq": layer_config = {"quant_algo": "W8A8_SQ_PER_CHANNEL"} elif v == "nvfp4": @@ -751,7 +757,7 @@ def to_quantized_weight( return (weight / weights_scaling_factor.unsqueeze(-1)).to(torch.float8_e4m3fn) return (weight / weights_scaling_factor).to(torch.float8_e4m3fn) - if quantization == QUANTIZATION_INT8_SQ: + if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]: return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8) if quantization == QUANTIZATION_FP8_PB_WO: @@ -803,7 +809,7 @@ def from_quantized_weight( torch_dtype ) - if quantization == QUANTIZATION_INT8_SQ: + if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]: return weight.to(torch_dtype) * weights_scaling_factor[:, None].to(torch_dtype) raise NotImplementedError(f"quantization format {quantization} not supported") diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 6d492ffbf..0642ed910 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -178,6 +178,15 @@ "algorithm": "smoothquant", } +INT8_WEIGHT_ONLY_CFG = { + "quant_cfg": { + "*weight_quantizer": {"num_bits": 8, "axis": 0}, + "*input_quantizer": {"enable": False}, + **_default_disabled_quantizer_cfg, + }, + "algorithm": "max", +} + FP8_DEFAULT_CFG = { "quant_cfg": { "*weight_quantizer": {"num_bits": (4, 3), "axis": None}, @@ -622,6 +631,7 @@ "INT4_BLOCKWISE_WEIGHT_ONLY_CFG", "INT8_DEFAULT_CFG", "INT8_SMOOTHQUANT_CFG", + "INT8_WEIGHT_ONLY_CFG", "MXFP4_DEFAULT_CFG", "MXFP8_DEFAULT_CFG", "MXINT8_DEFAULT_CFG", diff --git a/tests/examples/llm_ptq/test_llm_ptq.py b/tests/examples/llm_ptq/test_llm_ptq.py index 1da11a8c7..ce11dbd46 100644 --- a/tests/examples/llm_ptq/test_llm_ptq.py +++ b/tests/examples/llm_ptq/test_llm_ptq.py @@ -78,6 +78,7 @@ def test_ptq_whisper(self, command): [ PTQCommand(quant="int8_sq", kv_cache_quant="none"), PTQCommand(quant="int8_sq", kv_cache_quant="none", tp=2, pp=2), + PTQCommand(quant="int8_wo", kv_cache_quant="none"), PTQCommand(quant="int4_awq", kv_cache_quant="none"), PTQCommand(quant="w4a8_awq", kv_cache_quant="none"), PTQCommand(quant="nvfp4"), diff --git a/tests/gpu/torch/export/test_export.py b/tests/gpu/torch/export/test_export.py index f9cf04586..36d155155 100644 --- a/tests/gpu/torch/export/test_export.py +++ b/tests/gpu/torch/export/test_export.py @@ -57,6 +57,7 @@ FP8_DEFAULT_CFG, INT4_AWQ_CFG, INT8_SMOOTHQUANT_CFG, + INT8_WEIGHT_ONLY_CFG, NVFP4_AWQ_LITE_CFG, NVFP4_DEFAULT_CFG, W4A8_AWQ_BETA_CFG, @@ -323,6 +324,7 @@ def test_adjust_attn_amax_values( ("config", "expected_block_size"), [ (FP8_DEFAULT_CFG, 0), + (INT8_WEIGHT_ONLY_CFG, 0), (INT8_SMOOTHQUANT_CFG, 0), (NVFP4_DEFAULT_CFG, 16), (NVFP4_AWQ_LITE_CFG, 16), diff --git a/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py b/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py index c3d5653c5..bad5dbdac 100644 --- a/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py +++ b/tests/gpu/torch/export/test_unified_hf_export_and_check_safetensors.py @@ -38,6 +38,7 @@ ("nvfp4_awq", "tiny_llama-nvfp4-awq", True, False, True, True), ("int4_awq", "tiny_llama-int4-awq", True, False, True, True), ("w4a8_awq", "tiny_llama-w4a8-awq", True, False, True, True), + ("int8_wo", "tiny_llama-int8-wo", False, False, False, False), ], ) def test_unified_hf_export_and_check_safetensors(