-
Notifications
You must be signed in to change notification settings - Fork 169
Support INT8 Weight-Only Quantization #263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -50,6 +50,7 @@ | |||||||||||||||||||
QUANTIZATION_FP8_PC_PT, | ||||||||||||||||||||
QUANTIZATION_INT4_AWQ, | ||||||||||||||||||||
QUANTIZATION_INT8_SQ, | ||||||||||||||||||||
QUANTIZATION_INT8_WO, | ||||||||||||||||||||
QUANTIZATION_MXFP4, | ||||||||||||||||||||
QUANTIZATION_NONE, | ||||||||||||||||||||
QUANTIZATION_NVFP4, | ||||||||||||||||||||
|
@@ -454,7 +455,10 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames | |||||||||||||||||||
return QUANTIZATION_INT4_AWQ | ||||||||||||||||||||
|
||||||||||||||||||||
if weight_quantizer.num_bits == 8: | ||||||||||||||||||||
return QUANTIZATION_INT8_SQ | ||||||||||||||||||||
if input_quantizer is not None and input_quantizer.is_enabled: | ||||||||||||||||||||
return QUANTIZATION_INT8_SQ | ||||||||||||||||||||
else: | ||||||||||||||||||||
return QUANTIZATION_INT8_WO | ||||||||||||||||||||
|
||||||||||||||||||||
if weight_quantizer.num_bits == (4, 3): | ||||||||||||||||||||
if weight_quantizer.block_sizes: | ||||||||||||||||||||
|
@@ -626,6 +630,8 @@ def process_layer_quant_config(layer_config_dict): | |||||||||||||||||||
"has_zero_point": False, | ||||||||||||||||||||
"pre_quant_scale": True, | ||||||||||||||||||||
} | ||||||||||||||||||||
elif v == "int8_wo": | ||||||||||||||||||||
layer_config = {"quant_algo": "W8A16"} | ||||||||||||||||||||
elif v == "int8_sq": | ||||||||||||||||||||
layer_config = {"quant_algo": "W8A8_SQ_PER_CHANNEL"} | ||||||||||||||||||||
elif v == "nvfp4": | ||||||||||||||||||||
|
@@ -751,7 +757,7 @@ def to_quantized_weight( | |||||||||||||||||||
return (weight / weights_scaling_factor.unsqueeze(-1)).to(torch.float8_e4m3fn) | ||||||||||||||||||||
return (weight / weights_scaling_factor).to(torch.float8_e4m3fn) | ||||||||||||||||||||
|
||||||||||||||||||||
if quantization == QUANTIZATION_INT8_SQ: | ||||||||||||||||||||
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]: | ||||||||||||||||||||
return (weight / weights_scaling_factor[:, None]).round().clamp(-128, 127).to(torch.int8) | ||||||||||||||||||||
|
||||||||||||||||||||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||
if quantization == QUANTIZATION_FP8_PB_WO: | ||||||||||||||||||||
|
@@ -803,7 +809,7 @@ def from_quantized_weight( | |||||||||||||||||||
torch_dtype | ||||||||||||||||||||
) | ||||||||||||||||||||
|
||||||||||||||||||||
if quantization == QUANTIZATION_INT8_SQ: | ||||||||||||||||||||
if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]: | ||||||||||||||||||||
return weight.to(torch_dtype) * weights_scaling_factor[:, None].to(torch_dtype) | ||||||||||||||||||||
|
||||||||||||||||||||
Comment on lines
+812
to
814
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. INT8 unpack path also misses 3D (MoE) weights. Mirror the 3D case to restore full-precision weights correctly. Apply: - if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
- return weight.to(torch_dtype) * weights_scaling_factor[:, None].to(torch_dtype)
+ if quantization in [QUANTIZATION_INT8_SQ, QUANTIZATION_INT8_WO]:
+ if weight.dim() == 3:
+ return weight.to(torch_dtype) * weights_scaling_factor.unsqueeze(-1).to(torch_dtype)
+ elif weight.dim() == 2:
+ return weight.to(torch_dtype) * weights_scaling_factor[:, None].to(torch_dtype)
+ else:
+ raise NotImplementedError("INT8 dequantization expects 2D or 3D weight tensors") 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||
raise NotImplementedError(f"quantization format {quantization} not supported") | ||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
General quant allowlist also needs ‘int8’ or guard by export_fmt.
Currently this block runs for HF too and rejects “int8”. Minimal fix: add “int8”.
Apply:
Alternative (cleaner): wrap this whole validation in
if [ "$EXPORT_FORMAT" != "hf" ]; then ... fi
so HF path is validated only once.🤖 Prompt for AI Agents