Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/guides/_compress_quantized_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ After PTQ, the model can be compressed with the following code:
Initialize HF models with compressed weights for lower memory usage
===================================================================

When working with large language models, memory constraints can be a significant challenge. ModelOpt provides a workflow for initaializing HF models with compressed weights across multiple GPUs to dramatically reduce memory usage.
When working with large language models, memory constraints can be a significant challenge. ModelOpt provides a workflow for initializing HF models with compressed weights across multiple GPUs to dramatically reduce memory usage.

For quantized formats like NVFP4, you can reduce memory usage by up to 4x compared to FP16/BF16 models. One limitation is that this workflow only works with max calibration algorithm.

Expand Down
4 changes: 4 additions & 0 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = {
"int8": mtq.INT8_DEFAULT_CFG,
"int8_sq": mtq.INT8_SMOOTHQUANT_CFG,
"int8_wo": mtq.INT8_WEIGHT_ONLY_CFG,
"fp8": mtq.FP8_DEFAULT_CFG,
"int4_awq": mtq.INT4_AWQ_CFG,
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
Expand Down Expand Up @@ -94,6 +95,7 @@ def auto_quantize(
qformat
in [
"fp8",
"int8_wo",
"int4_awq",
"nvfp4",
"nvfp4_awq",
Expand Down Expand Up @@ -229,6 +231,8 @@ def main(args):
assert (
args.qformat
in [
"int8",
"int8_wo",
"int4_awq",
Comment on lines +234 to 236
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Mismatch with huggingface_example.sh allowlists.

HF path here allows “int8” and “int8_wo”, but the script denies “int8”, causing CLI rejection before Python runs. Update the script allowlists accordingly (see suggested patch in the script file comment).


🏁 Script executed:

#!/bin/bash
set -e

# Locate the huggingface example script
script_path=$(find . -type f -name huggingface_example.sh | head -n1)
echo "Found script at: $script_path"

# Show lines around allowlist/quantization options
echo "=== Context around allowlisted quant types ==="
grep -nE 'allowlist|quant' -C3 "$script_path"

# Show any explicit int8/int8_wo mentions
echo "=== int8 / int8_wo occurrences ==="
grep -nE 'int8|int8_wo' -n "$script_path"

# Show validation logic (case statements) for quant types
echo "=== Case statements for quant type validation ==="
grep -nR -n 'case .* in' -C2 "$script_path"

Length of output: 3127


Add “int8” to quantization allowlist in huggingface_example.sh
In examples/llm_ptq/scripts/huggingface_example.sh, update both quant-format case statements (around lines 48 and 77) to include “int8” alongside “int8_wo” and adjust the corresponding error-message lists so that “int8” is accepted.

🤖 Prompt for AI Agents
In examples/llm_ptq/scripts/huggingface_example.sh around the quant-format case
statements near lines ~48 and ~77, the script currently accepts "int8_wo" (and
other formats) but does not include "int8" in the allowed patterns or the
corresponding error-message lists; update both case branches to include "int8"
alongside "int8_wo" (e.g., add "int8" as an accepted pattern in each case
pattern list) and update the error/usage messages so that they list "int8" as an
accepted format where applicable.

"fp8",
"nvfp4",
Expand Down
8 changes: 4 additions & 4 deletions examples/llm_ptq/scripts/huggingface_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ if [ "$EXPORT_FORMAT" = "hf" ]; then
IFS=","
for qformat in $QFORMAT; do
case $qformat in
fp16 | bf16 | fp8 | fp8_pc_pt | fp8_pb_wo | int4_awq | nvfp4 | nvfp4_awq | w4a8_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
fp16 | bf16 | fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int4_awq | nvfp4 | nvfp4_awq | w4a8_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
*)
echo "Unsupported quant argument: Expected one of: [fp16, bf16, fp8, fp8_pc_pt, fp8_pb_wo, int4_awq, nvfp4, nvfp4_awq, w4a8_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2
echo "Unsupported quant argument: Expected one of: [fp16, bf16, fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int4_awq, nvfp4, nvfp4_awq, w4a8_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2
exit 1
Comment on lines +48 to 51
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

HF allowlist should include ‘int8’ to match hf_ptq.py.

Without this, “--export_fmt hf --qformat int8” is blocked by the script. Add “int8” here.

Apply:

-        fp16 | bf16 | fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int4_awq | nvfp4 | nvfp4_awq | w4a8_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
+        fp16 | bf16 | fp8 | fp8_pc_pt | fp8_pb_wo | int8 | int8_wo | int4_awq | nvfp4 | nvfp4_awq | w4a8_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
fp16 | bf16 | fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int4_awq | nvfp4 | nvfp4_awq | w4a8_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
*)
echo "Unsupported quant argument: Expected one of: [fp16, bf16, fp8, fp8_pc_pt, fp8_pb_wo, int4_awq, nvfp4, nvfp4_awq, w4a8_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2
echo "Unsupported quant argument: Expected one of: [fp16, bf16, fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int4_awq, nvfp4, nvfp4_awq, w4a8_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2
exit 1
fp16 | bf16 | fp8 | fp8_pc_pt | fp8_pb_wo | int8 | int8_wo | int4_awq | nvfp4 | nvfp4_awq | w4a8_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
*)
echo "Unsupported quant argument: Expected one of: [fp16, bf16, fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int4_awq, nvfp4, nvfp4_awq, w4a8_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2
exit 1
🤖 Prompt for AI Agents
In examples/llm_ptq/scripts/huggingface_example.sh around lines 48 to 51, the
quantization allowlist is missing the "int8" option which prevents "--export_fmt
hf --qformat int8" from being accepted; add "int8" to the case pattern list (and
update the echo message to include "int8" among the expected values) so the
script accepts that quant argument and exits successfully for valid inputs.

;;
esac
Expand All @@ -74,9 +74,9 @@ esac
IFS=","
for qformat in $QFORMAT; do
case $qformat in
fp8 | fp8_pc_pt | fp8_pb_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
*)
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2
exit 1
Comment on lines +77 to 80
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

