Skip to content

Conversation

@sychen52
Copy link
Contributor

@sychen52 sychen52 commented Oct 7, 2025

What does this PR do?

This PR goes together with: NVIDIA/TensorRT-LLM#8180
Type of change: Bug fix
Overview:
For w4a8 nvfp4 fp8, export the scale factor in range of [-448/6, 448/6].

Usage

export in huggingface_example.sh

Testing

Test export in model together with import in trtllm

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Added explicit support for the W4A8_NVFP4_FP8 quantization format so exported models use the correct weight scaling for that format.
  • Bug Fixes

    • Fixed incorrect fallback behavior for W4A8_NVFP4_FP8 scaling retrieval, improving quantization accuracy and stability during model export.

@sychen52 sychen52 self-assigned this Oct 7, 2025
@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 7, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 7, 2025

Walkthrough

Adds a special-case branch for the W4A8_NVFP4_FP8 quantization format: computes weight_scaling_factor_2 as amax/448.0 in both get_weight_scaling_factor and get_weight_scaling_factor_2, passes that value (moved to weight.device) into NVFP4QTensor.get_weights_scaling_factor, and preserves existing NVFP4/NVFP4_AWQ and SequentialQuantizer fallback behavior.

Changes

Cohort / File(s) Summary of Changes
Quantization scaling logic
modelopt/torch/export/quant_utils.py
Added conditional handling for W4A8_NVFP4_FP8 in get_weight_scaling_factor and get_weight_scaling_factor_2; computes wsf2 = amax / 448.0 for that format; passes wsf2.to(weight.device) into NVFP4QTensor.get_weights_scaling_factor; retains existing NVFP4/NVFP4_AWQ delegation and SequentialQuantizer fallback.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor Caller
  participant QUtils as quant_utils.get_weight_scaling_factor
  participant QUtils2 as quant_utils.get_weight_scaling_factor_2
  participant QTensor as NVFP4QTensor
  participant Quant as weight_quantizer

  Caller->>QUtils: get_weight_scaling_factor(Quant, quant_format, weight)
  QUtils->>QUtils2: determine weight_scaling_factor_2
  alt quant_format == W4A8_NVFP4_FP8
    QUtils2->>Quant: read _amax
    QUtils2->>QUtils2: compute wsf2 = amax / 448.0
  else NVFP4 or NVFP4_AWQ
    QUtils2->>QTensor: get_weights_scaling_factor_2_from_quantizer(Quant)
    QTensor-->>QUtils2: wsf2
  else fallback
    QUtils2->>QUtils2: use SequentialQuantizer path
  end
  QUtils->>QTensor: get_weights_scaling_factor(wsf2.to(weight.device))
  QTensor-->>QUtils: weight scaling factor (wsf)
  QUtils-->>Caller: return wsf
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Poem

A nibble of four, a byte of eight,
I twitch my nose and compute the rate.
Amax over 448, tidy and neat,
I hop through scaling with nimble feet.
Quant beds aligned, I burrow with glee. 🐇✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly and accurately describes the primary change by indicating that the w4a8_nvfp4_fp8 scale factor is being constrained to the specified range, matching the PR’s bug-fix objective without including extraneous details.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Oct 7, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.37%. Comparing base (c511477) to head (decb05e).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #406      +/-   ##
==========================================
- Coverage   73.38%   73.37%   -0.01%     
==========================================
  Files         180      180              
  Lines       17934    17934              
==========================================
- Hits        13160    13159       -1     
- Misses       4774     4775       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@sychen52 sychen52 force-pushed the 448/6 branch 2 times, most recently from 5d3b2e8 to ec18006 Compare October 8, 2025 16:31
@sychen52 sychen52 marked this pull request as ready for review October 13, 2025 17:14
@sychen52 sychen52 requested a review from a team as a code owner October 13, 2025 17:14
@sychen52 sychen52 requested review from RalphMao and meenchen October 13, 2025 17:14
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
modelopt/torch/export/quant_utils.py (2)

