Skip to content

Conversation

cjluo-nv
Copy link
Collaborator

@cjluo-nv cjluo-nv commented Sep 24, 2025

What does this PR do?

Bug fix

Overview: ?

  • For some QAT usecase, we may mix frozen compressed gemms with fake quant gemms. Though all gemms will be converted to RealQuantLinear after compression, the fake quant gemm's q_tensor_state will be an empty dict
  • Move torch.compile ops on the module level to avoid runtime CPU overhead
  • Fix realquant gradient compute when the input tensor is 3D.

Testing

Unittests

Summary by CodeRabbit

  • New Features

    • FP8 per-tensor GEMM enabled with automatic scale caching and hardware availability checks.
    • Real-quant GEMM routing expanded to apply for more eligible inputs.
  • Refactor

    • Streamlined gradient and bias-gradient handling to reduce saved context and memory.
    • Performance optimizations via compilation-ready paths, FP8 conversions, and improved matrix ops.
    • More explicit availability checks for quantized GEMM implementations.
  • Chores

    • Tightened public API and docs to hide the FP8 backend from autosummaries.

@cjluo-nv cjluo-nv requested a review from a team as a code owner September 24, 2025 23:24
@cjluo-nv cjluo-nv requested a review from realAsma September 24, 2025 23:24
Copy link

coderabbitai bot commented Sep 24, 2025

Walkthrough

Refactors FP8 per-tensor GEMM with compiled helpers and per-tensor scale caching, updates autograd state for FP8 and NVFP4 GEMM paths, renames a RealQuantLinear method to an availability check, tightens Megatron routing, and removes the FP8 wildcard export from backends init.

Changes

Cohort / File(s) Summary of Changes
FP8 per-tensor GEMM path
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py
Added torch.compile helpers _to_fp8 and _fp8_gemm_impl. Rewrote fp8_per_tensor_gemm to cache per-tensor scales on the quant module, convert inputs/weights to FP8 as needed, call torch._scaled_mm, and add bias for FP32. Introduced availability_check. Updated Fp8PerTensorLinear autograd to save fewer tensors, store ctx.compute_bias_grad flag, adjust backward to conditionally compute bias grad and perform optional all-reduce, and added an apply wrapper supporting kwargs.
NVFP4 GEMM path
modelopt/torch/quantization/backends/nvfp4_gemm.py
In Nvfp4Linear, set ctx.compute_bias_grad (bool) during forward instead of saving it as a tensor. Backward uses the ctx boolean and computes grad_weight by reshaping grad_output/input_tensor for matmul rather than transpose-based logic.
Core quantized Linear routing
modelopt/torch/quantization/nn/modules/quant_linear.py
Renamed get_real_quant_gemm_implhas_real_quant_gemm_impl to return a boolean availability check; forward now uses this check and asserts the cached _real_quant_gemm_impl before invoking it.
Megatron parallel plugin
modelopt/torch/quantization/plugins/megatron.py
real_quant_module_set_extra_state now checks q_tensor_state truthiness. _forward_impl routes to real-quant GEMM only if input.numel() > 1 and has_real_quant_gemm_impl(...) is true; otherwise falls back to superclass flow.
Backend package exports
modelopt/torch/quantization/backends/__init__.py
Removed wildcard export: from .fp8_per_tensor_gemm import * (reducing package namespace exposure); other imports left intact.
Docs autosummary template
docs/source/_templates/autosummary/module.rst
Excludes the FP8 per-tensor gemm backend (modelopt.torch.quantization.backends.fp8_per_tensor_gemm) from autosummary listings via updated condition.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant U as Caller
  participant MQ as _RealQuantMegatronParallelLinear
  participant RQ as RealQuantLinear
  participant GI as GEMM Impl (cached)

  U->>MQ: forward(input, *args, **kwargs)
  MQ->>MQ: if input.numel() > 1 and has_real_quant_gemm_impl(input,…)
  alt Real-quant GEMM path
    MQ->>RQ: forward routed to quant GEMM
    RQ->>RQ: has_real_quant_gemm_impl(input,…) -> True
    RQ->>GI: _real_quant_gemm_impl(input, weight, bias,…)
    GI-->>RQ: output
    RQ-->>MQ: output
  else Fallback
    MQ->>MQ: call super()._forward_impl(...)
  end
  MQ-->>U: output
