Skip to content

Conversation

@victoroliv2
Copy link

@victoroliv2 victoroliv2 commented Jan 8, 2026

Description

Currently the Q/DQ pair is not created for the second MLP in LayerNormMLP. The end result is that only fp32 tactics are explored in TensorRt.

This fixes the bug and I'm able now to get both MLPs running in FP8 precision (I tried the DelayedScaling recipe on a L4 GPU).

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Small fix to LayerNormMLP ONNX export.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Fixed a critical bug in the ONNX export path for LayerNormMLP where incorrect tuple unpacking prevented FP8 quantization from being applied to the second MLP layer (fc2). The _get_quantizers() method returns 12 quantizers (6 for fc1 and 6 for fc2), but the old code was only unpacking 5 values, skipping fc1_output_quantizer. This caused all subsequent quantizers to be misaligned, meaning fc2_input_quantizer and fc2_weight_quantizer received incorrect values, preventing the Q/DQ pairs from being created for fc2 in TensorRT.

Key changes:

  • Correctly unpacks all 12 quantizers from _get_quantizers(), explicitly extracting fc1_output_quantizer and using proper placeholder unpacking with _ for unused gradient quantizers
  • Renames the final variable from output_quantizer to fc2_output_quantizer for consistency with the codebase naming conventions
  • The fix ensures both fc1 and fc2 layers can now run in FP8 precision during ONNX/TensorRT inference

Confidence Score: 5/5

  • This PR is safe to merge - it fixes a clear bug with a straightforward, well-understood solution
  • The fix addresses a definite bug in tuple unpacking that caused quantizers to be misaligned. The change is minimal, focused, and aligns the ONNX export code with the correct return signature of _get_quantizers(). The fix has been tested by the author on L4 GPU with DelayedScaling recipe, confirming both MLPs now run in FP8 precision.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/module/layernorm_mlp.py 5/5 Fixed critical tuple unpacking bug that prevented FP8 quantization for fc2 layer in ONNX export by correctly extracting all 12 quantizers from _get_quantizers()

Sequence Diagram

sequenceDiagram
    participant Client
    participant LayerNormMLP
    participant GetQuantizers as _get_quantizers()
    participant FC1 as FC1 Layer (GEMM1)
    participant Activation
    participant FC2 as FC2 Layer (GEMM2)
    
    Client->>LayerNormMLP: onnx_forward(inp, is_grad_enabled)
    LayerNormMLP->>GetQuantizers: Request quantizers(False, is_grad_enabled)
    
    Note over GetQuantizers: Returns 12 quantizers:<br/>fc1: input, weight, output, grad_input, grad_weight, grad_output<br/>fc2: input, weight, output, grad_input, grad_weight, grad_output
    
    GetQuantizers-->>LayerNormMLP: (fc1_input_quantizer, fc1_weight_quantizer,<br/>fc1_output_quantizer, _, _, _,<br/>fc2_input_quantizer, fc2_weight_quantizer,<br/>fc2_output_quantizer, _, _, _)
    
    Note over LayerNormMLP: Before fix: Skipped fc1_output_quantizer<br/>causing misalignment of all subsequent quantizers
    
    LayerNormMLP->>FC1: Apply LayerNorm + FC1<br/>with fc1_input_quantizer & fc1_weight_quantizer
    FC1-->>LayerNormMLP: fc1_out
    
    LayerNormMLP->>Activation: Apply activation (fp32)
    Activation-->>LayerNormMLP: act_out
    
    Note over LayerNormMLP: After fix: fc2_input_quantizer &<br/>fc2_weight_quantizer now correctly aligned
    
    LayerNormMLP->>FC2: Apply FC2 with Q/DQ pairs<br/>fc2_input_quantizer (act_out)<br/>fc2_weight_quantizer (fc2_weight)
    FC2-->>LayerNormMLP: fc2_out
    
    LayerNormMLP->>LayerNormMLP: Check fc2_output_quantizer<br/>(raises error if set)
    
    LayerNormMLP-->>Client: Return fc2_out
Loading

ksivaman
ksivaman previously approved these changes Jan 8, 2026
Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ksivaman
Copy link
Member

ksivaman commented Jan 8, 2026

Thanks @victoroliv2! Could you sign-off your commits (guide here)?

@victoroliv2 victoroliv2 force-pushed the fix-layernorm-mlp-fp8 branch from bb69e9a to c3e6be4 Compare January 8, 2026 20:58
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR fixes a critical bug in LayerNormMLP's ONNX export where the second MLP layer was not receiving proper FP8 quantization. The issue was caused by incorrect unpacking of the 12 quantizers returned by _get_quantizers(), which caused quantizer misalignment and prevented Q/DQ pairs from being created for fc2.

