Skip to content

Commit a16e700

Browse files
committed
Add support for GQA without past
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent 45b5189 commit a16e700

File tree

1 file changed

+6
-2
lines changed
  • onnxscript/rewriter/ort_fusions

1 file changed

+6
-2
lines changed

onnxscript/rewriter/ort_fusions/gqa.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ def pattern(
209209
# that share key/value.
210210

211211
key_seq_BHkvTDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2)
212+
# Concat with past_key is optional:
213+
key_seq_BHkvTDh = pattern.OrValue([key_seq_BHkvTDh, key_BHkvSDh_rope])
212214
key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, 2)
213215
key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, pattern.ANY_VALUE)
214216
key_seq_BHTDh = op.Reshape(
@@ -218,6 +220,8 @@ def pattern(
218220
# Concatenate past_value cache and current value, expand across heads
219221
# that share key/value.
220222
value_seq_BHkvTDh = op.Concat(past_value, value_BHkvSDh, axis=-2)
223+
# Concat with past_value is optional:
224+
value_seq_BHkvTDh = pattern.OrValue([value_seq_BHkvTDh, value_BHkvSDh])
221225
value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, 2)
222226
value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, pattern.ANY_VALUE)
223227
value_seq_BHTDh = op.Reshape(
@@ -268,9 +272,9 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
268272
if no_match(value_BSDkv, ["B", "S", "Dkv"]):
269273
return False
270274

271-
if no_match(past_key, ["B", "Hkv", "P", "Dh"]):
275+
if past_key is not None and no_match(past_key, ["B", "Hkv", "P", "Dh"]):
272276
return False
273-
if no_match(past_value, ["B", "Hkv", "P", "Dv"]):
277+
if past_value is not None and no_match(past_value, ["B", "Hkv", "P", "Dv"]):
274278
return False
275279

276280
# TODO: verify Reshapes:

0 commit comments

Comments
 (0)