Skip to content

Commit cfc0a47

Browse files
committed
Minor fixes
Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 1135a8e commit cfc0a47

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

onnxscript/rewriter/ort_fusions/gqa_test.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import math
66
import unittest
77

8-
import parameterized
98
import numpy as np
109
import onnx
1110
import onnx_ir as ir
1211
import onnx_ir.passes.common.shape_inference as shape_inference
1312
import onnxruntime as ort
13+
import parameterized
1414
import torch
1515

1616
import onnxscript
@@ -362,14 +362,23 @@ def test_fusion(self):
362362
assert_allclose(outputs3, source_model_outputs)
363363

364364

365+
@parameterized.parameterized_class([
366+
{"with_past": True, "transpose_first": True},
367+
{"with_past": True, "transpose_first": False},
368+
{"with_past": False, "transpose_first": True},
369+
{"with_past": False, "transpose_first": False},
370+
])
365371
class GemmaGQAFusionTest(unittest.TestCase):
372+
with_past = True
373+
transpose_first = True
366374
def __init__(self, *args, **kwargs):
367375
super().__init__(*args, **kwargs)
376+
368377
# Config parameters
369378
self.batchsize = 1 # Note: GQA (cpu) seems to require batch-size 1?
370379
self.seqlen = 8
371380
self.kv_seqlen = self.seqlen
372-
self.past_seqlen = 16
381+
self.past_seqlen = 16 if self.with_past else 0
373382
self.head_size = 16
374383
self.num_heads = 20
375384
self.kv_num_heads = 10
@@ -425,7 +434,9 @@ def __init__(self, *args, **kwargs):
425434
"key_scale": np.random.rand(Dh).astype(np.float32),
426435
}
427436

428-
def source_model_script(self, with_past: bool, transpose_first: bool):
437+
def source_model_script(self):
438+
with_past = self.with_past
439+
transpose_first = self.transpose_first
429440
scale_factor = math.sqrt(math.sqrt(self.head_size))
430441
minval = torch.finfo(torch.float32).min
431442
minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval])
@@ -567,17 +578,11 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal
567578

568579
return gqa
569580

570-
@parameterized.parameterized.expand([
571-
(True, True), # with_past=True, transpose_first=True
572-
(True, False), # with_past=True, transpose_first=False
573-
(False, True), # with_past=False, transpose_first=True
574-
(False, False), # with_past=False, transpose_first=False
575-
])
576-
def test_fusion(self, with_past, transpose_first):
581+
def test_fusion(self):
577582
"""Test that GQA fusion is successful on source model and produces an equivalent model."""
578583
inputs = self.inputs
579584

580-
source_model = self.source_model_script(with_past, transpose_first).to_model_proto(
585+
source_model = self.source_model_script().to_model_proto(
581586
input_types=self.input_types,
582587
output_types=self.output_types,
583588
)

0 commit comments

Comments
 (0)