Skip to content

Commit 75f5446

Browse files
authored
[#9753][feat] AutoDeploy: Implement add rms_norm fusion (#9754)
Signed-off-by: Chenghao Zhang <[email protected]>
1 parent da074be commit 75f5446

File tree

6 files changed

+273
-2
lines changed

6 files changed

+273
-2
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,12 @@ transforms:
128128
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
129129
fuse_rmsnorm:
130130
stage: post_load_fusion
131-
rmsnorm_backend: triton
131+
rmsnorm_backend: flashinfer
132132
gated_rmsnorm_backend: triton
133133
requires_shape_prop: true
134-
134+
fuse_add_rms_norm:
135+
stage: post_load_fusion
136+
enabled: true
135137
############################################################################################
136138
# VISUALIZE GRAPH
137139
############################################################################################
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import flashinfer
13+
import torch
14+
15+
from ...flashinfer_utils import get_env_enable_pdl
16+
17+
18+
@torch.library.custom_op(
19+
"auto_deploy::flashinfer_fused_add_rms_norm_inplace", mutates_args={"x", "residual"}
20+
)
21+
def flashinfer_fused_add_rms_norm_inplace(
22+
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float
23+
) -> None:
24+
"""
25+
Fused Add + RMSNorm operation using FlashInfer (In-place).
26+
Computes in-place:
27+
residual = x + residual (sum)
28+
x = rms_norm(residual, weight, eps) (normalized)
29+
30+
Returns None.
31+
"""
32+
# FlashInfer expects 2D inputs (batch*seq_len, hidden_size)
33+
x_shape = x.shape
34+
residual_shape = residual.shape
35+
x_flat = x.view(-1, x.shape[-1])
36+
residual_flat = residual.view(-1, residual.shape[-1])
37+
38+
flashinfer.norm.fused_add_rmsnorm(
39+
x_flat, residual_flat, weight, eps, enable_pdl=get_env_enable_pdl()
40+
)
41+
x_flat.view(x_shape)
42+
residual_flat.view(residual_shape)
43+
return
44+
45+
46+
@flashinfer_fused_add_rms_norm_inplace.register_fake
47+
def _(x, residual, weight, eps):
48+
return
49+
50+
51+
def flashinfer_fused_add_rms_norm(x, residual, weight, eps):
52+
"""Wrapper that calls the in-place op and returns the modified tensors."""
53+
torch.ops.auto_deploy.flashinfer_fused_add_rms_norm_inplace(x, residual, weight, eps)
54+
return x, residual
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
"""Transformation for fusing Add + Cast + RMSNorm."""
13+
14+
from typing import Tuple
15+
16+
import torch
17+
from torch.fx import GraphModule
18+
19+
from ...custom_ops.flashinfer_fused_add_rms_norm import flashinfer_fused_add_rms_norm
20+
from ...models.factory import ModelFactory
21+
from ...shim.interface import CachedSequenceInterface
22+
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
23+
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
24+
25+
26+
@TransformRegistry.register("fuse_add_rms_norm")
27+
class FuseAddRMSNorm(BaseTransform):
28+
"""Fuse (add + cast + RMSNorm) into one fused op.
29+
30+
Matches:
31+
x = add(input, residual)
32+
y = x.to(dtype)
33+
z = flashinfer_rms_norm(y, weight, eps)
34+
35+
Replaces with:
36+
z, x = flashinfer_fused_add_rms_norm(input, residual, weight, eps)
37+
"""
38+
39+
def _apply(
40+
self,
41+
gm: GraphModule,
42+
cm: CachedSequenceInterface,
43+
factory: ModelFactory,
44+
shared_config: SharedConfig,
45+
) -> Tuple[GraphModule, TransformInfo]:
46+
patterns = ADPatternMatcherPass()
47+
48+
# Dummy shapes for tracing
49+
bsz, hidden = 2, 128
50+
dummy_args = [
51+
torch.randn(bsz, hidden, device="meta", dtype=torch.bfloat16), # x (bf16)
52+
torch.randn(bsz, hidden, device="meta", dtype=torch.bfloat16), # residual (bf16)
53+
torch.randn(hidden, device="meta", dtype=torch.bfloat16), # weight
54+
1e-5, # eps
55+
]
56+
57+
op_ignore_types = {torch.ops.aten.to.dtype: (torch.dtype,)}
58+
scalar_workaround = {"eps": 1e-5}
59+
60+
def _fused_add_norm_pattern(x, residual, weight, eps):
61+
added = torch.ops.aten.add.Tensor(x, residual)
62+
cast = torch.ops.aten.to.dtype(added, torch.bfloat16)
63+
# Note: we assume flashinfer_rms_norm is the target
64+
norm = torch.ops.auto_deploy.flashinfer_rms_norm.default(cast, weight, eps)
65+
return norm, added
66+
67+
def _fused_add_norm_replacement(x, residual, weight, eps):
68+
# Use the python wrapper directly, not via torch.ops.auto_deploy
69+
return flashinfer_fused_add_rms_norm(x, residual, weight, eps)
70+
71+
# Register pattern
72+
register_ad_pattern(
73+
search_fn=_fused_add_norm_pattern,
74+
replace_fn=_fused_add_norm_replacement,
75+
patterns=patterns,
76+
dummy_args=dummy_args,
77+
op_ignore_types=op_ignore_types,
78+
scalar_workaround=scalar_workaround,
79+
)
80+
81+
num_matches = patterns.apply(gm.graph)
82+
83+
info = TransformInfo(
84+
skipped=False,
85+
num_matches=num_matches,
86+
is_clean=num_matches == 0,
87+
has_valid_shapes=num_matches == 0,
88+
)
89+
return gm, info

tests/integration/defs/examples/test_ad_speculative_decoding.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ def run_with_autodeploy(model, speculative_model_dir, batch_size):
8181
"world_size": 1,
8282
"kv_cache_config": kv_cache_config,
8383
"disable_overlap_scheduler": True,
84+
"transforms": {
85+
"fuse_rmsnorm": {"rmsnorm_backend": "triton"},
86+
},
8487
"max_num_tokens": 64,
8588
}
8689

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import pytest
2+
import torch
3+
4+
from tensorrt_llm._torch.auto_deploy.custom_ops.flashinfer_fused_add_rms_norm import (
5+
flashinfer_fused_add_rms_norm,
6+
)
7+
8+
9+
def rms_norm_ref(x, weight, eps):
10+
"""Reference implementation of RMSNorm using PyTorch ops."""
11+
input_dtype = x.dtype
12+
x = x.to(torch.float32)
13+
variance = x.pow(2).mean(-1, keepdim=True)
14+
x = x * torch.rsqrt(variance + eps)
15+
return weight * x.to(input_dtype)
16+
17+
18+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
19+
@pytest.mark.parametrize("hidden_size", [128, 1024])
20+
def test_flashinfer_fused_add_rms_norm_kernel(dtype, hidden_size):
21+
bsz = 4
22+
seq_len = 128
23+
eps = 1e-6
24+
25+
# Create inputs
26+
x = torch.randn(bsz, seq_len, hidden_size, device="cuda", dtype=dtype)
27+
residual = torch.randn_like(x)
28+
weight = torch.randn(hidden_size, device="cuda", dtype=dtype)
29+
30+
# Clone for reference
31+
x_ref = x.clone()
32+
residual_ref = residual.clone()
33+
34+
residual_ref_out = x_ref + residual_ref
35+
x_ref_out = rms_norm_ref(residual_ref_out, weight, eps)
36+
37+
# Run kernel (Our fused op)
38+
x_out, residual_out = flashinfer_fused_add_rms_norm(x, residual, weight, eps)
39+
40+
rtol, atol = (1e-2, 1e-2)
41+
42+
torch.testing.assert_close(residual_out, residual_ref_out, rtol=rtol, atol=atol)
43+
torch.testing.assert_close(x_out, x_ref_out, rtol=rtol, atol=atol)
44+
45+
# Verify in-place modification happened
46+
assert x is x_out
47+
assert residual is residual_out
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import torch
2+
from torch.export import Dim
3+
4+
from tensorrt_llm._torch.auto_deploy.custom_ops.flashinfer_fused_add_rms_norm import * # noqa
5+
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import * # noqa
6+
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
7+
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
8+
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
9+
10+
11+
class TestModel(torch.nn.Module):
12+
def __init__(self, hidden_size=128, eps=1e-5):
13+
super().__init__()
14+
self.weight = torch.nn.Parameter(
15+
torch.ones(hidden_size, device="cuda", dtype=torch.bfloat16)
16+
)
17+
self.eps = eps
18+
19+
def forward(self, x, residual):
20+
added = x + residual
21+
cast = added.to(torch.bfloat16)
22+
norm = torch.ops.auto_deploy.flashinfer_rms_norm(cast, self.weight, self.eps)
23+
return norm, added
24+
25+
26+
def _run_test(model):
27+
# The replacement uses flashinfer_fused_add_rms_norm python wrapper which calls the inplace op
28+
# auto_deploy::flashinfer_fused_add_rms_norm_inplace
29+
op = torch.ops.auto_deploy.flashinfer_fused_add_rms_norm_inplace
30+
31+
def checker(gm):
32+
return any(is_op(n, op) for n in gm.graph.nodes)
33+
34+
bsz, seq_len, hidden = 2, 8, 128
35+
# Inputs should be bfloat16
36+
x = torch.randn(bsz, seq_len, hidden, device="cuda", dtype=torch.bfloat16)
37+
residual = torch.randn(bsz, seq_len, hidden, device="cuda", dtype=torch.bfloat16)
38+
39+
# Dynamic shapes
40+
ds_x = {0: Dim("batch_size", max=8)}
41+
ds_res = {0: Dim("batch_size", max=8)}
42+
43+
gm = torch_export_to_gm(model, args=(x, residual), dynamic_shapes=(ds_x, ds_res), clone=True)
44+
45+
gm_transformed = InferenceOptimizer(
46+
None,
47+
{
48+
"fuse_add_rms_norm": {
49+
"stage": "post_load_fusion",
50+
},
51+
},
52+
)(None, gm)
53+
54+
# Check if transform happened
55+
if not checker(gm_transformed):
56+
raise AssertionError(
57+
"flashinfer_fused_add_rms_norm_inplace op not found in transformed graph"
58+
)
59+
60+
# Validation
61+
# Clone inputs because the fused op is inplace
62+
x_in = x.clone()
63+
res_in = residual.clone()
64+
65+
# The fused op is inplace, so inputs x_in and res_in will be modified.
66+
# gm_transformed returns (x_in, res_in) which are the modified tensors.
67+
y_transformed = gm_transformed(x_in, res_in)
68+
69+
y_model = model(x.clone(), residual.clone())
70+
torch.testing.assert_close(y_transformed[0], y_model[0], atol=1e-2, rtol=1e-2)
71+
torch.testing.assert_close(y_transformed[1], y_model[1], atol=1e-2, rtol=1e-2)
72+
73+
74+
def test_fuse_add_rms_norm():
75+
model = TestModel()
76+
_run_test(model)

0 commit comments

Comments
 (0)