273-283: Consider adding validation for _amax attribute.

The code directly accesses weight_quantizer._amax without verifying the attribute exists. While the quantization setup should guarantee this, adding an assertion would improve robustness and consistency with NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer, which includes such validation.

Apply this diff to add validation:

     if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
+        assert hasattr(weight_quantizer, "_amax"), "Weight quantizer does not have attribute _amax"
         # wsf2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
         # This is because the kernel dequantizes weight to fp8, which is in range 448.
         wsf2 = weight_quantizer._amax.float() / 448.0

304-307: Consider adding validation for _amax attribute.

Similar to the change in get_weight_scaling_factor, this code directly accesses weight_quantizer._amax without validation. Adding an assertion would improve consistency and robustness.

Apply this diff to add validation:

     elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8:
+        assert hasattr(weight_quantizer, "_amax"), "Weight quantizer does not have attribute _amax"
         # wsf2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
         # This is because the kernel dequantizes weight to fp8, which is in range 448.
         return weight_quantizer._amax.float() / 448.0
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 512e96f and ec18006.

📒 Files selected for processing (1)
  • modelopt/torch/export/quant_utils.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/export/quant_utils.py (2)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py (3)
  • NVFP4QTensor (31-295)
  • get_weights_scaling_factor_2_from_quantizer (56-60)
  • get_weights_scaling_factor (63-90)
modelopt/onnx/quantization/quant_utils.py (1)
  • get_weights_scaling_factor (135-157)

@sychen52 sychen52 enabled auto-merge (squash) October 15, 2025 17:27
Signed-off-by: Shiyang Chen <[email protected]>

fix based on comments

Signed-off-by: Shiyang Chen <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
modelopt/torch/export/quant_utils.py (2)

273-280: Optional: Consider adding defensive assertion for consistency.

For consistency with NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer (line 58 in the relevant snippet), you could add an assertion to verify the _amax attribute exists:

     if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
+        assert hasattr(weight_quantizer, "_amax"), "Weight quantizer does not have attribute _amax"
         # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.

However, since the quantization format is derived from the quantizer's properties, this assertion may be redundant. The current implementation is acceptable.


276-276: Nitpick: Minor code duplication.

The calculation weight_quantizer._amax.float() / 448.0 appears in both functions. While this is a very minor duplication, you could optionally extract it into a small helper method if similar logic needs to be added elsewhere in the future. Given that it only appears twice in related contexts, the current approach is acceptable.

Also applies to: 309-309

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 906ce5c and decb05e.

📒 Files selected for processing (1)
  • modelopt/torch/export/quant_utils.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/export/quant_utils.py (1)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py (3)
  • NVFP4QTensor (31-295)
  • get_weights_scaling_factor_2_from_quantizer (56-60)
  • get_weights_scaling_factor (63-90)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (2)
modelopt/torch/export/quant_utils.py (2)

273-285: LGTM! Scale factor calculation correctly implements the 448/6 range.

The special-case branch for QUANTIZATION_W4A8_NVFP4_FP8 correctly computes weight_scaling_factor_2 as amax/448.0, which when combined with the 6.0 factor in NVFP4QTensor.get_weights_scaling_factor (line 81 in the snippet: per_block_scale = per_block_amax / (6.0 * weights_scaling_factor_2)), produces the intended scale range of 448/6.

The device placement via .to(weight.device) on line 284 is appropriate and prevents device mismatch errors.


306-309: LGTM! Consistent implementation of scale factor calculation.

The calculation correctly mirrors the logic in get_weight_scaling_factor (lines 273-276), ensuring consistent behavior across both functions.

@sychen52
Copy link
Contributor Author

/ok to test decb05e

@sychen52 sychen52 merged commit 718fd9e into NVIDIA:main Oct 16, 2025
47 of 48 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants