@@ -141,6 +141,12 @@ def match_rope_pattern(gm: GraphModule) -> int:
141141 torch .randn (batch_size , num_heads , seq_len , head_dim , device = "meta" , dtype = torch .float16 ),
142142 torch .randn (batch_size , seq_len , head_dim // 2 , device = "meta" , dtype = torch .float16 ),
143143 ]
144+ # float32 input can change the graph when there's .float() in pattern
145+ dummy_complex_2 = [
146+ torch .randn (batch_size , num_heads , seq_len , head_dim , device = "meta" , dtype = torch .float32 ),
147+ torch .randn (batch_size , num_heads , seq_len , head_dim , device = "meta" , dtype = torch .float32 ),
148+ torch .randn (batch_size , seq_len , head_dim // 2 , device = "meta" , dtype = torch .float32 ),
149+ ]
144150 register_ad_pattern (
145151 search_fn = _explicit_rope_pattern ,
146152 replace_fn = _explicit_rope_repl ,
@@ -172,6 +178,16 @@ def match_rope_pattern(gm: GraphModule) -> int:
172178 },
173179 scalar_workaround = {"unsqueeze_dim" : 1 },
174180 )
181+ register_ad_pattern (
182+ search_fn = _complex_rope_pattern ,
183+ replace_fn = _complex_rope_repl ,
184+ patterns = patterns ,
185+ dummy_args = dummy_complex_2 ,
186+ op_ignore_types = {
187+ torch .ops .aten .reshape .default : (int ,),
188+ },
189+ scalar_workaround = {"unsqueeze_dim" : 1 },
190+ )
175191
176192 num_matches = patterns .apply (graph )
177193 canonicalize_graph (gm )
0 commit comments