Skip to content

[#9753][feat] AutoDeploy: Implement add rms_norm fusion#9754

Merged
nvchenghaoz merged 3 commits intoNVIDIA:mainfrom
nv-auto-deploy:chenghao/rms_add_1205
Dec 8, 2025
Merged

[#9753][feat] AutoDeploy: Implement add rms_norm fusion#9754
nvchenghaoz merged 3 commits intoNVIDIA:mainfrom
nv-auto-deploy:chenghao/rms_add_1205

Conversation

@nvchenghaoz
Copy link
Collaborator

@nvchenghaoz nvchenghaoz commented Dec 5, 2025

Add the transform for add and rms_norm fusion using flashinfer API.

The accepted pattern input for the transform is flashinfer rms_norm only.

Q: Do we want to fuse the add + Triton_rms_norm to flashinfer_fused_add_rms_norm or do we want to implement a triton version fused op instead?

#9753

Summary by CodeRabbit

  • New Features

    • Introduced a fused add + RMSNorm operation leveraging FlashInfer for improved performance.
    • Switched RMSNorm backend from Triton to FlashInfer for better optimization.
  • Tests

    • Added unit tests for the new fused add + RMSNorm operation.
    • Added integration tests validating the fusion transformation.

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
@nvchenghaoz nvchenghaoz requested a review from a team as a code owner December 5, 2025 19:01
@nvchenghaoz
Copy link
Collaborator Author

/bot run

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Dec 5, 2025

📝 Walkthrough

Walkthrough

This PR introduces support for a fused add + RMSNorm operation powered by FlashInfer. The changes include configuration updates to use FlashInfer for RMSNorm, a new custom operator implementing the fused operation, a graph transformation to detect and replace the pattern, and comprehensive unit tests for both the operator and transformation.

Changes

Cohort / File(s) Summary
Configuration
tensorrt_llm/_torch/auto_deploy/config/default.yaml
Updated fuse_rmsnorm transform to use flashinfer backend instead of triton. Added new fuse_add_rms_norm transform with post_load_fusion stage and enabled status.
Custom Operator
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_fused_add_rms_norm.py
Implements fused add + RMSNorm custom operator. Provides in-place operation flashinfer_fused_add_rms_norm_inplace() that flattens tensors, calls FlashInfer's fused operation, and reshapes back. Public wrapper flashinfer_fused_add_rms_norm() returns the modified tensors. Registers as Torch custom op with mutating argument semantics.
Transform Implementation
tensorrt_llm/_torch/auto_deploy/transform/library/fused_add_rms_norm.py
Implements FuseAddRMSNorm transformation class that detects and fuses the pattern: add → cast to bfloat16 → RMSNorm. Uses pattern matching to identify sequences and replaces them with the fused custom operator.
Unit Tests
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_fused_add_rms_norm_op.py
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fused_add_rms_norm.py
Test custom operator correctness with reference RMSNorm implementation and tolerance checks. Test transformation by verifying pattern fusion occurs in graph and outputs match original model behavior.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Optimizer as InferenceOptimizer
    participant Transform as FuseAddRMSNorm
    participant Matcher as PatternMatcher
    participant Graph as GraphModule
    participant CustomOp as flashinfer_fused_add_rms_norm

    User->>Optimizer: Apply optimization
    Optimizer->>Transform: Execute transform
    Transform->>Matcher: Create pattern for add→cast→rms_norm
    Matcher->>Graph: Scan for pattern matches
    Graph-->>Matcher: Pattern found (nodes identified)
    Matcher-->>Transform: Match information
    Transform->>Graph: Replace matched nodes with custom op
    Graph->>CustomOp: Call fused_add_rms_norm
    CustomOp-->>Graph: Return result
    Transform-->>Optimizer: Return transformed GraphModule
    Optimizer-->>User: Optimized model ready
Loading
sequenceDiagram
    participant Input as Input Tensors<br/>(x, residual, weight)
    participant Wrapper as flashinfer_fused_add_rms_norm
    participant InPlace as _inplace operation
    participant Flatten as Flatten to 2D
    participant FlashInfer as flashinfer.norm<br/>.fused_add_rmsnorm
    participant Reshape as Reshape to Original

    Input->>Wrapper: (x, residual, weight, eps)
    Wrapper->>InPlace: Call in-place version
    InPlace->>Flatten: Flatten (batch×seq, hidden)
    Flatten->>FlashInfer: Call with enable_pdl flag
    FlashInfer->>FlashInfer: Compute residual = x + residual<br/>Apply RMSNorm
    FlashInfer->>Reshape: Return normalized result
    Reshape->>InPlace: Reshape to original shapes
    InPlace->>Wrapper: In-place mutation complete
    Wrapper-->>Input: Return (x, residual)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • flashinfer_fused_add_rms_norm.py: Tensor reshaping logic, in-place semantics, and environment flag handling require careful verification for correctness and memory safety
  • fused_add_rms_norm.py: Pattern matching and graph replacement logic is intricate; verify pattern correctly identifies add→cast→rms_norm sequences and replacement routing is proper
  • Pattern matcher configuration: Dummy args, op ignore types, and scalar workaround settings need validation against actual graph patterns
  • Custom op registration: Mutating argument declarations and their interaction with the wrapper function require verification

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 23.08% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main feature: implementing add RMSNorm fusion for AutoDeploy, which aligns with the changeset that adds this fusion transform.
Description check ✅ Passed The description explains the feature (add and rms_norm fusion using flashinfer) and briefly mentions the implementation scope, but lacks comprehensive details about test coverage and doesn't follow the provided template structure.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_fused_add_rms_norm.py (1)

38-43: Dead code: view() calls have no effect.

The view() calls on lines 41-42 return new tensor views but the results are discarded. Since x_flat and residual_flat are views of x and residual, the in-place modification by flashinfer.norm.fused_add_rmsnorm already modifies the original tensors. These lines can be removed.

     flashinfer.norm.fused_add_rmsnorm(
         x_flat, residual_flat, weight, eps, enable_pdl=get_env_enable_pdl()
     )
-    x_flat.view(x_shape)
-    residual_flat.view(residual_shape)
     return
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 299601a and 57c5794.

📒 Files selected for processing (5)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_fused_add_rms_norm.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_add_rms_norm.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_fused_add_rms_norm_op.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fused_add_rms_norm.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Always maintain the namespace when importing in Python, even if only one class or function from a module is used (e.g., use from package.subpackage import foo and then foo.SomeClass() instead of from package.subpackage.foo import SomeClass)
Python filenames should use snake_case (e.g., some_file.py)
Python class names should use PascalCase (e.g., class SomeClass)
Python function and method names should use snake_case (e.g., def my_awesome_function():)
Python local variable names should use snake_case, with prefix k for variable names that start with a number (e.g., k_99th_percentile = ...)
Python global variables should use upper snake_case with prefix G (e.g., G_MY_GLOBAL = ...)
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...)
Avoid shadowing variables declared in an outer scope in Python
Initialize all externally visible members of a Python class in the constructor
For Python interfaces that may be used outside a file, prefer docstrings over comments
Python comments should be reserved for code within a function, or interfaces that are local to a file
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx
Python attributes and variables can be documented inline with type and description (e.g., self.x = 5 followed by """<type>: Description of 'x'""" )
Avoid using reflection in Python when functionality can be easily achieved without reflection
When using try-except blocks in Python, limit the except clause to the smallest set of specific errors possible instead of catching all exceptions
When using try-except blocks in Python to handle multiple possible variable types (duck-typing), keep the body of the try as small as possible and use the else block to implement the logic

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_fused_add_rms_norm_op.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_add_rms_norm.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fused_add_rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_fused_add_rms_norm.py
**/*.{cpp,h,cu,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header that includes the current year at the top

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_fused_add_rms_norm_op.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_add_rms_norm.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fused_add_rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_fused_add_rms_norm.py
🧠 Learnings (5)
📓 Common learnings
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py:180-182
Timestamp: 2025-10-20T17:09:21.560Z
Learning: In tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py, the _gated_rmsnorm_replacement function does not need to cast the output of torch.ops.auto_deploy.torch_rmsnorm_gated back to the input dtype, even though the custom op returns fp32. The dtype handling is managed elsewhere or the fp32 output is acceptable for downstream consumers.
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py:6-6
Timestamp: 2025-10-20T16:54:09.824Z
Learning: In tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, the import `from ...modules.mamba.layernorm_gated import _layer_norm_fwd` is correct and should not be changed to modules.fla.layernorm_gated. The _layer_norm_fwd function exists in both modules/mamba/layernorm_gated.py and modules/fla/layernorm_gated.py, but the mamba version is the intended implementation for this use case.
📚 Learning: 2025-10-20T16:54:09.824Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py:6-6
Timestamp: 2025-10-20T16:54:09.824Z
Learning: In tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, the import `from ...modules.mamba.layernorm_gated import _layer_norm_fwd` is correct and should not be changed to modules.fla.layernorm_gated. The _layer_norm_fwd function exists in both modules/mamba/layernorm_gated.py and modules/fla/layernorm_gated.py, but the mamba version is the intended implementation for this use case.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_fused_add_rms_norm_op.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_add_rms_norm.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fused_add_rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_fused_add_rms_norm.py
📚 Learning: 2025-10-20T17:09:21.560Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py:180-182
Timestamp: 2025-10-20T17:09:21.560Z
Learning: In tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py, the _gated_rmsnorm_replacement function does not need to cast the output of torch.ops.auto_deploy.torch_rmsnorm_gated back to the input dtype, even though the custom op returns fp32. The dtype handling is managed elsewhere or the fp32 output is acceptable for downstream consumers.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_fused_add_rms_norm_op.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fused_add_rms_norm.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fused_add_rms_norm.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_fused_add_rms_norm.py
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
Repo: NVIDIA/TensorRT-LLM PR: 6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fused_add_rms_norm.py
📚 Learning: 2025-09-09T09:40:45.658Z
Learnt from: fredricz-20070104
Repo: NVIDIA/TensorRT-LLM PR: 7645
File: tests/integration/test_lists/qa/llm_function_core.txt:648-648
Timestamp: 2025-09-09T09:40:45.658Z
Learning: In TensorRT-LLM test lists, it's common and intentional for the same test to appear in multiple test list files when they serve different purposes (e.g., llm_function_core.txt for comprehensive core functionality testing and llm_function_core_sanity.txt for quick sanity checks). This duplication allows tests to be run in different testing contexts.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fused_add_rms_norm.py
🧬 Code graph analysis (4)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_fused_add_rms_norm_op.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_fused_add_rms_norm.py (1)
  • flashinfer_fused_add_rms_norm (51-54)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_add_rms_norm.py (4)
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_fused_add_rms_norm.py (1)
  • flashinfer_fused_add_rms_norm (51-54)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (11-92)
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (2)
  • ADPatternMatcherPass (61-67)
  • register_ad_pattern (99-182)
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py (1)
  • dummy_args (153-158)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fused_add_rms_norm.py (4)
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_fused_add_rms_norm.py (2)
  • flashinfer_fused_add_rms_norm (51-54)
  • flashinfer_fused_add_rms_norm_inplace (21-43)
tensorrt_llm/_torch/auto_deploy/export/export.py (1)
  • torch_export_to_gm (276-344)
tensorrt_llm/_torch/auto_deploy/transform/optimizer.py (1)
  • InferenceOptimizer (23-78)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (198-221)
tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_fused_add_rms_norm.py (1)
tensorrt_llm/_torch/flashinfer_utils.py (1)
  • get_env_enable_pdl (10-15)
🪛 Ruff (0.14.7)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_add_rms_norm.py

42-42: Unused method argument: cm

(ARG002)


43-43: Unused method argument: factory

(ARG002)


44-44: Unused method argument: shared_config

(ARG002)

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fused_add_rms_norm.py

56-58: Avoid specifying long messages outside the exception class

(TRY003)

tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_fused_add_rms_norm.py

47-47: Unused function argument: x

(ARG001)


47-47: Unused function argument: residual

(ARG001)


47-47: Unused function argument: weight

(ARG001)


47-47: Unused function argument: eps

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Check PR Title Format
  • GitHub Check: Check PR Checklist Resolution
  • GitHub Check: Pre-commit Check
🔇 Additional comments (10)
tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

129-136: LGTM!

The configuration changes correctly enable the new fused add + RMSNorm transform. The placement after fuse_rmsnorm is appropriate since fuse_add_rms_norm depends on flashinfer_rms_norm being present in the graph (as noted in the pattern matching).

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_fused_add_rms_norm_op.py (2)

18-47: LGTM!

The test correctly validates both numerical accuracy and in-place modification behavior. Consider extending dtype coverage to torch.float16 in future iterations if the FlashInfer op supports it.


1-6: Missing NVIDIA copyright header.

Per coding guidelines, all TensorRT-LLM code files should contain an NVIDIA copyright header at the top.

Add the copyright header at the beginning of the file:

+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import pytest
 import torch
⛔ Skipped due to learnings
Learnt from: galagam
Repo: NVIDIA/TensorRT-LLM PR: 6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py:180-182
Timestamp: 2025-10-20T17:09:21.560Z
Learning: In tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py, the _gated_rmsnorm_replacement function does not need to cast the output of torch.ops.auto_deploy.torch_rmsnorm_gated back to the input dtype, even though the custom op returns fp32. The dtype handling is managed elsewhere or the fp32 output is acceptable for downstream consumers.
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py:6-6
Timestamp: 2025-10-20T16:54:09.824Z
Learning: In tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, the import `from ...modules.mamba.layernorm_gated import _layer_norm_fwd` is correct and should not be changed to modules.fla.layernorm_gated. The _layer_norm_fwd function exists in both modules/mamba/layernorm_gated.py and modules/fla/layernorm_gated.py, but the mamba version is the intended implementation for this use case.
Learnt from: moraxu
Repo: NVIDIA/TensorRT-LLM PR: 6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
Learnt from: EmmaQiaoCh
Repo: NVIDIA/TensorRT-LLM PR: 7370
File: tests/unittest/trt/model_api/test_model_quantization.py:24-27
Timestamp: 2025-08-29T14:07:45.863Z
Learning: In TensorRT-LLM's CI infrastructure, pytest skip markers (pytest.mark.skip) are properly honored even when test files have __main__ blocks that call test functions directly. The testing system correctly skips tests without requiring modifications to the __main__ block execution pattern.
tensorrt_llm/_torch/auto_deploy/transform/library/fused_add_rms_norm.py (2)

60-69: Verify output semantics match between pattern and replacement.

The pattern returns (norm, added) while the replacement calls flashinfer_fused_add_rms_norm which returns (x, residual). Based on the docstring of flashinfer_fused_add_rms_norm_inplace:

  • residual = x + residual (the sum)
  • x = rms_norm(residual, ...) (the normalized result)

So the fused op returns (normalized_output, sum_output) which matches (norm, added). The semantics align correctly.


39-45: LGTM!

The unused parameters (cm, factory, shared_config) are required by the BaseTransform interface contract. This is expected and not a concern.

tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_fused_add_rms_norm.py (2)

46-48: LGTM!

The fake registration is correctly implemented for tracing support. The unused parameters are expected for fake kernels.


51-54: LGTM!

Clean wrapper implementation that provides a convenient interface returning the modified tensors.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_fused_add_rms_norm.py (3)

11-23: LGTM!

The test model correctly represents the pattern (add + cast + RMSNorm) that the transform is designed to fuse.


26-71: LGTM!

The test helper comprehensively validates the transformation:

  1. Exports the model to a graph module
  2. Applies the fuse_add_rms_norm transform
  3. Verifies the fused op is present in the graph
  4. Validates numerical correctness of outputs

1-8: Missing NVIDIA copyright header.

Per coding guidelines, all TensorRT-LLM code files should contain an NVIDIA copyright header at the top.

+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 import torch
 from torch.export import Dim
⛔ Skipped due to learnings
Learnt from: galagam
Repo: NVIDIA/TensorRT-LLM PR: 6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py:180-182
Timestamp: 2025-10-20T17:09:21.560Z
Learning: In tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py, the _gated_rmsnorm_replacement function does not need to cast the output of torch.ops.auto_deploy.torch_rmsnorm_gated back to the input dtype, even though the custom op returns fp32. The dtype handling is managed elsewhere or the fp32 output is acceptable for downstream consumers.
Learnt from: moraxu
Repo: NVIDIA/TensorRT-LLM PR: 6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py:6-6
Timestamp: 2025-10-20T16:54:09.824Z
Learning: In tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, the import `from ...modules.mamba.layernorm_gated import _layer_norm_fwd` is correct and should not be changed to modules.fla.layernorm_gated. The _layer_norm_fwd function exists in both modules/mamba/layernorm_gated.py and modules/fla/layernorm_gated.py, but the mamba version is the intended implementation for this use case.
Learnt from: fredricz-20070104
Repo: NVIDIA/TensorRT-LLM PR: 7645
File: tests/integration/test_lists/qa/llm_function_core.txt:648-648
Timestamp: 2025-09-09T09:40:45.658Z
Learning: In TensorRT-LLM test lists, it's common and intentional for the same test to appear in multiple test list files when they serve different purposes (e.g., llm_function_core.txt for comprehensive core functionality testing and llm_function_core_sanity.txt for quick sanity checks). This duplication allows tests to be run in different testing contexts.
Learnt from: Fridah-nv
Repo: NVIDIA/TensorRT-LLM PR: 6760
File: tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py:81-98
Timestamp: 2025-08-09T02:04:49.623Z
Learning: In TensorRT-LLM's auto_deploy module, torch.dtype values in configuration dictionaries must be stored as string representations (e.g., "float16" instead of torch.float16) because OmegaConf.merge does not support torch.dtype types. These string representations are converted to actual torch.dtype objects in downstream code.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27150 [ run ] triggered by Bot. Commit: 57c5794

@nvchenghaoz nvchenghaoz enabled auto-merge (squash) December 5, 2025 20:46
@tensorrt-cicd
Copy link
Collaborator

PR_Github #27150 [ run ] completed with state SUCCESS. Commit: 57c5794
/LLM/main/L0_MergeRequest_PR pipeline #20718 completed with status: 'FAILURE'

@nvchenghaoz
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27165 [ run ] triggered by Bot. Commit: 57c5794

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
@nvchenghaoz
Copy link
Collaborator Author

/bot kill

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27167 [ kill ] triggered by Bot. Commit: 7e3e379

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27165 [ run ] completed with state ABORTED. Commit: 57c5794
LLM/main/L0_MergeRequest_PR #20728 (Blue Ocean) completed with status: ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27167 [ kill ] completed with state SUCCESS. Commit: 7e3e379
Successfully killed previous jobs for commit 7e3e379

@nvchenghaoz
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27168 [ run ] triggered by Bot. Commit: 7e3e379

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27168 [ run ] completed with state SUCCESS. Commit: 7e3e379
/LLM/main/L0_MergeRequest_PR pipeline #20730 completed with status: 'FAILURE'

@nvchenghaoz
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27224 [ run ] triggered by Bot. Commit: 7e3e379

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27224 [ run ] completed with state SUCCESS. Commit: 7e3e379
/LLM/main/L0_MergeRequest_PR pipeline #20784 completed with status: 'FAILURE'

@nvchenghaoz
Copy link
Collaborator Author

/bot run

@nvchenghaoz
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27350 [ run ] triggered by Bot. Commit: cbbe6c9

@tensorrt-cicd
Copy link
Collaborator

PR_Github #27350 [ run ] completed with state SUCCESS. Commit: cbbe6c9
/LLM/main/L0_MergeRequest_PR pipeline #20895 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@nvchenghaoz nvchenghaoz merged commit 75f5446 into NVIDIA:main Dec 8, 2025
5 checks passed
@github-project-automation github-project-automation bot moved this from Backlog to Done in AutoDeploy Board Dec 9, 2025
usberkeley pushed a commit to usberkeley/TensorRT-LLM that referenced this pull request Dec 11, 2025
…#9754)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
codego7250 pushed a commit to codego7250/TensorRT-LLM that referenced this pull request Dec 11, 2025
…#9754)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
codego7250 pushed a commit to codego7250/TensorRT-LLM that referenced this pull request Dec 13, 2025
…#9754)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

3 participants