@@ -163,6 +163,13 @@ def pattern(
163163 ):
164164 # Reshape query from (B, S, D) to (B, S, H, D/H)
165165 query_BSHDh = op .Reshape (query_BSD , pattern .ANY_VALUE , _outputs = ["query_BSHDh" ])
166+ # Qwen variant uses normalization of query/key before rotary embedding:
167+ # The normalization can happen before (eg., Qwen) or after the Transpose (eg., Gemma).
168+ query_BSHDh_normalized = op .SimplifiedLayerNormalization (
169+ query_BSHDh , pattern .ANY_VALUE , axis = - 1 , _outputs = ["query_BSHDh_normalized" ]
170+ )
171+ query_BSHDh = pattern .OrValue ([query_BSHDh , query_BSHDh_normalized ])
172+
166173 # Transpose from (B, S, H, D/H) to (B, H, S, D/H)
167174 query_BHSDh = op .Transpose (query_BSHDh , perm = [0 , 2 , 1 , 3 ])
168175
@@ -174,6 +181,11 @@ def pattern(
174181
175182 # Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H)
176183 key_BSHkvDh = op .Reshape (key_BSDkv , pattern .ANY_VALUE , _outputs = ["key_BSHkvDh" ])
184+ key_BSHkvDh_normalized = op .SimplifiedLayerNormalization (
185+ key_BSHkvDh , pattern .ANY_VALUE , axis = - 1 , _outputs = ["key_BSHkvDh_normalized" ]
186+ )
187+ key_BSHkvDh = pattern .OrValue ([key_BSHkvDh , key_BSHkvDh_normalized ])
188+
177189 # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H)
178190 key_BHkvSDh = op .Transpose (key_BSHkvDh , perm = [0 , 2 , 1 , 3 ])
179191
@@ -258,8 +270,23 @@ def check(
258270 query_BSHDh ,
259271 key_BSHkvDh ,
260272 mask ,
273+ query_BSHDh_normalized = None ,
274+ query_BHSDh_normalized = None ,
275+ key_BSHkvDh_normalized = None ,
276+ key_BHkvSDh_normalized = None ,
261277 ** _ ,
262278 ):
279+ result = pattern .MatchResult ()
280+ if query_BSHDh_normalized is not None and query_BHSDh_normalized is not None :
281+ return result .fail (
282+ "Query normalized twice" ,
283+ [query_BSHDh_normalized , query_BHSDh_normalized ],
284+ )
285+ if key_BSHkvDh_normalized is not None and key_BHkvSDh_normalized is not None :
286+ return result .fail (
287+ "Key normalized twice" ,
288+ [key_BSHkvDh_normalized , key_BHkvSDh_normalized ],
289+ )
263290 bindings : dict [str , Dim ] = {}
264291
265292 def no_match (val : ir .Value , dims : Sequence [str ]) -> bool :
@@ -282,7 +309,7 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
282309 # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]:
283310 # or check Reshape's shape-input value
284311
285- result = pattern . MatchResult ()
312+
286313 num_heads = _ir_utils .get_dim (query_BSHDh , 2 )
287314 kv_num_heads = _ir_utils .get_dim (key_BSHkvDh , 2 )
288315 if not isinstance (num_heads , int ):
@@ -334,7 +361,9 @@ def rewrite(
334361 mask ,
335362 query_BSHDh ,
336363 key_BSHkvDh ,
364+ query_BSHDh_normalized = None ,
337365 query_BHSDh_normalized = None ,
366+ key_BSHkvDh_normalized = None ,
338367 key_BHkvSDh_normalized = None ,
339368 ** _ ,
340369 ):
@@ -356,9 +385,10 @@ def rewrite(
356385 max_seq_length = op .ReduceMax (seqlens_k , zero_int64_1d , keepdims = 0 )
357386 total_seq_length_int32 = op .Add (max_seq_length , one_int32_0d )
358387
359- if query_BHSDh_normalized is not None :
388+ normalized_query = query_BHSDh_normalized or query_BSHDh_normalized
389+ if normalized_query is not None :
360390 # We apply normalization without the transpose, which is fused into GQA
361- norm_node = query_BHSDh_normalized .producer ()
391+ norm_node = normalized_query .producer ()
362392 norm_attrs = norm_node .attributes
363393 norm_scale = norm_node .inputs [1 ]
364394 query_BSHDh_normalized = op .SimplifiedLayerNormalization (
@@ -367,9 +397,10 @@ def rewrite(
367397 reshape_BSHDh_to_BSD = op .Constant (value_ints = [0 , 0 , - 1 ])
368398 query_BSD = op .Reshape (query_BSHDh_normalized , reshape_BSHDh_to_BSD )
369399
370- if key_BHkvSDh_normalized is not None :
400+ normalized_key = key_BHkvSDh_normalized or key_BSHkvDh_normalized
401+ if normalized_key is not None :
371402 # We apply normalization without the transpose, which is fused into GQA
372- norm_node = key_BHkvSDh_normalized .producer ()
403+ norm_node = normalized_key .producer ()
373404 norm_attrs = norm_node .attributes
374405 norm_scale = norm_node .inputs [1 ]
375406 key_BSHkvDh_normalized = op .SimplifiedLayerNormalization (
0 commit comments