Loading
sequenceDiagram
  autonumber
  participant CF as Fp8PerTensorLinear (Autograd)
  participant Q as Quant Module
  participant K as Kernels

  Note over CF: Forward
  CF->>Q: fetch/cache _scale_a, _scale_b (use amax / reduce_amax)
  CF->>K: _to_fp8(input, scale_a)
  CF->>K: _to_fp8(weight, scale_b) if needed
  CF->>K: torch._scaled_mm(fp8_a, fp8_b, scales)
  K-->>CF: matmul result (reshaped)
  CF->>CF: add bias if input is float32
  CF->>CF: ctx.save_for_backward(minimal tensors)
  CF->>CF: ctx.compute_bias_grad = bias requires_grad

  Note over CF: Backward
  CF->>K: compute grad_input, grad_weight (reshape-based mm)
  alt ctx.compute_bias_grad
    CF->>CF: compute/reduce grad_bias
  else
    CF->>CF: skip bias grad
  end
  CF->>CF: optional all-reduce on grads (if configured)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

A rabbit hops through compiled code,
scales snug in a tiny load.
FP8 whispers, matmuls sing,
gradients dance on nimble spring.
I twitch my nose — tensors aligned, carrots for speed! 🥕

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly and accurately reflects the PR’s primary objective of improving the realquant GEMM implementation by focusing on bug fixes and performance adjustments. It clearly communicates the main change to teammates scanning history. Although it uses the common abbreviation “impl”, it remains understandable in context.
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch chenjiel/improve_realquant

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4b7cdd6 and b8433ef.

📒 Files selected for processing (2)
  • docs/source/_templates/autosummary/module.rst (1 hunks)
  • modelopt/torch/quantization/backends/__init__.py (0 hunks)
💤 Files with no reviewable changes (1)
  • modelopt/torch/quantization/backends/init.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (1)
docs/source/_templates/autosummary/module.rst (1)

14-16: Filter condition looks correct.

Using the and restores the original plugin gate while cleanly excluding the specific FP8 backend module. Looks good to me.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 26c203a and 2dfac4d.

