Skip to content

Commit 126c9fd

Browse files
committed
Merge branch 'rama/qwen-gqa' into rama/qwen-gqa2
2 parents a16e700 + 7ee86ce commit 126c9fd

File tree

1 file changed

+36
-5
lines changed
  • onnxscript/rewriter/ort_fusions

1 file changed

+36
-5
lines changed

onnxscript/rewriter/ort_fusions/gqa.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,13 @@ def pattern(
163163
):
164164
# Reshape query from (B, S, D) to (B, S, H, D/H)
165165
query_BSHDh = op.Reshape(query_BSD, pattern.ANY_VALUE, _outputs=["query_BSHDh"])
166+
# Qwen variant uses normalization of query/key before rotary embedding:
167+
# The normalization can happen before (eg., Qwen) or after the Transpose (eg., Gemma).
168+
query_BSHDh_normalized = op.SimplifiedLayerNormalization(
169+
query_BSHDh, pattern.ANY_VALUE, axis=-1, _outputs=["query_BSHDh_normalized"]
170+
)
171+
query_BSHDh = pattern.OrValue([query_BSHDh, query_BSHDh_normalized])
172+
166173
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
167174
query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3])
168175

@@ -174,6 +181,11 @@ def pattern(
174181

175182
# Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H)
176183
key_BSHkvDh = op.Reshape(key_BSDkv, pattern.ANY_VALUE, _outputs=["key_BSHkvDh"])
184+
key_BSHkvDh_normalized = op.SimplifiedLayerNormalization(
185+
key_BSHkvDh, pattern.ANY_VALUE, axis=-1, _outputs=["key_BSHkvDh_normalized"]
186+
)
187+
key_BSHkvDh = pattern.OrValue([key_BSHkvDh, key_BSHkvDh_normalized])
188+
177189
# Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H)
178190
key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3])
179191

@@ -258,8 +270,23 @@ def check(
258270
query_BSHDh,
259271
key_BSHkvDh,
260272
mask,
273+
query_BSHDh_normalized=None,
274+
query_BHSDh_normalized=None,
275+
key_BSHkvDh_normalized=None,
276+
key_BHkvSDh_normalized=None,
261277
**_,
262278
):
279+
result = pattern.MatchResult()
280+
if query_BSHDh_normalized is not None and query_BHSDh_normalized is not None:
281+
return result.fail(
282+
"Query normalized twice",
283+
[query_BSHDh_normalized, query_BHSDh_normalized],
284+
)
285+
if key_BSHkvDh_normalized is not None and key_BHkvSDh_normalized is not None:
286+
return result.fail(
287+
"Key normalized twice",
288+
[key_BSHkvDh_normalized, key_BHkvSDh_normalized],
289+
)
263290
bindings: dict[str, Dim] = {}
264291

265292
def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
@@ -282,7 +309,7 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
282309
# and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]:
283310
# or check Reshape's shape-input value
284311

285-
result = pattern.MatchResult()
312+
286313
num_heads = _ir_utils.get_dim(query_BSHDh, 2)
287314
kv_num_heads = _ir_utils.get_dim(key_BSHkvDh, 2)
288315
if not isinstance(num_heads, int):
@@ -334,7 +361,9 @@ def rewrite(
334361
mask,
335362
query_BSHDh,
336363
key_BSHkvDh,
364+
query_BSHDh_normalized=None,
337365
query_BHSDh_normalized=None,
366+
key_BSHkvDh_normalized=None,
338367
key_BHkvSDh_normalized=None,
339368
**_,
340369
):
@@ -356,9 +385,10 @@ def rewrite(
356385
max_seq_length = op.ReduceMax(seqlens_k, zero_int64_1d, keepdims=0)
357386
total_seq_length_int32 = op.Add(max_seq_length, one_int32_0d)
358387

359-
if query_BHSDh_normalized is not None:
388+
normalized_query = query_BHSDh_normalized or query_BSHDh_normalized
389+
if normalized_query is not None:
360390
# We apply normalization without the transpose, which is fused into GQA
361-
norm_node = query_BHSDh_normalized.producer()
391+
norm_node = normalized_query.producer()
362392
norm_attrs = norm_node.attributes
363393
norm_scale = norm_node.inputs[1]
364394
query_BSHDh_normalized = op.SimplifiedLayerNormalization(
@@ -367,9 +397,10 @@ def rewrite(
367397
reshape_BSHDh_to_BSD = op.Constant(value_ints=[0, 0, -1])
368398
query_BSD = op.Reshape(query_BSHDh_normalized, reshape_BSHDh_to_BSD)
369399

370-
if key_BHkvSDh_normalized is not None:
400+
normalized_key = key_BHkvSDh_normalized or key_BSHkvDh_normalized
401+
if normalized_key is not None:
371402
# We apply normalization without the transpose, which is fused into GQA
372-
norm_node = key_BHkvSDh_normalized.producer()
403+
norm_node = normalized_key.producer()
373404
norm_attrs = norm_node.attributes
374405
norm_scale = norm_node.inputs[1]
375406
key_BSHkvDh_normalized = op.SimplifiedLayerNormalization(

0 commit comments

Comments
 (0)