Skip to content

Commit 1135a8e

Browse files
committed
Add test case for recent update to GQA fusion
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent 3846705 commit 1135a8e

File tree

1 file changed

+35
-14
lines changed

1 file changed

+35
-14
lines changed

onnxscript/rewriter/ort_fusions/gqa_test.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import math
66
import unittest
77

8+
import parameterized
89
import numpy as np
910
import onnx
1011
import onnx_ir as ir
@@ -424,7 +425,7 @@ def __init__(self, *args, **kwargs):
424425
"key_scale": np.random.rand(Dh).astype(np.float32),
425426
}
426427

427-
def source_model_script(self):
428+
def source_model_script(self, with_past: bool, transpose_first: bool):
428429
scale_factor = math.sqrt(math.sqrt(self.head_size))
429430
minval = torch.finfo(torch.float32).min
430431
minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval])
@@ -458,16 +459,26 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal
458459
# We convert them into BHSDh (i.e., BHSd) format. In this version, we have only
459460
# one sequence length (S) for all Q, K, and V (with no cache).
460461
query_BSHDh = op.Reshape(query, shape_BSHDh)
461-
query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3])
462-
query_BHSDh_normalized = op.SimplifiedLayerNormalization(
463-
query_BHSDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1
464-
)
465-
466462
key_BSHkvDh = op.Reshape(key, shape_BSHkvDh)
467-
key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3])
468-
key_BHkvSDh_normalized = op.SimplifiedLayerNormalization(
469-
key_BHkvSDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1
470-
)
463+
464+
if transpose_first:
465+
query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3])
466+
query_BHSDh_normalized = op.SimplifiedLayerNormalization(
467+
query_BHSDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1
468+
)
469+
key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3])
470+
key_BHkvSDh_normalized = op.SimplifiedLayerNormalization(
471+
key_BHkvSDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1
472+
)
473+
else:
474+
query_BSHDh_normalized = op.SimplifiedLayerNormalization(
475+
query_BSHDh, query_scale, axis=-1, epsilon=1e-06, stash_type=1
476+
)
477+
query_BHSDh_normalized = op.Transpose(query_BSHDh_normalized, perm=[0, 2, 1, 3])
478+
key_BSHkvDh_normalized = op.SimplifiedLayerNormalization(
479+
key_BSHkvDh, key_scale, axis=-1, epsilon=1e-06, stash_type=1
480+
)
481+
key_BHkvSDh_normalized = op.Transpose(key_BSHkvDh_normalized, perm=[0, 2, 1, 3])
471482

472483
value_BSHkvDh = op.Reshape(value, shape_BSHkvDh)
473484
value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3])
@@ -489,9 +500,13 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal
489500
cos,
490501
sin,
491502
)
492-
key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2)
493503

494-
value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2)
504+
if with_past:
505+
key_seq_BHkvSkvDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2)
506+
value_seq_BHkvSkvDh = op.Concat(past_value, value_BHkvSDh, axis=-2)
507+
else:
508+
key_seq_BHkvSkvDh = key_BHkvSDh_rope
509+
value_seq_BHkvSkvDh = value_BHkvSDh
495510

496511
# Now, expand from shared heads to all heads
497512
key_BHkv1SDh = op.Unsqueeze(key_seq_BHkvSkvDh, 2)
@@ -552,11 +567,17 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal
552567

553568
return gqa
554569

555-
def test_fusion(self):
570+
@parameterized.parameterized.expand([
571+
(True, True), # with_past=True, transpose_first=True
572+
(True, False), # with_past=True, transpose_first=False
573+
(False, True), # with_past=False, transpose_first=True
574+
(False, False), # with_past=False, transpose_first=False
575+
])
576+
def test_fusion(self, with_past, transpose_first):
556577
"""Test that GQA fusion is successful on source model and produces an equivalent model."""
557578
inputs = self.inputs
558579

559-
source_model = self.source_model_script().to_model_proto(
580+
source_model = self.source_model_script(with_past, transpose_first).to_model_proto(
560581
input_types=self.input_types,
561582
output_types=self.output_types,
562583
)

0 commit comments

Comments
 (0)