Key Changes:

  • Correctly unpacks all 12 quantizers (previously only unpacked 5 with *_ catching the rest)
  • Adds explicit capturing of fc1_output_quantizer (was skipped before)
  • Properly separates each quantizer with individual _ placeholders for unused grad quantizers
  • Changes output_quantizer to fc2_output_quantizer for clarity and correctness

Impact:
This fix enables FP8 precision for both MLPs in TensorRT, allowing proper FP8 tactics exploration instead of falling back to fp32.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk - it fixes a clear bug with a straightforward correction
  • The fix correctly aligns the unpacking with the return values from _get_quantizers(). The change is well-contained, only affecting ONNX export logic, and has existing test coverage. The bug was clearly identified (quantizer misalignment preventing FP8 for fc2), and the solution directly addresses it by properly unpacking all 12 quantizers.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/module/layernorm_mlp.py 5/5 Fixed critical FP8 quantization bug in ONNX export by correctly unpacking all 12 quantizers from _get_quantizers, enabling FP8 precision for fc2 weight

Sequence Diagram

sequenceDiagram
    participant Client
    participant onnx_forward
    participant _get_quantizers
    participant onnx_layernorm
    participant onnx_gemm as onnx_gemm (fc1)
    participant activation
    participant onnx_gemm2 as onnx_gemm (fc2)

    Client->>onnx_forward: inp (input tensor)
    onnx_forward->>_get_quantizers: get quantizers (fp8_output=False)
    _get_quantizers-->>onnx_forward: 12 quantizers (fc1/fc2 input/weight/output + grads)
    
    Note over onnx_forward: Unpack quantizers correctly<br/>(fixed in this PR)
    
    onnx_forward->>onnx_layernorm: LayerNorm + FP8 cast (fc1_input_quantizer)
    onnx_layernorm-->>onnx_forward: ln_out (quantized)
    
    alt fc1_weight_quantizer is not None
        onnx_forward->>onnx_forward: Q/DQ fc1_weight (FP8)
    end
    
    onnx_forward->>onnx_gemm: fc1_weight @ ln_out
    onnx_gemm-->>onnx_forward: fc1_out
    
    onnx_forward->>activation: apply activation (fp32)
    activation-->>onnx_forward: act_out
    
    alt fc2_weight_quantizer is not None
        onnx_forward->>onnx_forward: Q/DQ fc2_weight (FP8)
        Note right of onnx_forward: This now works correctly!
    end
    
    alt fc2_input_quantizer is not None
        onnx_forward->>onnx_forward: Q/DQ act_out (FP8)
        Note right of onnx_forward: This now works correctly!
    end
    
    onnx_forward->>onnx_gemm2: fc2_weight @ act_out
    onnx_gemm2-->>onnx_forward: fc2_out
    
    onnx_forward-->>Client: fc2_out (+ optional outputs)
Loading

@victoroliv2
Copy link
Author

@ksivaman signed!

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR fixes a critical bug in the ONNX export of LayerNormMLP where the second MLP layer (FC2) was not receiving proper FP8 quantization. The root cause was incorrect unpacking of the 12-element tuple returned by _get_quantizers().

The Bug:

  • The old code only unpacked 5 values from a 12-value tuple (6 quantizers for fc1, 6 for fc2)
  • fc2_input_quantizer was incorrectly assigned to position 3 (actually fc1_output_quantizer)
  • fc2_weight_quantizer was incorrectly assigned to position 4 (actually fc1_grad_input_quantizer)
  • The actual fc2_input_quantizer (position 7), fc2_weight_quantizer (position 8), and fc2_output_quantizer (position 9) were never retrieved
  • This prevented TensorRT from exploring FP8 tactics for the second MLP

The Fix:

  • Now correctly unpacks all 12 values with explicit placeholders (_) for unused gradient quantizers
  • Properly retrieves fc2_input_quantizer, fc2_weight_quantizer, and fc2_output_quantizer from their correct positions
  • Enables Q/DQ (quantize/dequantize) pairs for both FC1 and FC2 layers
  • Successfully enables FP8 precision in TensorRT for both MLPs (verified by author on L4 GPU)

Confidence Score: 5/5

  • This PR is safe to merge - it fixes a clear bug with a straightforward solution
  • The fix correctly addresses the tuple unpacking mismatch. The change is minimal, well-scoped, and the logic is sound. The author has verified that FP8 tactics now work correctly on L4 GPU with DelayedScaling recipe.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/module/layernorm_mlp.py 5/5 Fixed critical bug in ONNX export where FC2 quantizers were unpacked from wrong positions, now correctly retrieves fc2_input_quantizer, fc2_weight_quantizer, and fc2_output_quantizer