General quant allowlist also needs ‘int8’ or guard by export_fmt.

Currently this block runs for HF too and rejects “int8”. Minimal fix: add “int8”.

Apply:

-    fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
+    fp8 | fp8_pc_pt | fp8_pb_wo | int8 | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;

Alternative (cleaner): wrap this whole validation in if [ "$EXPORT_FORMAT" != "hf" ]; then ... fi so HF path is validated only once.

🤖 Prompt for AI Agents
In examples/llm_ptq/scripts/huggingface_example.sh around lines 77-80, the quant
argument validation excludes "int8" and incorrectly rejects it for the HF path;
either add "int8" to the list of allowed quant values in the case pattern or,
preferably, wrap this entire validation block in a guard so it only runs when
EXPORT_FORMAT is not "hf" (i.e., surround the case...esac with if [
"$EXPORT_FORMAT" != "hf" ]; then ... fi), ensuring HF exports are not blocked
and that "int8" remains accepted where appropriate.

;;
esac
Expand Down
3 changes: 2 additions & 1 deletion modelopt/torch/export/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
QUANTIZATION_NONE = None
QUANTIZATION_FP8 = "fp8"
QUANTIZATION_INT8_SQ = "int8_sq"
QUANTIZATION_INT8_WO = "int8_wo"
QUANTIZATION_INT4_AWQ = "int4_awq"
QUANTIZATION_W4A8_AWQ = "w4a8_awq"
QUANTIZATION_NVFP4 = "nvfp4"
Expand Down Expand Up @@ -201,7 +202,7 @@ def weights_scaling_factor(self):
return None

if self.q.weights_scaling_factor.numel() != 1:
# for NVFP4, Int4 AWQ and Int8 SQ case, we concatenate the
# for NVFP4, Int4 AWQ, Int8 SQ and Int8 WO case, we concatenate the
# q_weight_scaling_factor, k_weight_scaling_factor, v_weight_scaling_factor
if self.q.weights_scaling_factor.dtype == torch.float8_e4m3fn:
# For NVFP4, we recompute the weights_scaling_factor using max weights_scaling_factor2
Expand Down
12 changes: 9 additions & 3 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
QUANTIZATION_FP8_PC_PT,
QUANTIZATION_INT4_AWQ,
QUANTIZATION_INT8_SQ,
QUANTIZATION_INT8_WO,
QUANTIZATION_MXFP4,
QUANTIZATION_NONE,
QUANTIZATION_NVFP4,
Expand Down Expand Up @@ -454,7 +455,10 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
return QUANTIZATION_INT4_AWQ

if weight_quantizer.num_bits == 8:
return QUANTIZATION_INT8_SQ
if input_quantizer is not None and input_quantizer.is_enabled:
return QUANTIZATION_INT8_SQ
else:
return QUANTIZATION_INT8_WO

