Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/guides/_compress_quantized_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ After PTQ, the model can be compressed with the following code:
Initialize HF models with compressed weights for lower memory usage
===================================================================

When working with large language models, memory constraints can be a significant challenge. ModelOpt provides a workflow for initaializing HF models with compressed weights across multiple GPUs to dramatically reduce memory usage.
When working with large language models, memory constraints can be a significant challenge. ModelOpt provides a workflow for initializing HF models with compressed weights across multiple GPUs to dramatically reduce memory usage.

For quantized formats like NVFP4, you can reduce memory usage by up to 4x compared to FP16/BF16 models. One limitation is that this workflow only works with max calibration algorithm.

Expand Down
3 changes: 3 additions & 0 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = {
"int8": mtq.INT8_DEFAULT_CFG,
"int8_sq": mtq.INT8_SMOOTHQUANT_CFG,
"int8_wo": mtq.INT8_WEIGHT_ONLY_CFG,
"fp8": mtq.FP8_DEFAULT_CFG,
"int4_awq": mtq.INT4_AWQ_CFG,
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
Expand Down Expand Up @@ -101,6 +102,7 @@ def auto_quantize(
in [
"fp8",
"int8_sq",
"int8_wo",
"int4_awq",
"nvfp4",
"nvfp4_awq",
Expand Down Expand Up @@ -223,6 +225,7 @@ def main(args):
assert (
args.qformat
in [
"int8_wo",
"int4_awq",
"fp8",
"nvfp4",
Expand Down
4 changes: 2 additions & 2 deletions examples/llm_ptq/scripts/huggingface_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ esac
IFS=","
for qformat in $QFORMAT; do
case $qformat in
fp8 | fp8_pc_pt | fp8_pb_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
*)
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2
exit 1
Comment on lines +56 to 59
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

General quant allowlist also needs ‘int8’ or guard by export_fmt.

Currently this block runs for HF too and rejects “int8”. Minimal fix: add “int8”.

Apply:

-    fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
+    fp8 | fp8_pc_pt | fp8_pb_wo | int8 | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;

Alternative (cleaner): wrap this whole validation in if [ "$EXPORT_FORMAT" != "hf" ]; then ... fi so HF path is validated only once.

🤖 Prompt for AI Agents
In examples/llm_ptq/scripts/huggingface_example.sh around lines 77-80, the quant
argument validation excludes "int8" and incorrectly rejects it for the HF path;
either add "int8" to the list of allowed quant values in the case pattern or,
preferably, wrap this entire validation block in a guard so it only runs when
EXPORT_FORMAT is not "hf" (i.e., surround the case...esac with if [
"$EXPORT_FORMAT" != "hf" ]; then ... fi), ensuring HF exports are not blocked
and that "int8" remains accepted where appropriate.

;;
esac
Expand Down
3 changes: 2 additions & 1 deletion modelopt/torch/export/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
QUANTIZATION_NONE = None
QUANTIZATION_FP8 = "fp8"
QUANTIZATION_INT8_SQ = "int8_sq"
QUANTIZATION_INT8_WO = "int8_wo"
QUANTIZATION_INT4_AWQ = "int4_awq"
QUANTIZATION_W4A8_AWQ = "w4a8_awq"
QUANTIZATION_NVFP4 = "nvfp4"
Expand Down Expand Up @@ -201,7 +202,7 @@ def weights_scaling_factor(self):
return None

if self.q.weights_scaling_factor.numel() != 1:
# for NVFP4, Int4 AWQ and Int8 SQ case, we concatenate the
# for NVFP4, Int4 AWQ, Int8 SQ and Int8 WO case, we concatenate the
# q_weight_scaling_factor, k_weight_scaling_factor, v_weight_scaling_factor
if self.q.weights_scaling_factor.dtype == torch.float8_e4m3fn:
# For NVFP4, we recompute the weights_scaling_factor using max weights_scaling_factor2
Expand Down
12 changes: 9 additions & 3 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
QUANTIZATION_FP8_PC_PT,
QUANTIZATION_INT4_AWQ,
QUANTIZATION_INT8_SQ,
QUANTIZATION_INT8_WO,
QUANTIZATION_MXFP4,
QUANTIZATION_NONE,
QUANTIZATION_NVFP4,
Expand Down Expand Up @@ -454,7 +455,10 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
return QUANTIZATION_INT4_AWQ

if weight_quantizer.num_bits == 8:
return QUANTIZATION_INT8_SQ
if input_quantizer is not None and input_quantizer.is_enabled:
return QUANTIZATION_INT8_SQ
else:
return QUANTIZATION_INT8_WO

if weight_quantizer.num_bits == (4, 3):
if weight_quantizer.block_sizes:
Expand Down Expand Up @@ -626,6 +630,8 @@ def process_layer_quant_config(layer_config_dict):
"has_zero_point": False,
"pre_quant_scale": True,
}
elif v == "int8_wo":
layer_config = {"quant_algo": "W8A16"}
elif v == "int8_sq":
layer_config = {"quant_algo": "W8A8_SQ_PER_CHANNEL"}
elif v == "nvfp4":
Expand Down Expand Up @@ -751,7 +757,7 @@ def to_quantized_weight(
return (weight / weights_scaling_factor.unsqueeze(-1)).to(torch.float8_e4m3fn)
return (weight / weights_scaling_factor).to(torch.float8_e4m3fn)

if quantization == QUANTIZATION_INT8_SQ:
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8)

if quantization == QUANTIZATION_FP8_PB_WO:
Expand Down Expand Up @@ -803,7 +809,7 @@ def from_quantized_weight(
torch_dtype
)

if quantization == QUANTIZATION_INT8_SQ:
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
return weight.to(torch_dtype) * weights_scaling_factor[:, None].to(torch_dtype)

Comment on lines +812 to 814
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

INT8 unpack path also misses 3D (MoE) weights.

Mirror the 3D case to restore full-precision weights correctly.

Apply:

-    if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
-        return weight.to(torch_dtype) * weights_scaling_factor[:, None].to(torch_dtype)
+    if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
+        if weight.dim() == 3:
+            return weight.to(torch_dtype) * weights_scaling_factor.unsqueeze(-1).to(torch_dtype)
+        elif weight.dim() == 2:
+            return weight.to(torch_dtype) * weights_scaling_factor[:, None].to(torch_dtype)
+        else:
+            raise NotImplementedError("INT8 dequantization expects 2D or 3D weight tensors")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
return weight.to(torch_dtype) * weights_scaling_factor[:, None].to(torch_dtype)
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
if weight.dim() == 3:
return weight.to(torch_dtype) * weights_scaling_factor.unsqueeze(-1).to(torch_dtype)
elif weight.dim() == 2:
return weight.to(torch_dtype) * weights_scaling_factor[:, None].to(torch_dtype)
else:
raise NotImplementedError("INT8 dequantization expects 2D or 3D weight tensors")

raise NotImplementedError(f"quantization format {quantization} not supported")
Expand Down
10 changes: 10 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,15 @@
"algorithm": "smoothquant",
}

INT8_WEIGHT_ONLY_CFG = {
"quant_cfg": {
"*weight_quantizer": {"num_bits": 8, "axis": 0},
"*input_quantizer": {"enable": False},
**_default_disabled_quantizer_cfg,
},
"algorithm": "max",
}

FP8_DEFAULT_CFG = {
"quant_cfg": {
"*weight_quantizer": {"num_bits": (4, 3), "axis": None},
Expand Down Expand Up @@ -622,6 +631,7 @@
"INT4_BLOCKWISE_WEIGHT_ONLY_CFG",
"INT8_DEFAULT_CFG",
"INT8_SMOOTHQUANT_CFG",
"INT8_WEIGHT_ONLY_CFG",
"MXFP4_DEFAULT_CFG",
"MXFP8_DEFAULT_CFG",
"MXINT8_DEFAULT_CFG",
Expand Down
1 change: 1 addition & 0 deletions tests/examples/llm_ptq/test_llm_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_ptq_whisper(self, command):
[
PTQCommand(quant="int8_sq", kv_cache_quant="none"),
PTQCommand(quant="int8_sq", kv_cache_quant="none", tp=2, pp=2),
PTQCommand(quant="int8_wo", kv_cache_quant="none"),
PTQCommand(quant="int4_awq", kv_cache_quant="none"),
PTQCommand(quant="w4a8_awq", kv_cache_quant="none"),
PTQCommand(quant="nvfp4"),
Expand Down
2 changes: 2 additions & 0 deletions tests/gpu/torch/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
FP8_DEFAULT_CFG,
INT4_AWQ_CFG,
INT8_SMOOTHQUANT_CFG,
INT8_WEIGHT_ONLY_CFG,
NVFP4_AWQ_LITE_CFG,
NVFP4_DEFAULT_CFG,
W4A8_AWQ_BETA_CFG,
Expand Down Expand Up @@ -323,6 +324,7 @@ def test_adjust_attn_amax_values(
("config", "expected_block_size"),
[
(FP8_DEFAULT_CFG, 0),
(INT8_WEIGHT_ONLY_CFG, 0),
(INT8_SMOOTHQUANT_CFG, 0),
(NVFP4_DEFAULT_CFG, 16),
(NVFP4_AWQ_LITE_CFG, 16),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
("nvfp4_awq", "tiny_llama-nvfp4-awq", True, False, True, True),
("int4_awq", "tiny_llama-int4-awq", True, False, True, True),
("w4a8_awq", "tiny_llama-w4a8-awq", True, False, True, True),
("int8_wo", "tiny_llama-int8-wo", False, False, False, False),
],
)
def test_unified_hf_export_and_check_safetensors(
Expand Down
Loading