-
Notifications
You must be signed in to change notification settings - Fork 162
Support INT8 Weight-Only Quantization #263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -45,9 +45,9 @@ if [ "$EXPORT_FORMAT" = "hf" ]; then | |||||||||||||||||||
IFS="," | ||||||||||||||||||||
for qformat in $QFORMAT; do | ||||||||||||||||||||
case $qformat in | ||||||||||||||||||||
fp16 | bf16 | fp8 | fp8_pc_pt | fp8_pb_wo | int4_awq | nvfp4 | nvfp4_awq | w4a8_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;; | ||||||||||||||||||||
fp16 | bf16 | fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int4_awq | nvfp4 | nvfp4_awq | w4a8_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;; | ||||||||||||||||||||
*) | ||||||||||||||||||||
echo "Unsupported quant argument: Expected one of: [fp16, bf16, fp8, fp8_pc_pt, fp8_pb_wo, int4_awq, nvfp4, nvfp4_awq, w4a8_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2 | ||||||||||||||||||||
echo "Unsupported quant argument: Expected one of: [fp16, bf16, fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int4_awq, nvfp4, nvfp4_awq, w4a8_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2 | ||||||||||||||||||||
exit 1 | ||||||||||||||||||||
Comment on lines
+48
to
51
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. HF allowlist should include ‘int8’ to match hf_ptq.py. Without this, “--export_fmt hf --qformat int8” is blocked by the script. Add “int8” here. Apply: - fp16 | bf16 | fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int4_awq | nvfp4 | nvfp4_awq | w4a8_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
+ fp16 | bf16 | fp8 | fp8_pc_pt | fp8_pb_wo | int8 | int8_wo | int4_awq | nvfp4 | nvfp4_awq | w4a8_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;; 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||
;; | ||||||||||||||||||||
esac | ||||||||||||||||||||
|
@@ -74,9 +74,9 @@ esac | |||||||||||||||||||
IFS="," | ||||||||||||||||||||
for qformat in $QFORMAT; do | ||||||||||||||||||||
case $qformat in | ||||||||||||||||||||
fp8 | fp8_pc_pt | fp8_pb_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;; | ||||||||||||||||||||
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;; | ||||||||||||||||||||
*) | ||||||||||||||||||||
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2 | ||||||||||||||||||||
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2 | ||||||||||||||||||||
exit 1 | ||||||||||||||||||||
Comment on lines
+77
to
80
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion General quant allowlist also needs ‘int8’ or guard by export_fmt. Currently this block runs for HF too and rejects “int8”. Minimal fix: add “int8”. Apply: - fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
+ fp8 | fp8_pc_pt | fp8_pb_wo | int8 | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;; Alternative (cleaner): wrap this whole validation in 🤖 Prompt for AI Agents
|
||||||||||||||||||||
;; | ||||||||||||||||||||
esac | ||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -50,6 +50,7 @@ | |||||||||||||||||||||||||||||||
QUANTIZATION_FP8_PC_PT, | ||||||||||||||||||||||||||||||||
QUANTIZATION_INT4_AWQ, | ||||||||||||||||||||||||||||||||
QUANTIZATION_INT8_SQ, | ||||||||||||||||||||||||||||||||
QUANTIZATION_INT8_WO, | ||||||||||||||||||||||||||||||||
QUANTIZATION_MXFP4, | ||||||||||||||||||||||||||||||||
QUANTIZATION_NONE, | ||||||||||||||||||||||||||||||||
QUANTIZATION_NVFP4, | ||||||||||||||||||||||||||||||||
|
@@ -454,7 +455,10 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames | |||||||||||||||||||||||||||||||
return QUANTIZATION_INT4_AWQ | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if weight_quantizer.num_bits == 8: | ||||||||||||||||||||||||||||||||
return QUANTIZATION_INT8_SQ | ||||||||||||||||||||||||||||||||
if input_quantizer is not None and input_quantizer.is_enabled: | ||||||||||||||||||||||||||||||||
return QUANTIZATION_INT8_SQ | ||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||
return QUANTIZATION_INT8_WO | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if weight_quantizer.num_bits == (4, 3): | ||||||||||||||||||||||||||||||||
if weight_quantizer.block_sizes: | ||||||||||||||||||||||||||||||||
|
@@ -626,6 +630,8 @@ def process_layer_quant_config(layer_config_dict): | |||||||||||||||||||||||||||||||
"has_zero_point": False, | ||||||||||||||||||||||||||||||||
"pre_quant_scale": True, | ||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||
elif v == "int8_wo": | ||||||||||||||||||||||||||||||||
layer_config = {"quant_algo": "W8A16"} | ||||||||||||||||||||||||||||||||
elif v == "int8_sq": | ||||||||||||||||||||||||||||||||
layer_config = {"quant_algo": "W8A8_SQ_PER_CHANNEL"} | ||||||||||||||||||||||||||||||||
elif v == "nvfp4": | ||||||||||||||||||||||||||||||||
|
@@ -751,7 +757,7 @@ def to_quantized_weight( | |||||||||||||||||||||||||||||||
return (weight / weights_scaling_factor.unsqueeze(-1)).to(torch.float8_e4m3fn) | ||||||||||||||||||||||||||||||||
return (weight / weights_scaling_factor).to(torch.float8_e4m3fn) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if quantization == QUANTIZATION_INT8_SQ: | ||||||||||||||||||||||||||||||||
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]: | ||||||||||||||||||||||||||||||||
return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
Comment on lines
+760
to
762
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. INT8 pack path doesn’t handle 3D (MoE) weights. For stacked expert weights (E, out, in), broadcasting with Apply: - if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
- return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8)
+ if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
+ if weight.dim() == 3:
+ # (experts, out, in) ÷ (experts, out, 1)
+ return (
+ (weight / weights_scaling_factor.unsqueeze(-1))
+ .round()
+ .clamp(-128, 127)
+ .to(torch.int8)
+ )
+ elif weight.dim() == 2:
+ return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8)
+ else:
+ raise NotImplementedError("INT8 quantization expects 2D or 3D weight tensors") 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||
if quantization == QUANTIZATION_FP8_PB_WO: | ||||||||||||||||||||||||||||||||
|
@@ -803,7 +809,7 @@ def from_quantized_weight( | |||||||||||||||||||||||||||||||
torch_dtype | ||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
if quantization == QUANTIZATION_INT8_SQ: | ||||||||||||||||||||||||||||||||
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]: | ||||||||||||||||||||||||||||||||
return weight.to(torch_dtype) * weights_scaling_factor[:, None].to(torch_dtype) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
Comment on lines
+812
to
814
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. INT8 unpack path also misses 3D (MoE) weights. Mirror the 3D case to restore full-precision weights correctly. Apply: - if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
- return weight.to(torch_dtype) * weights_scaling_factor[:, None].to(torch_dtype)
+ if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
+ if weight.dim() == 3:
+ return weight.to(torch_dtype) * weights_scaling_factor.unsqueeze(-1).to(torch_dtype)
+ elif weight.dim() == 2:
+ return weight.to(torch_dtype) * weights_scaling_factor[:, None].to(torch_dtype)
+ else:
+ raise NotImplementedError("INT8 dequantization expects 2D or 3D weight tensors") 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||
raise NotImplementedError(f"quantization format {quantization} not supported") | ||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -95,6 +95,7 @@ def llama_path(tiny_llama_path): | |
PTQCommand(quant="bf16"), | ||
PTQCommand(quant="int8_sq"), | ||
# ("int8_sq", "tensorrt_llm", "sparsegpt"), | ||
PTQCommand(quant="int8_wo", export_fmt="hf"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainVerify end-to-end plumbing for "int8_wo" and HF export args Sanity-check that the CLI and scripts accept "int8_wo", pass through export_fmt="hf", and that the KV-cache flag name matches what the scripts expect. 🏁 Script executed: #!/bin/bash
set -euo pipefail
echo "1) Confirm repo recognizes INT8 weight-only:"
rg -nP -C2 '(int8_wo|INT8_WEIGHT_ONLY|WEIGHT[-_ ]ONLY.*INT8)'
echo
echo "2) Check script argument wiring for HF PTQ path:"
rg -n -C2 'huggingface_example.sh|hf_ptq\.py' scripts examples | sed -n '1,120p' || true
rg -n -C2 -- '(^|\s)--(quant|export_fmt)\b' scripts/huggingface_example.sh examples/**/hf_ptq.py || true
echo
echo "3) Confirm KV-cache flag spelling expected by code (quant vs qformat):"
rg -nP -C2 'kv_cache_(quant|qformat)' scripts/huggingface_example.sh examples/**/hf_ptq.py
echo
echo "Expected:"
echo "- At least one config/enum or decision branch for int8_wo"
echo "- huggingface_example.sh forwards --quant and --export_fmt to hf_ptq.py"
echo "- hf_ptq.py consumes kv_cache_quant (or adjust test if it uses kv_cache_qformat)" Length of output: 19435 🏁 Script executed: #!/bin/bash
set -euo pipefail
echo "1) Parser args in hf_ptq.py:"
grep -R "add_argument" -n examples/llm_ptq/hf_ptq.py | grep -E "export_fmt|qformat|quant"
echo
echo "2) hf_ptq.py export_fmt flag usage:"
grep -R -- "--export_fmt" -n examples/llm_ptq/hf_ptq.py
echo
echo "3) Script forwarding in examples/llm_ptq/scripts/huggingface_example.sh:"
grep -n "hf_ptq.py" examples/llm_ptq/scripts/huggingface_example.sh | sed -n '180,230p'
grep -n -- "--export_fmt" examples/llm_ptq/scripts/huggingface_example.sh
grep -n -- "--quant" examples/llm_ptq/scripts/huggingface_example.sh Length of output: 28 🏁 Script executed: #!/bin/bash
set -euo pipefail
echo "1) Parser args in hf_ptq.py:"
grep -R "add_argument" -n examples/llm_ptq/hf_ptq.py | grep -E "quant|qformat|export_fmt|kv_cache" || true
echo
echo "2) Script invocation in huggingface_example.sh:"
sed -n '180,260p' examples/llm_ptq/scripts/huggingface_example.sh || true
echo
echo "3) export_fmt usage in tests:"
grep -n "export_fmt" tests/examples/llm_ptq/test_llm_ptq.py || true
echo
echo "4) kv_cache usage in tests:"
grep -n "kv_cache" tests/examples/llm_ptq/test_llm_ptq.py || true Length of output: 4192 Align CLI flags with tests and scripts
|
||
PTQCommand(quant="int4_awq"), | ||
PTQCommand(quant="int4_awq", export_fmt="hf"), | ||
PTQCommand(quant="nvfp4"), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Mismatch with huggingface_example.sh allowlists.
HF path here allows “int8” and “int8_wo”, but the script denies “int8”, causing CLI rejection before Python runs. Update the script allowlists accordingly (see suggested patch in the script file comment).
🏁 Script executed:
Length of output: 3127
Add “int8” to quantization allowlist in huggingface_example.sh
In examples/llm_ptq/scripts/huggingface_example.sh, update both quant-format case statements (around lines 48 and 77) to include “int8” alongside “int8_wo” and adjust the corresponding error-message lists so that “int8” is accepted.
🤖 Prompt for AI Agents