if weight_quantizer.num_bits == (4, 3):
if weight_quantizer.block_sizes:
Expand Down Expand Up @@ -626,6 +630,8 @@ def process_layer_quant_config(layer_config_dict):
"has_zero_point": False,
"pre_quant_scale": True,
}
elif v == "int8_wo":
layer_config = {"quant_algo": "W8A16"}
elif v == "int8_sq":
layer_config = {"quant_algo": "W8A8_SQ_PER_CHANNEL"}
elif v == "nvfp4":
Expand Down Expand Up @@ -751,7 +757,7 @@ def to_quantized_weight(
return (weight / weights_scaling_factor.unsqueeze(-1)).to(torch.float8_e4m3fn)
return (weight / weights_scaling_factor).to(torch.float8_e4m3fn)

if quantization == QUANTIZATION_INT8_SQ:
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8)

Comment on lines +760 to 762
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

INT8 pack path doesn’t handle 3D (MoE) weights.

For stacked expert weights (E, out, in), broadcasting with [:, None] is wrong. Add 3D handling to avoid shape/broadcast errors and incorrect scaling.

Apply:

-    if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
-        return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8)
+    if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
+        if weight.dim() == 3:
+            # (experts, out, in) ÷ (experts, out, 1)
+            return (
+                (weight / weights_scaling_factor.unsqueeze(-1))
+                .round()
+                .clamp(-128, 127)
+                .to(torch.int8)
+            )
+        elif weight.dim() == 2:
+            return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8)
+        else:
+            raise NotImplementedError("INT8 quantization expects 2D or 3D weight tensors")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8)
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
if weight.dim() == 3:
# (experts, out, in) ÷ (experts, out, 1)
return (
(weight / weights_scaling_factor.unsqueeze(-1))
.round()
.clamp(-128, 127)
.to(torch.int8)
)
elif weight.dim() == 2:
return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8)
else:
raise NotImplementedError("INT8 quantization expects 2D or 3D weight tensors")

if quantization == QUANTIZATION_FP8_PB_WO:
Expand Down Expand Up @@ -803,7 +809,7 @@ def from_quantized_weight(
torch_dtype
)

if quantization == QUANTIZATION_INT8_SQ:
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
return weight.to(torch_dtype) * weights_scaling_factor[:, None].to(torch_dtype)

Comment on lines +812 to 814
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

INT8 unpack path also misses 3D (MoE) weights.

Mirror the 3D case to restore full-precision weights correctly.

Apply:

-    if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
-        return weight.to(torch_dtype) * weights_scaling_factor[:, None].to(torch_dtype)
+    if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
+        if weight.dim() == 3:
+            return weight.to(torch_dtype) * weights_scaling_factor.unsqueeze(-1).to(torch_dtype)
+        elif weight.dim() == 2:
+            return weight.to(torch_dtype) * weights_scaling_factor[:, None].to(torch_dtype)
+        else:
+            raise NotImplementedError("INT8 dequantization expects 2D or 3D weight tensors")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
return weight.to(torch_dtype) * weights_scaling_factor[:, None].to(torch_dtype)
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
if weight.dim() == 3:
return weight.to(torch_dtype) * weights_scaling_factor.unsqueeze(-1).to(torch_dtype)
elif weight.dim() == 2:
return weight.to(torch_dtype) * weights_scaling_factor[:, None].to(torch_dtype)
else:
raise NotImplementedError("INT8 dequantization expects 2D or 3D weight tensors")

raise NotImplementedError(f"quantization format {quantization} not supported")
Expand Down
10 changes: 10 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,15 @@
"algorithm": "smoothquant",
}

INT8_WEIGHT_ONLY_CFG = {
"quant_cfg": {
"*weight_quantizer": {"num_bits": 8, "axis": 0},
"*input_quantizer": {"enable": False},
**_default_disabled_quantizer_cfg,
},
"algorithm": "max",
}