📒 Files selected for processing (4)
  • modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (5 hunks)
  • modelopt/torch/quantization/backends/nvfp4_gemm.py (3 hunks)
  • modelopt/torch/quantization/nn/modules/quant_linear.py (2 hunks)
  • modelopt/torch/quantization/plugins/megatron.py (2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-16T21:46:46.344Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-212
Timestamp: 2025-09-16T21:46:46.344Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the mtq.quantize function calls the forward_loop under a no_grad context, so wrapping the forward_loop in inference_mode or no_grad is not needed.

Applied to files:

  • modelopt/torch/quantization/plugins/megatron.py
🧬 Code graph analysis (3)
modelopt/torch/quantization/backends/nvfp4_gemm.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • bias (297-301)
modelopt/torch/export/model_config.py (2)
  • bias (152-162)
  • weight (144-149)
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • bias (297-301)
  • amax (231-236)
  • amax (239-250)
modelopt/torch/quantization/utils.py (1)
  • reduce_amax (148-180)
modelopt/torch/quantization/plugins/megatron.py (1)
modelopt/torch/quantization/nn/modules/quant_linear.py (2)
  • _should_run_real_quant_gemm (143-149)
  • has_real_quant_gemm_impl (151-160)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (3)
modelopt/torch/quantization/plugins/megatron.py (1)

421-444: Real-quant GEMM gating LGTM

Thanks for tightening the guard with the multi-element check and the new has_real_quant_gemm_impl helper—this keeps the fast path off the degenerate shapes while still routing through the cached implementation when available.

modelopt/torch/quantization/nn/modules/quant_linear.py (1)

151-182: Forward-path refactor reads clean

Renaming to has_real_quant_gemm_impl and asserting the cached handle before use makes the control flow clearer and avoids redundant registry lookups. The additional numel() guard mirrors the Megatron path nicely.

modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (1)

35-97: Module-level torch.compile helpers look good

Hoisting the FP8 conversion/GEMM into compiled helpers should eliminate the per-call compile expense while keeping the runtime path tidy. Bias handling for float32 inputs is also handled correctly with the post-add.

Copy link

codecov bot commented Sep 24, 2025

Codecov Report

❌ Patch coverage is 20.00000% with 8 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.69%. Comparing base (0178562) to head (b8433ef).
⚠️ Report is 7 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/quantization/backends/nvfp4_gemm.py 0.00% 4 Missing ⚠️
...lopt/torch/quantization/nn/modules/quant_linear.py 33.33% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #368      +/-   ##
==========================================
+ Coverage   73.46%   73.69%   +0.22%     
==========================================
  Files         172      171       -1     
  Lines       17640    17557      -83     
==========================================
- Hits        12959    12938      -21     
+ Misses       4681     4619      -62     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Chenjie Luo <[email protected]>
@cjluo-nv cjluo-nv force-pushed the chenjiel/improve_realquant branch from 2dfac4d to 05a2e78 Compare September 25, 2025 19:09
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (1)

174-194: Safer all_reduce and solid 3D-grad handling.

  • The reshape-based grad_weight computation fixes the 3D input case. Good.
  • Guard all_reduce when grad_input is None.

Apply this diff:

-        if ctx.allreduce_dgrad:
+        if ctx.allreduce_dgrad and grad_input is not None:
             # All-reduce. Note: here async and sync are effectively the same.
             torch.distributed.all_reduce(grad_input, group=ctx.tp_group)
🧹 Nitpick comments (1)
modelopt/torch/quantization/nn/modules/quant_linear.py (1)

151-160: Consider re-probing when implementation was previously missing.

Caching None prevents future matches (e.g., backend becomes available later, device moves to CUDA). If acceptable, keep as-is; otherwise, optionally re-run match when cached value is None.

Apply this minimal change to re-probe only when previously unresolved:

-        if not hasattr(self, "_real_quant_gemm_impl"):
+        if not hasattr(self, "_real_quant_gemm_impl") or self._real_quant_gemm_impl is None:
             self._real_quant_gemm_impl = backends.gemm_registry.find_match(
                 self, input, *args, **kwargs
             )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2dfac4d and 05a2e78.

📒 Files selected for processing (4)
  • modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (5 hunks)
  • modelopt/torch/quantization/backends/nvfp4_gemm.py (3 hunks)
  • modelopt/torch/quantization/nn/modules/quant_linear.py (2 hunks)
  • modelopt/torch/quantization/plugins/megatron.py (2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-16T21:46:46.344Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-212
Timestamp: 2025-09-16T21:46:46.344Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the mtq.quantize function calls the forward_loop under a no_grad context, so wrapping the forward_loop in inference_mode or no_grad is not needed.

Applied to files:

  • modelopt/torch/quantization/plugins/megatron.py
🧬 Code graph analysis (3)
modelopt/torch/quantization/plugins/megatron.py (1)
modelopt/torch/quantization/nn/modules/quant_linear.py (2)
  • _should_run_real_quant_gemm (143-149)
  • has_real_quant_gemm_impl (151-160)
modelopt/torch/quantization/backends/nvfp4_gemm.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • bias (297-301)
modelopt/torch/export/model_config.py (2)
  • bias (152-162)
  • weight (144-149)
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (3)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
  • to (115-123)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
  • bias (297-301)
  • amax (231-236)
  • amax (239-250)
modelopt/torch/quantization/utils.py (1)
  • reduce_amax (148-180)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (8)
modelopt/torch/quantization/backends/nvfp4_gemm.py (3)

176-178: Good fix for 3D input grad_weight.

Reshaping to 2D for matmul is correct and fixes the 3D input gradient bug.


179-186: Bias-grad guard bug and unsafe all_reduce on None.

  • Gate on the boolean, not “is not None”. This bug will compute/return bias grad when bias is absent/frozen. Duplicate of prior review.
  • Only all_reduce grad_input when it exists.

Apply this diff:

-        if ctx.compute_bias_grad is not None:
+        if ctx.compute_bias_grad:
             # Sum all dimensions except the last one
             grad_bias = grad_outputs.sum(dim=list(range(grad_outputs.dim() - 1)))

-        if ctx.allreduce_dgrad:
+        if ctx.allreduce_dgrad and grad_input is not None:
             # All-reduce. Note: here async and sync are effectively the same.
             torch.distributed.all_reduce(grad_input, group=ctx.tp_group)

146-146: Correct use of boolean for bias-grad flag.

Setting a boolean ctx.compute_bias_grad is the right approach and aligns with FP8 path.

modelopt/torch/quantization/nn/modules/quant_linear.py (1)

169-182: Forward gating looks good and avoids GEMM fallback mismatch.

The additional input.numel() > 1 and has-impl check are appropriate. The NVTX range wrapping is fine.

modelopt/torch/quantization/plugins/megatron.py (2)

118-137: Truthiness check for q_tensor_state is correct.

This properly skips restoration for empty dicts in mixed QAT scenarios.


421-447: Real-quant GEMM gating matches core and prevents arg mismatch.

The added checks and forwarding into RealQuantLinear.forward with needed kwargs look good.

modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (2)

35-55: Module-level compiled kernels: good structure and use of _scaled_mm.

  • _to_fp8 and _fp8_gemm_impl are clean and shaped for batched inputs.
  • Transpose/reshape on weight to KxN aligns dimensions correctly.

57-96: Scale caching and FP32 bias path look right.

  • Caching _scale_a/_scale_b only when amax is present avoids stale reuse.
  • FP32 bias handled outside _scaled_mm is correct given its limitation.

Signed-off-by: Chenjie Luo <[email protected]>
Signed-off-by: Chenjie Luo <[email protected]>
Signed-off-by: Chenjie Luo <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 231f7bd and b5e6fc5.

📒 Files selected for processing (2)
  • docs/source/_templates/autosummary/module.rst (1 hunks)
  • modelopt/torch/quantization/backends/__init__.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/backends/__init__.py (1)
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (1)
  • fp8_per_tensor_gemm (57-96)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/quantization/backends/__init__.py (1)

18-21: LGTM: Skip FP8 backend on Windows

Conditionally skipping the FP8 per-tensor GEMM import on Windows avoids the import-time failure while leaving other backends intact. The names stay exposed on supported platforms, so this trade-off looks good.

cjluo-nv and others added 2 commits September 25, 2025 20:38
Signed-off-by: Chenjie Luo <[email protected]>
Copy link
Collaborator

@ChenhanYu ChenhanYu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed the megatron plugins part only

Copy link
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, only a general question.

out_dtype=input.dtype,
use_fast_accum=True,
)
return output.reshape(*input_shape[:-1], output.shape[-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq, could you share why moving these functions to the module level reduces the CPU overheads?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. My suspicion is that if it's not in the module level, the torch.compile will be called everytime showed in the nsys trace.

@kevalmorabia97 kevalmorabia97 merged commit ad091e8 into main Sep 26, 2025
25 of 27 checks passed
@kevalmorabia97 kevalmorabia97 deleted the chenjiel/improve_realquant branch September 26, 2025 04:07
kevalmorabia97 added a commit that referenced this pull request Sep 26, 2025
Signed-off-by: Chenjie Luo <[email protected]>
Signed-off-by: Keval Morabia <[email protected]>
Co-authored-by: Keval Morabia <[email protected]>
yeyu-nvidia pushed a commit that referenced this pull request Oct 1, 2025
Signed-off-by: Chenjie Luo <[email protected]>
Signed-off-by: Keval Morabia <[email protected]>
Co-authored-by: Keval Morabia <[email protected]>
Signed-off-by: Ye Yu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants