Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,12 @@ transforms:
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
fuse_rmsnorm:
stage: post_load_fusion
rmsnorm_backend: triton
rmsnorm_backend: flashinfer
gated_rmsnorm_backend: triton
requires_shape_prop: true

fuse_add_rms_norm:
stage: post_load_fusion
enabled: true
############################################################################################
# VISUALIZE GRAPH
############################################################################################
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import flashinfer
import torch

from ...flashinfer_utils import get_env_enable_pdl


@torch.library.custom_op(
"auto_deploy::flashinfer_fused_add_rms_norm_inplace", mutates_args={"x", "residual"}
)
def flashinfer_fused_add_rms_norm_inplace(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float
) -> None:
"""
Fused Add + RMSNorm operation using FlashInfer (In-place).
Computes in-place:
residual = x + residual (sum)
x = rms_norm(residual, weight, eps) (normalized)
Returns None.
"""
# FlashInfer expects 2D inputs (batch*seq_len, hidden_size)
x_shape = x.shape
residual_shape = residual.shape
x_flat = x.view(-1, x.shape[-1])
residual_flat = residual.view(-1, residual.shape[-1])

flashinfer.norm.fused_add_rmsnorm(
x_flat, residual_flat, weight, eps, enable_pdl=get_env_enable_pdl()
)
x_flat.view(x_shape)
residual_flat.view(residual_shape)
return


@flashinfer_fused_add_rms_norm_inplace.register_fake
def _(x, residual, weight, eps):
return