FP8_DEFAULT_CFG = {
"quant_cfg": {
"*weight_quantizer": {"num_bits": (4, 3), "axis": None},
Expand Down Expand Up @@ -622,6 +631,7 @@
"INT4_BLOCKWISE_WEIGHT_ONLY_CFG",
"INT8_DEFAULT_CFG",
"INT8_SMOOTHQUANT_CFG",
"INT8_WEIGHT_ONLY_CFG",
"MXFP4_DEFAULT_CFG",
"MXFP8_DEFAULT_CFG",
"MXINT8_DEFAULT_CFG",
Expand Down
1 change: 1 addition & 0 deletions tests/examples/llm_ptq/test_llm_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def llama_path(tiny_llama_path):
PTQCommand(quant="bf16"),
PTQCommand(quant="int8_sq"),
# ("int8_sq", "tensorrt_llm", "sparsegpt"),
PTQCommand(quant="int8_wo", export_fmt="hf"),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Verify end-to-end plumbing for "int8_wo" and HF export args

Sanity-check that the CLI and scripts accept "int8_wo", pass through export_fmt="hf", and that the KV-cache flag name matches what the scripts expect.


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "1) Confirm repo recognizes INT8 weight-only:"
rg -nP -C2 '(int8_wo|INT8_WEIGHT_ONLY|WEIGHT[-_ ]ONLY.*INT8)'

echo
echo "2) Check script argument wiring for HF PTQ path:"
rg -n -C2 'huggingface_example.sh|hf_ptq\.py' scripts examples | sed -n '1,120p' || true
rg -n -C2 -- '(^|\s)--(quant|export_fmt)\b' scripts/huggingface_example.sh examples/**/hf_ptq.py || true

echo
echo "3) Confirm KV-cache flag spelling expected by code (quant vs qformat):"
rg -nP -C2 'kv_cache_(quant|qformat)' scripts/huggingface_example.sh examples/**/hf_ptq.py

echo
echo "Expected:"
echo "- At least one config/enum or decision branch for int8_wo"
echo "- huggingface_example.sh forwards --quant and --export_fmt to hf_ptq.py"
echo "- hf_ptq.py consumes kv_cache_quant (or adjust test if it uses kv_cache_qformat)"

Length of output: 19435


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "1) Parser args in hf_ptq.py:"
grep -R "add_argument" -n examples/llm_ptq/hf_ptq.py | grep -E "export_fmt|qformat|quant"

echo
echo "2) hf_ptq.py export_fmt flag usage:"
grep -R -- "--export_fmt" -n examples/llm_ptq/hf_ptq.py

echo
echo "3) Script forwarding in examples/llm_ptq/scripts/huggingface_example.sh:"
grep -n "hf_ptq.py" examples/llm_ptq/scripts/huggingface_example.sh | sed -n '180,230p'
grep -n -- "--export_fmt" examples/llm_ptq/scripts/huggingface_example.sh
grep -n -- "--quant" examples/llm_ptq/scripts/huggingface_example.sh

Length of output: 28


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "1) Parser args in hf_ptq.py:"
grep -R "add_argument" -n examples/llm_ptq/hf_ptq.py | grep -E "quant|qformat|export_fmt|kv_cache" || true

echo
echo "2) Script invocation in huggingface_example.sh:"
sed -n '180,260p' examples/llm_ptq/scripts/huggingface_example.sh || true

echo
echo "3) export_fmt usage in tests:"
grep -n "export_fmt" tests/examples/llm_ptq/test_llm_ptq.py || true

echo
echo "4) kv_cache usage in tests:"
grep -n "kv_cache" tests/examples/llm_ptq/test_llm_ptq.py || true

Length of output: 4192


Align CLI flags with tests and scripts

  • In examples/llm_ptq/hf_ptq.py (parser around line 707):
    • Add a --export_fmt argument to consume the export_fmt="hf" passed by tests and huggingface_example.sh.
    • Alias --quant to --qformat (or accept both) so PTQCommand(quant=…) maps to args.qformat.
    • Alias --kv_cache_quant to --kv_cache_qformat (or accept both) so tests’ kv_cache_quant matches args.kv_cache_qformat.

PTQCommand(quant="int4_awq"),
PTQCommand(quant="int4_awq", export_fmt="hf"),
PTQCommand(quant="nvfp4"),
Expand Down
2 changes: 2 additions & 0 deletions tests/gpu/torch/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
FP8_DEFAULT_CFG,
INT4_AWQ_CFG,
INT8_SMOOTHQUANT_CFG,
INT8_WEIGHT_ONLY_CFG,
NVFP4_AWQ_LITE_CFG,
NVFP4_DEFAULT_CFG,
W4A8_AWQ_BETA_CFG,
Expand Down Expand Up @@ -323,6 +324,7 @@ def test_adjust_attn_amax_values(
("config", "expected_block_size"),
[
(FP8_DEFAULT_CFG, 0),
(INT8_WEIGHT_ONLY_CFG, 0),
(INT8_SMOOTHQUANT_CFG, 0),
(NVFP4_DEFAULT_CFG, 16),
(NVFP4_AWQ_LITE_CFG, 16),
Expand Down
Loading