Skip to content

Commit 913695f

Browse files
authored
add a new rope pattern for llama4 scout (#97)
Signed-off-by: Frida Hou <[email protected]>
1 parent 4f8f767 commit 913695f

File tree

1 file changed

+16
-0
lines changed
  • tensorrt_llm/_torch/auto_deploy/transformations/library

1 file changed

+16
-0
lines changed

tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,12 @@ def match_rope_pattern(gm: GraphModule) -> int:
141141
torch.randn(batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float16),
142142
torch.randn(batch_size, seq_len, head_dim // 2, device="meta", dtype=torch.float16),
143143
]
144+
# float32 input can change the graph when there's .float() in pattern
145+
dummy_complex_2 = [
146+
torch.randn(batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float32),
147+
torch.randn(batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float32),
148+
torch.randn(batch_size, seq_len, head_dim // 2, device="meta", dtype=torch.float32),
149+
]
144150
register_ad_pattern(
145151
search_fn=_explicit_rope_pattern,
146152
replace_fn=_explicit_rope_repl,
@@ -172,6 +178,16 @@ def match_rope_pattern(gm: GraphModule) -> int:
172178
},
173179
scalar_workaround={"unsqueeze_dim": 1},
174180
)
181+
register_ad_pattern(
182+
search_fn=_complex_rope_pattern,
183+
replace_fn=_complex_rope_repl,
184+
patterns=patterns,
185+
dummy_args=dummy_complex_2,
186+
op_ignore_types={
187+
torch.ops.aten.reshape.default: (int,),
188+
},
189+
scalar_workaround={"unsqueeze_dim": 1},
190+
)
175191

176192
num_matches = patterns.apply(graph)
177193
canonicalize_graph(gm)

0 commit comments

Comments
 (0)