Sequence Diagram

sequenceDiagram
    participant ONNX as onnx_forward
    participant GQ as _get_quantizers
    participant LN as LayerNorm
    participant FC1 as FC1 (First MLP)
    participant ACT as Activation
    participant FC2 as FC2 (Second MLP)
    
    ONNX->>GQ: Call _get_quantizers(False, is_grad_enabled)
    GQ-->>ONNX: Return 12 quantizers (6 for fc1, 6 for fc2)
    Note over ONNX: OLD: Only unpacked 5 values<br/>fc2 quantizers at wrong positions
    Note over ONNX: NEW: Correctly unpack all 12 values<br/>fc2 quantizers at correct positions
    
    ONNX->>LN: onnx_layernorm with fc1_input_quantizer
    LN-->>ONNX: ln_out (FP8 quantized)
    
    ONNX->>FC1: Quantize fc1_weight, compute GEMM
    FC1-->>ONNX: fc1_out (convert to FP32)
    
    ONNX->>ACT: Apply activation function (FP32)
    ACT-->>ONNX: act_out (FP32)
    
    Note over ONNX,FC2: BUG FIX: Now uses correct fc2 quantizers
    ONNX->>FC2: Quantize fc2_weight (fc2_weight_quantizer)
    ONNX->>FC2: Quantize act_out (fc2_input_quantizer)
    ONNX->>FC2: Compute GEMM
    FC2-->>ONNX: fc2_out (enables FP8 tactics in TensorRT)
Loading

@ptrendx
Copy link
Member

ptrendx commented Jan 8, 2026

/te-ci pytorch

@ptrendx
Copy link
Member

ptrendx commented Jan 8, 2026

Hi @victoroliv2. Overall LGTM. Could you fix the linter error below:

transformer_engine/pytorch/module/layernorm_mlp.py:2251:12: W0612: Unused variable 'fc1_output_quantizer' (unused-variable)

?

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Fixes incorrect tuple unpacking in onnx_forward method that was preventing FP8 quantization of the second MLP layer. The old code incorrectly assigned quantizers by skipping intermediate values, causing fc2_input_quantizer and fc2_weight_quantizer to receive wrong values from _get_quantizers(). The fix explicitly unpacks all 12 returned values with proper placeholder usage, ensuring fc2 quantizers are correctly assigned and Q/DQ pairs are created for both MLPs in TensorRT.

Confidence Score: 4/5

  • Safe to merge - fixes critical bug in FP8 quantization without introducing new risks
  • The fix correctly addresses tuple unpacking to match the 12-value return from _get_quantizers(). The old code was incorrectly mapping quantizers due to missing placeholder unpacking, causing fc2 to receive wrong quantizer objects. The new implementation explicitly unpacks all positions with proper underscore placeholders, ensuring fc2_input_quantizer (position 6), fc2_weight_quantizer (position 7), and fc2_output_quantizer (position 8) are correctly assigned. The fix is minimal, targeted, and the logic is straightforward. No new edge cases or error conditions are introduced.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/module/layernorm_mlp.py 4/5 Fixes tuple unpacking to correctly extract fc2 quantizers from _get_quantizers() return value, enabling FP8 quantization for second MLP

Sequence Diagram

sequenceDiagram
    participant ONNXFwd as onnx_forward
    participant GetQ as _get_quantizers
    participant FC1 as First MLP (fc1)
    participant Act as Activation
    participant FC2 as Second MLP (fc2)
    
    ONNXFwd->>GetQ: Call _get_quantizers(False, is_grad_enabled)
    GetQ-->>ONNXFwd: Return 12 quantizers (fc1: 0-5, fc2: 6-11)
    
    Note over ONNXFwd: OLD: Unpacked positions 0,1,2,3,4<br/>NEW: Unpacks 0,1,_,_,_,_,6,7,8,_,_,_
    
    ONNXFwd->>FC1: Quantize fc1_weight (Q/DQ)
    ONNXFwd->>FC1: Apply LayerNorm + fc1_input_quantizer
    FC1->>Act: Forward to activation (fp32)
    
    Note over ONNXFwd,FC2: FIX: Now fc2 quantizers<br/>correctly assigned
    
    ONNXFwd->>FC2: Quantize fc2_weight (Q/DQ)
    ONNXFwd->>FC2: Quantize act_out via fc2_input_quantizer (Q/DQ)
    FC2-->>ONNXFwd: Return fc2_out (FP8 enabled)
Loading

@victoroliv2
Copy link
Author

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants