-
Notifications
You must be signed in to change notification settings - Fork 170
Improve realquant gemm impl #368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughRefactors FP8 per-tensor GEMM with compiled helpers and per-tensor scale caching, updates autograd state for FP8 and NVFP4 GEMM paths, renames a RealQuantLinear method to an availability check, tightens Megatron routing, and removes the FP8 wildcard export from backends init. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant U as Caller
participant MQ as _RealQuantMegatronParallelLinear
participant RQ as RealQuantLinear
participant GI as GEMM Impl (cached)
U->>MQ: forward(input, *args, **kwargs)
MQ->>MQ: if input.numel() > 1 and has_real_quant_gemm_impl(input,…)
alt Real-quant GEMM path
MQ->>RQ: forward routed to quant GEMM
RQ->>RQ: has_real_quant_gemm_impl(input,…) -> True
RQ->>GI: _real_quant_gemm_impl(input, weight, bias,…)
GI-->>RQ: output
RQ-->>MQ: output
else Fallback
MQ->>MQ: call super()._forward_impl(...)
end
MQ-->>U: output
sequenceDiagram
autonumber
participant CF as Fp8PerTensorLinear (Autograd)
participant Q as Quant Module
participant K as Kernels
Note over CF: Forward
CF->>Q: fetch/cache _scale_a, _scale_b (use amax / reduce_amax)
CF->>K: _to_fp8(input, scale_a)
CF->>K: _to_fp8(weight, scale_b) if needed
CF->>K: torch._scaled_mm(fp8_a, fp8_b, scales)
K-->>CF: matmul result (reshaped)
CF->>CF: add bias if input is float32
CF->>CF: ctx.save_for_backward(minimal tensors)
CF->>CF: ctx.compute_bias_grad = bias requires_grad
Note over CF: Backward
CF->>K: compute grad_input, grad_weight (reshape-based mm)
alt ctx.compute_bias_grad
CF->>CF: compute/reduce grad_bias
else
CF->>CF: skip bias grad
end
CF->>CF: optional all-reduce on grads (if configured)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (2)
💤 Files with no reviewable changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
🔇 Additional comments (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
cbdd3e0
to
2dfac4d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py
(5 hunks)modelopt/torch/quantization/backends/nvfp4_gemm.py
(3 hunks)modelopt/torch/quantization/nn/modules/quant_linear.py
(2 hunks)modelopt/torch/quantization/plugins/megatron.py
(2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-16T21:46:46.344Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-212
Timestamp: 2025-09-16T21:46:46.344Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the mtq.quantize function calls the forward_loop under a no_grad context, so wrapping the forward_loop in inference_mode or no_grad is not needed.
Applied to files:
modelopt/torch/quantization/plugins/megatron.py
🧬 Code graph analysis (3)
modelopt/torch/quantization/backends/nvfp4_gemm.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
bias
(297-301)modelopt/torch/export/model_config.py (2)
bias
(152-162)weight
(144-149)
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
bias
(297-301)amax
(231-236)amax
(239-250)modelopt/torch/quantization/utils.py (1)
reduce_amax
(148-180)
modelopt/torch/quantization/plugins/megatron.py (1)
modelopt/torch/quantization/nn/modules/quant_linear.py (2)
_should_run_real_quant_gemm
(143-149)has_real_quant_gemm_impl
(151-160)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (3)
modelopt/torch/quantization/plugins/megatron.py (1)
421-444
: Real-quant GEMM gating LGTMThanks for tightening the guard with the multi-element check and the new
has_real_quant_gemm_impl
helper—this keeps the fast path off the degenerate shapes while still routing through the cached implementation when available.modelopt/torch/quantization/nn/modules/quant_linear.py (1)
151-182
: Forward-path refactor reads cleanRenaming to
has_real_quant_gemm_impl
and asserting the cached handle before use makes the control flow clearer and avoids redundant registry lookups. The additionalnumel()
guard mirrors the Megatron path nicely.modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (1)
35-97
: Module-leveltorch.compile
helpers look goodHoisting the FP8 conversion/GEMM into compiled helpers should eliminate the per-call compile expense while keeping the runtime path tidy. Bias handling for float32 inputs is also handled correctly with the post-add.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #368 +/- ##
==========================================
+ Coverage 73.46% 73.69% +0.22%
==========================================
Files 172 171 -1
Lines 17640 17557 -83
==========================================
- Hits 12959 12938 -21
+ Misses 4681 4619 -62 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Chenjie Luo <[email protected]>
2dfac4d
to
05a2e78
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (1)
174-194
: Safer all_reduce and solid 3D-grad handling.
- The reshape-based grad_weight computation fixes the 3D input case. Good.
- Guard all_reduce when grad_input is None.
Apply this diff:
- if ctx.allreduce_dgrad: + if ctx.allreduce_dgrad and grad_input is not None: # All-reduce. Note: here async and sync are effectively the same. torch.distributed.all_reduce(grad_input, group=ctx.tp_group)
🧹 Nitpick comments (1)
modelopt/torch/quantization/nn/modules/quant_linear.py (1)
151-160
: Consider re-probing when implementation was previously missing.Caching None prevents future matches (e.g., backend becomes available later, device moves to CUDA). If acceptable, keep as-is; otherwise, optionally re-run match when cached value is None.
Apply this minimal change to re-probe only when previously unresolved:
- if not hasattr(self, "_real_quant_gemm_impl"): + if not hasattr(self, "_real_quant_gemm_impl") or self._real_quant_gemm_impl is None: self._real_quant_gemm_impl = backends.gemm_registry.find_match( self, input, *args, **kwargs )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py
(5 hunks)modelopt/torch/quantization/backends/nvfp4_gemm.py
(3 hunks)modelopt/torch/quantization/nn/modules/quant_linear.py
(2 hunks)modelopt/torch/quantization/plugins/megatron.py
(2 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-16T21:46:46.344Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#318
File: modelopt/torch/quantization/plugins/transformers_trainer.py:206-212
Timestamp: 2025-09-16T21:46:46.344Z
Learning: In modelopt/torch/quantization/plugins/transformers_trainer.py, the mtq.quantize function calls the forward_loop under a no_grad context, so wrapping the forward_loop in inference_mode or no_grad is not needed.
Applied to files:
modelopt/torch/quantization/plugins/megatron.py
🧬 Code graph analysis (3)
modelopt/torch/quantization/plugins/megatron.py (1)
modelopt/torch/quantization/nn/modules/quant_linear.py (2)
_should_run_real_quant_gemm
(143-149)has_real_quant_gemm_impl
(151-160)
modelopt/torch/quantization/backends/nvfp4_gemm.py (2)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
bias
(297-301)modelopt/torch/export/model_config.py (2)
bias
(152-162)weight
(144-149)
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (3)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
to
(115-123)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (3)
bias
(297-301)amax
(231-236)amax
(239-250)modelopt/torch/quantization/utils.py (1)
reduce_amax
(148-180)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (8)
modelopt/torch/quantization/backends/nvfp4_gemm.py (3)
176-178
: Good fix for 3D input grad_weight.Reshaping to 2D for matmul is correct and fixes the 3D input gradient bug.
179-186
: Bias-grad guard bug and unsafe all_reduce on None.
- Gate on the boolean, not “is not None”. This bug will compute/return bias grad when bias is absent/frozen. Duplicate of prior review.
- Only all_reduce grad_input when it exists.
Apply this diff:
- if ctx.compute_bias_grad is not None: + if ctx.compute_bias_grad: # Sum all dimensions except the last one grad_bias = grad_outputs.sum(dim=list(range(grad_outputs.dim() - 1))) - if ctx.allreduce_dgrad: + if ctx.allreduce_dgrad and grad_input is not None: # All-reduce. Note: here async and sync are effectively the same. torch.distributed.all_reduce(grad_input, group=ctx.tp_group)
146-146
: Correct use of boolean for bias-grad flag.Setting a boolean
ctx.compute_bias_grad
is the right approach and aligns with FP8 path.modelopt/torch/quantization/nn/modules/quant_linear.py (1)
169-182
: Forward gating looks good and avoids GEMM fallback mismatch.The additional
input.numel() > 1
and has-impl check are appropriate. The NVTX range wrapping is fine.modelopt/torch/quantization/plugins/megatron.py (2)
118-137
: Truthiness check for q_tensor_state is correct.This properly skips restoration for empty dicts in mixed QAT scenarios.
421-447
: Real-quant GEMM gating matches core and prevents arg mismatch.The added checks and forwarding into RealQuantLinear.forward with needed kwargs look good.
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (2)
35-55
: Module-level compiled kernels: good structure and use of _scaled_mm.
_to_fp8
and_fp8_gemm_impl
are clean and shaped for batched inputs.- Transpose/reshape on weight to KxN aligns dimensions correctly.
57-96
: Scale caching and FP32 bias path look right.
- Caching _scale_a/_scale_b only when amax is present avoids stale reuse.
- FP32 bias handled outside _scaled_mm is correct given its limitation.
Signed-off-by: Chenjie Luo <[email protected]>
Signed-off-by: Chenjie Luo <[email protected]>
Signed-off-by: Chenjie Luo <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
docs/source/_templates/autosummary/module.rst
(1 hunks)modelopt/torch/quantization/backends/__init__.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/backends/__init__.py (1)
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (1)
fp8_per_tensor_gemm
(57-96)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/quantization/backends/__init__.py (1)
18-21
: LGTM: Skip FP8 backend on WindowsConditionally skipping the FP8 per-tensor GEMM import on Windows avoids the import-time failure while leaving other backends intact. The names stay exposed on supported platforms, so this trade-off looks good.
Signed-off-by: Chenjie Luo <[email protected]>
Signed-off-by: Keval Morabia <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reviewed the megatron plugins part only
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, only a general question.
out_dtype=input.dtype, | ||
use_fast_accum=True, | ||
) | ||
return output.reshape(*input_shape[:-1], output.shape[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq, could you share why moving these functions to the module level reduces the CPU overheads?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure. My suspicion is that if it's not in the module level, the torch.compile will be called everytime showed in the nsys trace.
Signed-off-by: Chenjie Luo <[email protected]> Signed-off-by: Keval Morabia <[email protected]> Co-authored-by: Keval Morabia <[email protected]>
Signed-off-by: Chenjie Luo <[email protected]> Signed-off-by: Keval Morabia <[email protected]> Co-authored-by: Keval Morabia <[email protected]> Signed-off-by: Ye Yu <[email protected]>
What does this PR do?
Bug fix
Overview: ?
Testing
Unittests
Summary by CodeRabbit
New Features
Refactor
Chores