Skip to content

Conversation

yeyu-nvidia
Copy link
Contributor

@yeyu-nvidia yeyu-nvidia commented Sep 8, 2025

What does this PR do?

Type of change: Bug fix

Overview:
RealQuantLinear will fall back to regular forward if input.numel()<1. However, this is not caught in _RealQuantMegatronParallelLinear. This will lead to adding extra input to regular linear forward and failed.
It also fix the bug that in fp8_per_tensor_gemm, the weight grad compute has a shape mismatch.

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: No
  • Did you add or update any necessary documentation?: No
  • Did you update Changelog?: No

Additional Information

Summary by CodeRabbit

  • Bug Fixes
    • Single-element inputs to quantized linear layers now fall back to a safe path, avoiding errors and ensuring stable behavior.
    • Gradient computation for FP8 per-tensor linear ops was generalized to handle arbitrary leading dimensions, fixing incorrect grads for batched/flattened inputs.
    • Ensures consistent behavior across model/sequence-parallel configurations.
  • Documentation
    • Clarified docstring to indicate compress defaults for weights.
  • Chores
    • No changes to public APIs.

@yeyu-nvidia yeyu-nvidia requested a review from a team as a code owner September 8, 2025 21:51
@yeyu-nvidia yeyu-nvidia requested a review from RalphMao September 8, 2025 21:51
Copy link

coderabbitai bot commented Sep 8, 2025

Walkthrough

Adds a guard so _RealQuantMegatronParallelLinear._forward_impl uses the real-quant GEMM path only when input.numel() > 1; generalizes backward grad_weight in Fp8PerTensorLinear to flatten leading dims before the matmul; updates CompressConfig.compress docstring to state default True. No public API changes.

Changes

Cohort / File(s) Summary
Megatron real-quant forward guard
modelopt/torch/quantization/plugins/megatron.py
Adds input.numel() > 1 to the conditional selecting the real-quant GEMM path and reformats the boolean across multiple lines; single-element inputs now fall back to the superclass _forward_impl.
FP8 per-tensor backward grad reshape
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py
Changes grad_weight computation to use grad_outputs.reshape(-1, grad_outputs.shape[-1]).T @ input_tensor.reshape(-1, input_tensor.shape[-1]), flattening leading dims before the matmul to support arbitrary leading dimensions; other backward logic unchanged.
Config docstring update
modelopt/torch/quantization/config.py
Updates the CompressConfig.compress docstring to state the default is True for all weights (text change only; no behavioral or signature changes).

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor Caller
  participant Layer as _RealQuantMegatronParallelLinear
  participant Super as Superclass (_forward_impl)
  participant RQL as RealQuantLinear.forward

  Caller->>Layer: _forward_impl(input)
  alt Conditions met AND input.numel() > 1
    Note over Layer: Real-quant GEMM path chosen
    Layer->>RQL: forward(input, ...)
    RQL-->>Layer: output
  else Fallback
    Note over Layer,Super: Single-element input or guard not satisfied
    Layer->>Super: _forward_impl(input)
    Super-->>Layer: output
  end
  Layer-->>Caller: output
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Pre-merge checks (3 passed)

✅ Passed checks (3 passed)
Check name Status Explanation
Title Check ✅ Passed The title “Fix the bug in realquant” accurately refers to the RealQuant bug fix but does not mention the fp8_per_tensor_gemm gradient shape correction or the updated documentation, making it overly narrow relative to the full changeset.
Description Check ✅ Passed The PR description clearly outlines both bug fixes in RealQuantMegatronParallelLinear and fp8_per_tensor_gemm, labels the change as a bug fix, and provides context for the issue, making it directly relevant and informative.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

Poem

A rabbit in code with a hop and a twitch,
One-element tensors now take a small switch.
Bigger batches dance through real-quant gates,
FP8 grads flatten their lined-up plates.
I nibble the docs and hop — all set! 🐇

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b4cb2d5 and fc44595.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/config.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • modelopt/torch/quantization/config.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yeyu/fix_realquant_bug

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
modelopt/torch/quantization/plugins/megatron.py (1)

421-425: Threshold check: align with the described bug condition (< 1) or confirm > 1 is intended

PR says fallback happens when input.numel() < 1. The code uses > 1, which also routes numel()==1 to the fallback path. If 1-element inputs are valid for real-quant GEMM, consider > 0 instead.

