Skip to content

Commit 8e55921

Browse files
feat(auto_deploy): add L2 norm pattern matcher and fusion transform
Signed-off-by: Karthik Vetrivel <kvetrivel@nvidia.com>
1 parent 3a89495 commit 8e55921

File tree

3 files changed

+306
-0
lines changed

3 files changed

+306
-0
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ transforms:
5050
expected_layout: bsnd
5151
match_rmsnorm_pattern:
5252
stage: pattern_matcher
53+
match_l2norm_pattern:
54+
stage: pattern_matcher
5355
############################################################################################
5456
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
5557
############################################################################################
@@ -139,6 +141,9 @@ transforms:
139141
rmsnorm_backend: flashinfer
140142
gated_rmsnorm_backend: triton
141143
requires_shape_prop: true
144+
fuse_l2norm:
145+
stage: post_load_fusion
146+
backend: fla
142147
fuse_add_rms_norm:
143148
stage: post_load_fusion
144149
enabled: true
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
"""Graph transform to optimize L2Norm execution using FLA Triton kernels."""
2+
3+
from typing import Literal, Tuple, Type
4+
5+
import torch
6+
from pydantic import Field
7+
from torch.fx import GraphModule, Node
8+
9+
from ...models.factory import ModelFactory
10+
from ...shim.interface import CachedSequenceInterface
11+
12+
# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher
13+
from ...utils.node_utils import is_op
14+
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
15+
from ..interface import (
16+
BaseTransform,
17+
SharedConfig,
18+
TransformConfig,
19+
TransformInfo,
20+
TransformRegistry,
21+
)
22+
23+
_BACKEND_OPS = {
24+
"fla": torch.ops.auto_deploy.fla_l2norm.default,
25+
"torch": torch.ops.auto_deploy.torch_l2norm.default,
26+
}
27+
28+
29+
def _l2_norm_pattern(data: torch.Tensor, eps: float) -> torch.Tensor:
30+
"""Implements the L2Norm pattern for pattern matching.
31+
32+
L2 normalization: x / sqrt(sum(x^2) + eps)
33+
34+
Args:
35+
data: Input tensor to normalize.
36+
eps: Small constant for numerical stability.
37+
38+
Returns:
39+
L2 normalized tensor.
40+
"""
41+
input_dtype = data.dtype
42+
data = data.to(torch.float32)
43+
sum_sq = (data * data).sum(dim=-1, keepdim=True)
44+
data = data * torch.rsqrt(sum_sq + eps)
45+
return data.to(input_dtype)
46+
47+
48+
def _l2_norm_pattern_no_dtype_cast(data: torch.Tensor, eps: float) -> torch.Tensor:
49+
"""Implements the L2Norm pattern without dtype casting for pattern matching.
50+
51+
Some models may already operate in float32 and skip the dtype cast.
52+
53+
Args:
54+
data: Input tensor to normalize.
55+
eps: Small constant for numerical stability.
56+
57+
Returns:
58+
L2 normalized tensor.
59+
"""
60+
sum_sq = (data * data).sum(dim=-1, keepdim=True)
61+
return data * torch.rsqrt(sum_sq + eps)
62+
63+
64+
def _l2_norm_to_torch_l2norm(data: torch.Tensor, eps: float) -> torch.Tensor:
65+
"""Replace L2Norm pattern with torch_l2norm op (standardized representation).
66+
67+
Args:
68+
data: Input tensor to normalize.
69+
eps: Small constant for numerical stability.
70+
71+
Returns:
72+
L2 normalized tensor using torch_l2norm.
73+
"""
74+
return torch.ops.auto_deploy.torch_l2norm(data, eps)
75+
76+
77+
@TransformRegistry.register("match_l2norm_pattern")
78+
class MatchL2NormPattern(BaseTransform):
79+
"""Matches L2Norm patterns in the graph and replaces them with torch_l2norm op.
80+
81+
This transform runs in the pattern_matcher stage and standardizes L2Norm patterns
82+
to use torch_l2norm op, which can later be fused to a specific backend in the
83+
post_load_fusion stage.
84+
85+
Args:
86+
gm: Input graph module to transform.
87+
88+
Returns:
89+
Transformed graph module with standardized torch_l2norm operations.
90+
"""
91+
92+
def _apply(
93+
self,
94+
gm: GraphModule,
95+
cm: CachedSequenceInterface,
96+
factory: ModelFactory,
97+
shared_config: SharedConfig,
98+
) -> Tuple[GraphModule, TransformInfo]:
99+
graph = gm.graph
100+
patterns = ADPatternMatcherPass()
101+
102+
bs = 2
103+
hidden_size = 512
104+
105+
def dummy_args(input_dtype: torch.dtype, eps: float = 1e-6):
106+
return [
107+
torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype),
108+
eps,
109+
]
110+
111+
configs = [
112+
torch.bfloat16,
113+
torch.float16,
114+
torch.float32,
115+
]
116+
117+
search_fns = [
118+
_l2_norm_pattern,
119+
_l2_norm_pattern_no_dtype_cast,
120+
]
121+
for search_fn in search_fns:
122+
for input_dtype in configs:
123+
register_ad_pattern(
124+
search_fn=search_fn,
125+
replace_fn=_l2_norm_to_torch_l2norm,
126+
patterns=patterns,
127+
dummy_args=dummy_args(input_dtype),
128+
op_ignore_types={},
129+
scalar_workaround={"eps": 1e-6},
130+
skip_duplicates=True,
131+
)
132+
133+
cnt = patterns.apply(graph)
134+
135+
info = TransformInfo(
136+
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0
137+
)
138+
139+
return gm, info
140+
141+
142+
class FuseL2NormConfig(TransformConfig):
143+
"""Configuration for the L2Norm fusion transform."""
144+
145+
backend: Literal["torch", "fla"] = Field(
146+
default="fla",
147+
description="Backend to use for L2Norm computation ('fla' or 'torch').",
148+
)
149+
150+
151+
@TransformRegistry.register("fuse_l2norm")
152+
class FuseL2Norm(BaseTransform):
153+
"""Fuses torch_l2norm ops with the selected backend implementation.
154+
155+
This transform runs in the post_load_fusion stage and replaces torch_l2norm ops
156+
with the specified backend implementation (fla or torch).
157+
158+
Args:
159+
gm: Input graph module to transform.
160+
backend: Backend to use for L2Norm computation ("fla" or "torch").
161+
162+
Returns:
163+
Transformed graph module with backend-specific L2Norm operations.
164+
"""
165+
166+
config: FuseL2NormConfig
167+
168+
@classmethod
169+
def get_config_class(cls) -> Type[TransformConfig]:
170+
return FuseL2NormConfig
171+
172+
def _apply(
173+
self,
174+
gm: GraphModule,
175+
cm: CachedSequenceInterface,
176+
factory: ModelFactory,
177+
shared_config: SharedConfig,
178+
) -> Tuple[GraphModule, TransformInfo]:
179+
graph = gm.graph
180+
target_op = _BACKEND_OPS[self.config.backend]
181+
cnt = 0
182+
183+
for node in list(graph.nodes):
184+
if is_op(node, torch.ops.auto_deploy.torch_l2norm):
185+
with graph.inserting_after(node):
186+
new_node: Node = graph.call_function(
187+
target_op,
188+
args=node.args,
189+
kwargs=node.kwargs,
190+
)
191+
new_node.meta = node.meta.copy()
192+
node.replace_all_uses_with(new_node)
193+
graph.erase_node(node)
194+
cnt += 1
195+
196+
info = TransformInfo(
197+
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0
198+
)
199+
200+
return gm, info
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import pytest
2+
import torch
3+
from _graph_test_helpers import run_test_transformed_gm
4+
from torch.export import Dim
5+
6+
from tensorrt_llm._torch.auto_deploy.custom_ops.l2norm import * # noqa
7+
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
8+
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
9+
from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op
10+
11+
12+
class L2Norm(torch.nn.Module):
13+
"""L2 normalization module that normalizes along the last dimension."""
14+
15+
def __init__(self, eps=1e-6):
16+
super().__init__()
17+
self.eps = eps
18+
19+
def forward(self, x):
20+
input_dtype = x.dtype
21+
x = x.to(torch.float32)
22+
sum_sq = (x * x).sum(dim=-1, keepdim=True)
23+
x = x * torch.rsqrt(sum_sq + self.eps)
24+
return x.to(input_dtype)
25+
26+
27+
class L2NormNoCast(torch.nn.Module):
28+
"""L2 normalization module without dtype casting (for float32 inputs)."""
29+
30+
def __init__(self, eps=1e-6):
31+
super().__init__()
32+
self.eps = eps
33+
34+
def forward(self, x):
35+
sum_sq = (x * x).sum(dim=-1, keepdim=True)
36+
return x * torch.rsqrt(sum_sq + self.eps)
37+
38+
39+
class TestModel(torch.nn.Module):
40+
def __init__(self, eps: float = 1e-6, use_no_cast: bool = False):
41+
super().__init__()
42+
self.linear1 = torch.nn.Linear(1024, 1024, device="cuda", dtype=torch.float16)
43+
if use_no_cast:
44+
self.l2_norm = L2NormNoCast(eps)
45+
else:
46+
self.l2_norm = L2Norm(eps)
47+
self.linear2 = torch.nn.Linear(1024, 1024, device="cuda", dtype=torch.float16)
48+
49+
def forward(self, x):
50+
x = self.linear1(x)
51+
x = self.l2_norm(x)
52+
x = self.linear2(x)
53+
return x
54+
55+
56+
def _run_test(model, op, variant):
57+
def checker(gm):
58+
return any(is_op(n, op) for n in gm.graph.nodes)
59+
60+
x = torch.randn(2, 1024, device="cuda", dtype=torch.float16)
61+
dynamic_shapes = {0: Dim.DYNAMIC}
62+
gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True)
63+
gm_transformed = InferenceOptimizer(
64+
None,
65+
{
66+
"match_l2norm_pattern": {
67+
"stage": "pattern_matcher",
68+
},
69+
"fuse_l2norm": {
70+
"stage": "post_load_fusion",
71+
"backend": variant,
72+
},
73+
},
74+
)(None, gm)
75+
76+
run_test_transformed_gm(
77+
model,
78+
x,
79+
gm_transformed,
80+
checker,
81+
lambda num_p_og: num_p_og,
82+
dynamic_shapes=dynamic_shapes,
83+
)
84+
85+
new_input = torch.randn(4, 1024, device="cuda", dtype=torch.float16)
86+
y_transformed = gm_transformed(new_input)
87+
y_model = model(new_input)
88+
torch.testing.assert_close(y_transformed, y_model, atol=1e-3, rtol=1e-3)
89+
90+
91+
@pytest.mark.parametrize("eps", [1e-2, 1e-6])
92+
@pytest.mark.parametrize(
93+
"variant, op",
94+
[
95+
("fla", torch.ops.auto_deploy.fla_l2norm.default),
96+
("torch", torch.ops.auto_deploy.torch_l2norm.default),
97+
],
98+
)
99+
def test_l2norm_fusion(eps, variant, op):
100+
model = TestModel(eps)
101+
_run_test(model, op, variant)

0 commit comments

Comments
 (0)