def flashinfer_fused_add_rms_norm(x, residual, weight, eps):
"""Wrapper that calls the in-place op and returns the modified tensors."""
torch.ops.auto_deploy.flashinfer_fused_add_rms_norm_inplace(x, residual, weight, eps)
return x, residual
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Transformation for fusing Add + Cast + RMSNorm."""

from typing import Tuple

import torch
from torch.fx import GraphModule

from ...custom_ops.flashinfer_fused_add_rms_norm import flashinfer_fused_add_rms_norm
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry


@TransformRegistry.register("fuse_add_rms_norm")
class FuseAddRMSNorm(BaseTransform):
"""Fuse (add + cast + RMSNorm) into one fused op.
Matches:
x = add(input, residual)
y = x.to(dtype)
z = flashinfer_rms_norm(y, weight, eps)
Replaces with:
z, x = flashinfer_fused_add_rms_norm(input, residual, weight, eps)
"""

def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
patterns = ADPatternMatcherPass()

# Dummy shapes for tracing
bsz, hidden = 2, 128
dummy_args = [
torch.randn(bsz, hidden, device="meta", dtype=torch.bfloat16), # x (bf16)
torch.randn(bsz, hidden, device="meta", dtype=torch.bfloat16), # residual (bf16)
torch.randn(hidden, device="meta", dtype=torch.bfloat16), # weight
1e-5, # eps
]

op_ignore_types = {torch.ops.aten.to.dtype: (torch.dtype,)}
scalar_workaround = {"eps": 1e-5}

def _fused_add_norm_pattern(x, residual, weight, eps):
added = torch.ops.aten.add.Tensor(x, residual)
cast = torch.ops.aten.to.dtype(added, torch.bfloat16)
# Note: we assume flashinfer_rms_norm is the target
norm = torch.ops.auto_deploy.flashinfer_rms_norm.default(cast, weight, eps)
return norm, added

def _fused_add_norm_replacement(x, residual, weight, eps):
# Use the python wrapper directly, not via torch.ops.auto_deploy
return flashinfer_fused_add_rms_norm(x, residual, weight, eps)

# Register pattern
register_ad_pattern(
search_fn=_fused_add_norm_pattern,
replace_fn=_fused_add_norm_replacement,
patterns=patterns,
dummy_args=dummy_args,
op_ignore_types=op_ignore_types,
scalar_workaround=scalar_workaround,
)

num_matches = patterns.apply(gm.graph)

info = TransformInfo(
skipped=False,
num_matches=num_matches,
is_clean=num_matches == 0,
has_valid_shapes=num_matches == 0,
)
return gm, info
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def run_with_autodeploy(model, speculative_model_dir, batch_size):
"world_size": 1,
"kv_cache_config": kv_cache_config,
"disable_overlap_scheduler": True,
"transforms": {
"fuse_rmsnorm": {"rmsnorm_backend": "triton"},
},
"max_num_tokens": 64,
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest
import torch

from tensorrt_llm._torch.auto_deploy.custom_ops.flashinfer_fused_add_rms_norm import (
flashinfer_fused_add_rms_norm,
)


def rms_norm_ref(x, weight, eps):
"""Reference implementation of RMSNorm using PyTorch ops."""
input_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
return weight * x.to(input_dtype)


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [128, 1024])
def test_flashinfer_fused_add_rms_norm_kernel(dtype, hidden_size):
bsz = 4
seq_len = 128
eps = 1e-6

# Create inputs
x = torch.randn(bsz, seq_len, hidden_size, device="cuda", dtype=dtype)
residual = torch.randn_like(x)
weight = torch.randn(hidden_size, device="cuda", dtype=dtype)

# Clone for reference
x_ref = x.clone()
residual_ref = residual.clone()

residual_ref_out = x_ref + residual_ref
x_ref_out = rms_norm_ref(residual_ref_out, weight, eps)

# Run kernel (Our fused op)
x_out, residual_out = flashinfer_fused_add_rms_norm(x, residual, weight, eps)

rtol, atol = (1e-2, 1e-2)

torch.testing.assert_close(residual_out, residual_ref_out, rtol=rtol, atol=atol)
torch.testing.assert_close(x_out, x_ref_out, rtol=rtol, atol=atol)

# Verify in-place modification happened
assert x is x_out
assert residual is residual_out
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch
from torch.export import Dim

from tensorrt_llm._torch.auto_deploy.custom_ops.flashinfer_fused_add_rms_norm import * # noqa
from tensorrt_llm._torch.auto_deploy.custom_ops.rms_norm import * # noqa
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op


class TestModel(torch.nn.Module):
def __init__(self, hidden_size=128, eps=1e-5):
super().__init__()
self.weight = torch.nn.Parameter(
torch.ones(hidden_size, device="cuda", dtype=torch.bfloat16)
)
self.eps = eps

def forward(self, x, residual):
added = x + residual
cast = added.to(torch.bfloat16)
norm = torch.ops.auto_deploy.flashinfer_rms_norm(cast, self.weight, self.eps)
return norm, added


def _run_test(model):
# The replacement uses flashinfer_fused_add_rms_norm python wrapper which calls the inplace op
# auto_deploy::flashinfer_fused_add_rms_norm_inplace
op = torch.ops.auto_deploy.flashinfer_fused_add_rms_norm_inplace

def checker(gm):
return any(is_op(n, op) for n in gm.graph.nodes)

bsz, seq_len, hidden = 2, 8, 128
# Inputs should be bfloat16
x = torch.randn(bsz, seq_len, hidden, device="cuda", dtype=torch.bfloat16)
residual = torch.randn(bsz, seq_len, hidden, device="cuda", dtype=torch.bfloat16)

# Dynamic shapes
ds_x = {0: Dim("batch_size", max=8)}
ds_res = {0: Dim("batch_size", max=8)}

gm = torch_export_to_gm(model, args=(x, residual), dynamic_shapes=(ds_x, ds_res), clone=True)

gm_transformed = InferenceOptimizer(
None,
{
"fuse_add_rms_norm": {
"stage": "post_load_fusion",
},
},
)(None, gm)

# Check if transform happened
if not checker(gm_transformed):
raise AssertionError(
"flashinfer_fused_add_rms_norm_inplace op not found in transformed graph"
)

# Validation
# Clone inputs because the fused op is inplace
x_in = x.clone()
res_in = residual.clone()

# The fused op is inplace, so inputs x_in and res_in will be modified.
# gm_transformed returns (x_in, res_in) which are the modified tensors.
y_transformed = gm_transformed(x_in, res_in)

y_model = model(x.clone(), residual.clone())
torch.testing.assert_close(y_transformed[0], y_model[0], atol=1e-2, rtol=1e-2)
torch.testing.assert_close(y_transformed[1], y_model[1], atol=1e-2, rtol=1e-2)


def test_fuse_add_rms_norm():
model = TestModel()
_run_test(model)