-            and input.numel() > 1
+            and input.numel() > 0

Please confirm backend support for numel()==1; if unsupported, keep > 1 and add a short comment to document it.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cf6f1d4 and 760c583.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/plugins/megatron.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/quantization/plugins/megatron.py (1)
modelopt/torch/quantization/nn/modules/quant_linear.py (2)
  • _should_run_real_quant_gemm (143-149)
  • get_real_quant_gemm_impl (151-160)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (1)
modelopt/torch/quantization/plugins/megatron.py (1)

421-425: Ignore incorrect invocation suggestion: _should_run_real_quant_gemm is decorated with @property (see modelopt/torch/quantization/nn/modules/quant_linear.py:143), so referencing it without () is the correct usage.

Likely an incorrect or invalid review comment.

Copy link

codecov bot commented Sep 8, 2025

Codecov Report

❌ Patch coverage is 0% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 73.85%. Comparing base (cf6f1d4) to head (fc44595).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
...torch/quantization/backends/fp8_per_tensor_gemm.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #301      +/-   ##
==========================================
- Coverage   73.93%   73.85%   -0.09%     
==========================================
  Files         172      172              
  Lines       17408    17408              
==========================================
- Hits        12871    12857      -14     
- Misses       4537     4551      +14     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Ye Yu <[email protected]>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (1)

185-189: Guard all_reduce when grad_input is None to avoid crashes.

If input_tensor.requires_grad is False, grad_input stays None and all_reduce will throw. Guard it.

-        if ctx.allreduce_dgrad:
-            # All-reduce. Note: here async and sync are effectively the same.
-            torch.distributed.all_reduce(grad_input, group=ctx.tp_group)
+        if ctx.allreduce_dgrad and grad_input is not None:
+            # All-reduce. Note: here async and sync are effectively the same.
+            torch.distributed.all_reduce(grad_input, group=ctx.tp_group)
🧹 Nitpick comments (4)
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (4)

178-180: Minor readability nit: prefer flatten()/einsum over reshape()+T.

Either of the below is equivalent and a bit clearer:

- grad_weight = grad_outputs.reshape(-1, grad_outputs.shape[-1]).T @ input_tensor.reshape(
-     -1, input_tensor.shape[-1]
- )
+ # option A: flatten leading dims
+ grad_weight = grad_outputs.flatten(0, -2).T @ input_tensor.flatten(0, -2)
+ # option B: einsum
+ # grad_weight = torch.einsum('...o,...i->oi', grad_outputs, input_tensor)

76-80: Avoid .data; use .detach() for safety with autograd.

Accessing .data can lead to silent autograd issues. Detach is safer.

-    else:
-        weight_fp8 = quant_module.weight.data
+    else:
+        weight_fp8 = quant_module.weight.detach()

62-66: Zero amax assertions: confirm behavior on all-zero tensors.

Asserting nonzero amax will hard-crash on all-zero inputs/weights. If that scenario can occur (e.g., padded steps, masked tokens), consider a small epsilon clamp or a graceful fallback to non-FP8 path.

Example:

-        assert input_amax != 0
+        if input_amax == 0:
+            return quant_module.linear(input, bias)  # or set scale to a safe epsilon
-        assert weight_amax != 0
+        if weight_amax == 0:
+            return quant_module.linear(input, bias)

If fallback here is undesirable, clamp:

input_amax = torch.clamp(input_amax, min=torch.finfo(torch.float32).eps)
weight_amax = torch.clamp(weight_amax, min=torch.finfo(torch.float32).eps)

Also applies to: 71-75


160-189: Add a unit test for the new grad_weight path with >2D inputs.

Recommend a test comparing to einsum:

  • Input shape [B, T, Ii], weight [Oo, Ii]; check dW ≈ einsum('bto,bti->oi', dY, X).

I can draft a minimal pytest covering float16/float32, channels-last, and tensor-parallel shard if helpful.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 760c583 and b4cb2d5.

📒 Files selected for processing (1)
  • modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/torch/quantization/backends/fp8_per_tensor_gemm.py (1)

178-180: Gradient generalization for arbitrary leading dims is correct.

Shape math checks out: (B, Oo)ᵀ @ (B, Ii) → (Oo, Ii). This fixes multi-dim batches.

Signed-off-by: Ye Yu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant