-
Notifications
You must be signed in to change notification settings - Fork 602
ONNX: Fix FP8 quantization for the second MLP in LayerNormMLP #2577
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Fixed a critical bug in the ONNX export path for LayerNormMLP where incorrect tuple unpacking prevented FP8 quantization from being applied to the second MLP layer (fc2). The _get_quantizers() method returns 12 quantizers (6 for fc1 and 6 for fc2), but the old code was only unpacking 5 values, skipping fc1_output_quantizer. This caused all subsequent quantizers to be misaligned, meaning fc2_input_quantizer and fc2_weight_quantizer received incorrect values, preventing the Q/DQ pairs from being created for fc2 in TensorRT.
Key changes:
- Correctly unpacks all 12 quantizers from
_get_quantizers(), explicitly extractingfc1_output_quantizerand using proper placeholder unpacking with_for unused gradient quantizers - Renames the final variable from
output_quantizertofc2_output_quantizerfor consistency with the codebase naming conventions - The fix ensures both fc1 and fc2 layers can now run in FP8 precision during ONNX/TensorRT inference
Confidence Score: 5/5
- This PR is safe to merge - it fixes a clear bug with a straightforward, well-understood solution
- The fix addresses a definite bug in tuple unpacking that caused quantizers to be misaligned. The change is minimal, focused, and aligns the ONNX export code with the correct return signature of
_get_quantizers(). The fix has been tested by the author on L4 GPU with DelayedScaling recipe, confirming both MLPs now run in FP8 precision. - No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/layernorm_mlp.py | 5/5 | Fixed critical tuple unpacking bug that prevented FP8 quantization for fc2 layer in ONNX export by correctly extracting all 12 quantizers from _get_quantizers() |
Sequence Diagram
sequenceDiagram
participant Client
participant LayerNormMLP
participant GetQuantizers as _get_quantizers()
participant FC1 as FC1 Layer (GEMM1)
participant Activation
participant FC2 as FC2 Layer (GEMM2)
Client->>LayerNormMLP: onnx_forward(inp, is_grad_enabled)
LayerNormMLP->>GetQuantizers: Request quantizers(False, is_grad_enabled)
Note over GetQuantizers: Returns 12 quantizers:<br/>fc1: input, weight, output, grad_input, grad_weight, grad_output<br/>fc2: input, weight, output, grad_input, grad_weight, grad_output
GetQuantizers-->>LayerNormMLP: (fc1_input_quantizer, fc1_weight_quantizer,<br/>fc1_output_quantizer, _, _, _,<br/>fc2_input_quantizer, fc2_weight_quantizer,<br/>fc2_output_quantizer, _, _, _)
Note over LayerNormMLP: Before fix: Skipped fc1_output_quantizer<br/>causing misalignment of all subsequent quantizers
LayerNormMLP->>FC1: Apply LayerNorm + FC1<br/>with fc1_input_quantizer & fc1_weight_quantizer
FC1-->>LayerNormMLP: fc1_out
LayerNormMLP->>Activation: Apply activation (fp32)
Activation-->>LayerNormMLP: act_out
Note over LayerNormMLP: After fix: fc2_input_quantizer &<br/>fc2_weight_quantizer now correctly aligned
LayerNormMLP->>FC2: Apply FC2 with Q/DQ pairs<br/>fc2_input_quantizer (act_out)<br/>fc2_weight_quantizer (fc2_weight)
FC2-->>LayerNormMLP: fc2_out
LayerNormMLP->>LayerNormMLP: Check fc2_output_quantizer<br/>(raises error if set)
LayerNormMLP-->>Client: Return fc2_out
ksivaman
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
Thanks @victoroliv2! Could you sign-off your commits (guide here)? |
bb69e9a to
c3e6be4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR fixes a critical bug in LayerNormMLP's ONNX export where the second MLP layer was not receiving proper FP8 quantization. The issue was caused by incorrect unpacking of the 12 quantizers returned by _get_quantizers(), which caused quantizer misalignment and prevented Q/DQ pairs from being created for fc2.
Key Changes:
- Correctly unpacks all 12 quantizers (previously only unpacked 5 with
*_catching the rest) - Adds explicit capturing of
fc1_output_quantizer(was skipped before) - Properly separates each quantizer with individual
_placeholders for unused grad quantizers - Changes
output_quantizertofc2_output_quantizerfor clarity and correctness
Impact:
This fix enables FP8 precision for both MLPs in TensorRT, allowing proper FP8 tactics exploration instead of falling back to fp32.
Confidence Score: 5/5
- This PR is safe to merge with minimal risk - it fixes a clear bug with a straightforward correction
- The fix correctly aligns the unpacking with the return values from
_get_quantizers(). The change is well-contained, only affecting ONNX export logic, and has existing test coverage. The bug was clearly identified (quantizer misalignment preventing FP8 for fc2), and the solution directly addresses it by properly unpacking all 12 quantizers. - No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/layernorm_mlp.py | 5/5 | Fixed critical FP8 quantization bug in ONNX export by correctly unpacking all 12 quantizers from _get_quantizers, enabling FP8 precision for fc2 weight |
Sequence Diagram
sequenceDiagram
participant Client
participant onnx_forward
participant _get_quantizers
participant onnx_layernorm
participant onnx_gemm as onnx_gemm (fc1)
participant activation
participant onnx_gemm2 as onnx_gemm (fc2)
Client->>onnx_forward: inp (input tensor)
onnx_forward->>_get_quantizers: get quantizers (fp8_output=False)
_get_quantizers-->>onnx_forward: 12 quantizers (fc1/fc2 input/weight/output + grads)
Note over onnx_forward: Unpack quantizers correctly<br/>(fixed in this PR)
onnx_forward->>onnx_layernorm: LayerNorm + FP8 cast (fc1_input_quantizer)
onnx_layernorm-->>onnx_forward: ln_out (quantized)
alt fc1_weight_quantizer is not None
onnx_forward->>onnx_forward: Q/DQ fc1_weight (FP8)
end
onnx_forward->>onnx_gemm: fc1_weight @ ln_out
onnx_gemm-->>onnx_forward: fc1_out
onnx_forward->>activation: apply activation (fp32)
activation-->>onnx_forward: act_out
alt fc2_weight_quantizer is not None
onnx_forward->>onnx_forward: Q/DQ fc2_weight (FP8)
Note right of onnx_forward: This now works correctly!
end
alt fc2_input_quantizer is not None
onnx_forward->>onnx_forward: Q/DQ act_out (FP8)
Note right of onnx_forward: This now works correctly!
end
onnx_forward->>onnx_gemm2: fc2_weight @ act_out
onnx_gemm2-->>onnx_forward: fc2_out
onnx_forward-->>Client: fc2_out (+ optional outputs)
|
@ksivaman signed! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR fixes a critical bug in the ONNX export of LayerNormMLP where the second MLP layer (FC2) was not receiving proper FP8 quantization. The root cause was incorrect unpacking of the 12-element tuple returned by _get_quantizers().
The Bug:
- The old code only unpacked 5 values from a 12-value tuple (6 quantizers for fc1, 6 for fc2)
fc2_input_quantizerwas incorrectly assigned to position 3 (actuallyfc1_output_quantizer)fc2_weight_quantizerwas incorrectly assigned to position 4 (actuallyfc1_grad_input_quantizer)- The actual
fc2_input_quantizer(position 7),fc2_weight_quantizer(position 8), andfc2_output_quantizer(position 9) were never retrieved - This prevented TensorRT from exploring FP8 tactics for the second MLP
The Fix:
- Now correctly unpacks all 12 values with explicit placeholders (
_) for unused gradient quantizers - Properly retrieves
fc2_input_quantizer,fc2_weight_quantizer, andfc2_output_quantizerfrom their correct positions - Enables Q/DQ (quantize/dequantize) pairs for both FC1 and FC2 layers
- Successfully enables FP8 precision in TensorRT for both MLPs (verified by author on L4 GPU)
Confidence Score: 5/5
- This PR is safe to merge - it fixes a clear bug with a straightforward solution
- The fix correctly addresses the tuple unpacking mismatch. The change is minimal, well-scoped, and the logic is sound. The author has verified that FP8 tactics now work correctly on L4 GPU with DelayedScaling recipe.
- No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/layernorm_mlp.py | 5/5 | Fixed critical bug in ONNX export where FC2 quantizers were unpacked from wrong positions, now correctly retrieves fc2_input_quantizer, fc2_weight_quantizer, and fc2_output_quantizer |
Sequence Diagram
sequenceDiagram
participant ONNX as onnx_forward
participant GQ as _get_quantizers
participant LN as LayerNorm
participant FC1 as FC1 (First MLP)
participant ACT as Activation
participant FC2 as FC2 (Second MLP)
ONNX->>GQ: Call _get_quantizers(False, is_grad_enabled)
GQ-->>ONNX: Return 12 quantizers (6 for fc1, 6 for fc2)
Note over ONNX: OLD: Only unpacked 5 values<br/>fc2 quantizers at wrong positions
Note over ONNX: NEW: Correctly unpack all 12 values<br/>fc2 quantizers at correct positions
ONNX->>LN: onnx_layernorm with fc1_input_quantizer
LN-->>ONNX: ln_out (FP8 quantized)
ONNX->>FC1: Quantize fc1_weight, compute GEMM
FC1-->>ONNX: fc1_out (convert to FP32)
ONNX->>ACT: Apply activation function (FP32)
ACT-->>ONNX: act_out (FP32)
Note over ONNX,FC2: BUG FIX: Now uses correct fc2 quantizers
ONNX->>FC2: Quantize fc2_weight (fc2_weight_quantizer)
ONNX->>FC2: Quantize act_out (fc2_input_quantizer)
ONNX->>FC2: Compute GEMM
FC2-->>ONNX: fc2_out (enables FP8 tactics in TensorRT)
|
/te-ci pytorch |
|
Hi @victoroliv2. Overall LGTM. Could you fix the linter error below: ? |
Signed-off-by: Victor Oliveira <[email protected]>
054cfd1 to
2bf2ab5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Fixes incorrect tuple unpacking in onnx_forward method that was preventing FP8 quantization of the second MLP layer. The old code incorrectly assigned quantizers by skipping intermediate values, causing fc2_input_quantizer and fc2_weight_quantizer to receive wrong values from _get_quantizers(). The fix explicitly unpacks all 12 returned values with proper placeholder usage, ensuring fc2 quantizers are correctly assigned and Q/DQ pairs are created for both MLPs in TensorRT.
Confidence Score: 4/5
- Safe to merge - fixes critical bug in FP8 quantization without introducing new risks
- The fix correctly addresses tuple unpacking to match the 12-value return from _get_quantizers(). The old code was incorrectly mapping quantizers due to missing placeholder unpacking, causing fc2 to receive wrong quantizer objects. The new implementation explicitly unpacks all positions with proper underscore placeholders, ensuring fc2_input_quantizer (position 6), fc2_weight_quantizer (position 7), and fc2_output_quantizer (position 8) are correctly assigned. The fix is minimal, targeted, and the logic is straightforward. No new edge cases or error conditions are introduced.
- No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/module/layernorm_mlp.py | 4/5 | Fixes tuple unpacking to correctly extract fc2 quantizers from _get_quantizers() return value, enabling FP8 quantization for second MLP |
Sequence Diagram
sequenceDiagram
participant ONNXFwd as onnx_forward
participant GetQ as _get_quantizers
participant FC1 as First MLP (fc1)
participant Act as Activation
participant FC2 as Second MLP (fc2)
ONNXFwd->>GetQ: Call _get_quantizers(False, is_grad_enabled)
GetQ-->>ONNXFwd: Return 12 quantizers (fc1: 0-5, fc2: 6-11)
Note over ONNXFwd: OLD: Unpacked positions 0,1,2,3,4<br/>NEW: Unpacks 0,1,_,_,_,_,6,7,8,_,_,_
ONNXFwd->>FC1: Quantize fc1_weight (Q/DQ)
ONNXFwd->>FC1: Apply LayerNorm + fc1_input_quantizer
FC1->>Act: Forward to activation (fp32)
Note over ONNXFwd,FC2: FIX: Now fc2 quantizers<br/>correctly assigned
ONNXFwd->>FC2: Quantize fc2_weight (Q/DQ)
ONNXFwd->>FC2: Quantize act_out via fc2_input_quantizer (Q/DQ)
FC2-->>ONNXFwd: Return fc2_out (FP8 enabled)
|
/te-ci pytorch |
Description
Currently the Q/DQ pair is not created for the second MLP in LayerNormMLP. The end result is that only fp32 tactics are explored in TensorRt.
This fixes the bug and I'm able now to get both MLPs running in FP8 precision (I tried the DelayedScaling recipe on a L4 GPU).
Fixes # (issue)
Type of change
Changes
Small fix to LayerNormMLP ONNX export.
Checklist: