-
Notifications
You must be signed in to change notification settings - Fork 190
[OMNIML-2336] make w4a8_nvfp4_fp8's scale factor in range of 448/6 #406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughAdds a special-case branch for the W4A8_NVFP4_FP8 quantization format: computes weight_scaling_factor_2 as amax/448.0 in both get_weight_scaling_factor and get_weight_scaling_factor_2, passes that value (moved to weight.device) into NVFP4QTensor.get_weights_scaling_factor, and preserves existing NVFP4/NVFP4_AWQ and SequentialQuantizer fallback behavior. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor Caller
participant QUtils as quant_utils.get_weight_scaling_factor
participant QUtils2 as quant_utils.get_weight_scaling_factor_2
participant QTensor as NVFP4QTensor
participant Quant as weight_quantizer
Caller->>QUtils: get_weight_scaling_factor(Quant, quant_format, weight)
QUtils->>QUtils2: determine weight_scaling_factor_2
alt quant_format == W4A8_NVFP4_FP8
QUtils2->>Quant: read _amax
QUtils2->>QUtils2: compute wsf2 = amax / 448.0
else NVFP4 or NVFP4_AWQ
QUtils2->>QTensor: get_weights_scaling_factor_2_from_quantizer(Quant)
QTensor-->>QUtils2: wsf2
else fallback
QUtils2->>QUtils2: use SequentialQuantizer path
end
QUtils->>QTensor: get_weights_scaling_factor(wsf2.to(weight.device))
QTensor-->>QUtils: weight scaling factor (wsf)
QUtils-->>Caller: return wsf
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #406 +/- ##
==========================================
- Coverage 73.38% 73.37% -0.01%
==========================================
Files 180 180
Lines 17934 17934
==========================================
- Hits 13160 13159 -1
- Misses 4774 4775 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
5d3b2e8 to
ec18006
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
modelopt/torch/export/quant_utils.py (2)
273-283: Consider adding validation for_amaxattribute.The code directly accesses
weight_quantizer._amaxwithout verifying the attribute exists. While the quantization setup should guarantee this, adding an assertion would improve robustness and consistency withNVFP4QTensor.get_weights_scaling_factor_2_from_quantizer, which includes such validation.Apply this diff to add validation:
if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: + assert hasattr(weight_quantizer, "_amax"), "Weight quantizer does not have attribute _amax" # wsf2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. # This is because the kernel dequantizes weight to fp8, which is in range 448. wsf2 = weight_quantizer._amax.float() / 448.0
304-307: Consider adding validation for_amaxattribute.Similar to the change in
get_weight_scaling_factor, this code directly accessesweight_quantizer._amaxwithout validation. Adding an assertion would improve consistency and robustness.Apply this diff to add validation:
elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8: + assert hasattr(weight_quantizer, "_amax"), "Weight quantizer does not have attribute _amax" # wsf2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. # This is because the kernel dequantizes weight to fp8, which is in range 448. return weight_quantizer._amax.float() / 448.0
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/export/quant_utils.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/export/quant_utils.py (2)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py (3)
NVFP4QTensor(31-295)get_weights_scaling_factor_2_from_quantizer(56-60)get_weights_scaling_factor(63-90)modelopt/onnx/quantization/quant_utils.py (1)
get_weights_scaling_factor(135-157)
Signed-off-by: Shiyang Chen <[email protected]> fix based on comments Signed-off-by: Shiyang Chen <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
modelopt/torch/export/quant_utils.py (2)
273-280: Optional: Consider adding defensive assertion for consistency.For consistency with
NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(line 58 in the relevant snippet), you could add an assertion to verify the_amaxattribute exists:if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: + assert hasattr(weight_quantizer, "_amax"), "Weight quantizer does not have attribute _amax" # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.However, since the quantization format is derived from the quantizer's properties, this assertion may be redundant. The current implementation is acceptable.
276-276: Nitpick: Minor code duplication.The calculation
weight_quantizer._amax.float() / 448.0appears in both functions. While this is a very minor duplication, you could optionally extract it into a small helper method if similar logic needs to be added elsewhere in the future. Given that it only appears twice in related contexts, the current approach is acceptable.Also applies to: 309-309
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/export/quant_utils.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/export/quant_utils.py (1)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py (3)
NVFP4QTensor(31-295)get_weights_scaling_factor_2_from_quantizer(56-60)get_weights_scaling_factor(63-90)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (2)
modelopt/torch/export/quant_utils.py (2)
273-285: LGTM! Scale factor calculation correctly implements the 448/6 range.The special-case branch for
QUANTIZATION_W4A8_NVFP4_FP8correctly computesweight_scaling_factor_2asamax/448.0, which when combined with the6.0factor inNVFP4QTensor.get_weights_scaling_factor(line 81 in the snippet:per_block_scale = per_block_amax / (6.0 * weights_scaling_factor_2)), produces the intended scale range of448/6.The device placement via
.to(weight.device)on line 284 is appropriate and prevents device mismatch errors.
306-309: LGTM! Consistent implementation of scale factor calculation.The calculation correctly mirrors the logic in
get_weight_scaling_factor(lines 273-276), ensuring consistent behavior across both functions.
|
/ok to test decb05e |
What does this PR do?
This PR goes together with: NVIDIA/TensorRT-LLM#8180
Type of change: Bug fix
Overview:
For w4a8 nvfp4 fp8, export the scale factor in range of [-448/6, 448/6].
Usage
export in huggingface_example.sh
Testing
Test export in model together with import in trtllm
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Bug Fixes