-
Notifications
You must be signed in to change notification settings - Fork 162
Fix the bug in realquant #301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Ye Yu <[email protected]>
WalkthroughAdds a guard so _RealQuantMegatronParallelLinear._forward_impl uses the real-quant GEMM path only when input.numel() > 1; generalizes backward grad_weight in Fp8PerTensorLinear to flatten leading dims before the matmul; updates CompressConfig.compress docstring to state default True. No public API changes. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor Caller
participant Layer as _RealQuantMegatronParallelLinear
participant Super as Superclass (_forward_impl)
participant RQL as RealQuantLinear.forward
Caller->>Layer: _forward_impl(input)
alt Conditions met AND input.numel() > 1
Note over Layer: Real-quant GEMM path chosen
Layer->>RQL: forward(input, ...)
RQL-->>Layer: output
else Fallback
Note over Layer,Super: Single-element input or guard not satisfied
Layer->>Super: _forward_impl(input)
Super-->>Layer: output
end
Layer-->>Caller: output
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Pre-merge checks (3 passed)✅ Passed checks (3 passed)
Poem
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. 📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
✅ Files skipped from review due to trivial changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
modelopt/torch/quantization/plugins/megatron.py (1)
421-425
: Threshold check: align with the described bug condition (< 1) or confirm > 1 is intendedPR says fallback happens when input.numel() < 1. The code uses > 1, which also routes numel()==1 to the fallback path. If 1-element inputs are valid for real-quant GEMM, consider > 0 instead.
- and input.numel() > 1 + and input.numel() > 0Please confirm backend support for numel()==1; if unsupported, keep > 1 and add a short comment to document it.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/quantization/plugins/megatron.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/plugins/megatron.py (1)
modelopt/torch/quantization/nn/modules/quant_linear.py (2)
_should_run_real_quant_gemm
(143-149)get_real_quant_gemm_impl
(151-160)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (1)
modelopt/torch/quantization/plugins/megatron.py (1)
421-425
: Ignore incorrect invocation suggestion:_should_run_real_quant_gemm
is decorated with@property
(see modelopt/torch/quantization/nn/modules/quant_linear.py:143), so referencing it without()
is the correct usage.Likely an incorrect or invalid review comment.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #301 +/- ##
==========================================
- Coverage 73.93% 73.85% -0.09%
==========================================
Files 172 172
Lines 17408 17408
==========================================
- Hits 12871 12857 -14
- Misses 4537 4551 +14 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Ye Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (1)
185-189
: Guard all_reduce when grad_input is None to avoid crashes.If input_tensor.requires_grad is False, grad_input stays None and all_reduce will throw. Guard it.
- if ctx.allreduce_dgrad: - # All-reduce. Note: here async and sync are effectively the same. - torch.distributed.all_reduce(grad_input, group=ctx.tp_group) + if ctx.allreduce_dgrad and grad_input is not None: + # All-reduce. Note: here async and sync are effectively the same. + torch.distributed.all_reduce(grad_input, group=ctx.tp_group)
🧹 Nitpick comments (4)
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (4)
178-180
: Minor readability nit: prefer flatten()/einsum over reshape()+T.Either of the below is equivalent and a bit clearer:
- grad_weight = grad_outputs.reshape(-1, grad_outputs.shape[-1]).T @ input_tensor.reshape( - -1, input_tensor.shape[-1] - ) + # option A: flatten leading dims + grad_weight = grad_outputs.flatten(0, -2).T @ input_tensor.flatten(0, -2) + # option B: einsum + # grad_weight = torch.einsum('...o,...i->oi', grad_outputs, input_tensor)
76-80
: Avoid.data
; use.detach()
for safety with autograd.Accessing
.data
can lead to silent autograd issues. Detach is safer.- else: - weight_fp8 = quant_module.weight.data + else: + weight_fp8 = quant_module.weight.detach()
62-66
: Zero amax assertions: confirm behavior on all-zero tensors.Asserting nonzero amax will hard-crash on all-zero inputs/weights. If that scenario can occur (e.g., padded steps, masked tokens), consider a small epsilon clamp or a graceful fallback to non-FP8 path.
Example:
- assert input_amax != 0 + if input_amax == 0: + return quant_module.linear(input, bias) # or set scale to a safe epsilon - assert weight_amax != 0 + if weight_amax == 0: + return quant_module.linear(input, bias)If fallback here is undesirable, clamp:
input_amax = torch.clamp(input_amax, min=torch.finfo(torch.float32).eps) weight_amax = torch.clamp(weight_amax, min=torch.finfo(torch.float32).eps)Also applies to: 71-75
160-189
: Add a unit test for the new grad_weight path with >2D inputs.Recommend a test comparing to einsum:
- Input shape [B, T, Ii], weight [Oo, Ii]; check dW ≈ einsum('bto,bti->oi', dY, X).
I can draft a minimal pytest covering float16/float32, channels-last, and tensor-parallel shard if helpful.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (1)
178-180
: Gradient generalization for arbitrary leading dims is correct.Shape math checks out: (B, Oo)ᵀ @ (B, Ii) → (Oo, Ii). This fixes multi-dim batches.
Signed-off-by: Ye Yu <[email protected]>
What does this PR do?
Type of change: Bug fix
Overview:
RealQuantLinear will fall back to regular forward if input.numel()<1. However, this is not caught in _RealQuantMegatronParallelLinear. This will lead to adding extra input to regular linear forward and failed.
It also fix the bug that in fp8_per_tensor_gemm, the weight grad compute has a shape mismatch.
Usage
# Add a code snippet demonstrating how to use this
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit