|
5 | 5 | import math |
6 | 6 | import unittest |
7 | 7 |
|
8 | | -import parameterized |
9 | 8 | import numpy as np |
10 | 9 | import onnx |
11 | 10 | import onnx_ir as ir |
12 | 11 | import onnx_ir.passes.common.shape_inference as shape_inference |
13 | 12 | import onnxruntime as ort |
| 13 | +import parameterized |
14 | 14 | import torch |
15 | 15 |
|
16 | 16 | import onnxscript |
@@ -362,14 +362,23 @@ def test_fusion(self): |
362 | 362 | assert_allclose(outputs3, source_model_outputs) |
363 | 363 |
|
364 | 364 |
|
| 365 | +@parameterized.parameterized_class([ |
| 366 | + {"with_past": True, "transpose_first": True}, |
| 367 | + {"with_past": True, "transpose_first": False}, |
| 368 | + {"with_past": False, "transpose_first": True}, |
| 369 | + {"with_past": False, "transpose_first": False}, |
| 370 | +]) |
365 | 371 | class GemmaGQAFusionTest(unittest.TestCase): |
| 372 | + with_past = True |
| 373 | + transpose_first = True |
366 | 374 | def __init__(self, *args, **kwargs): |
367 | 375 | super().__init__(*args, **kwargs) |
| 376 | + |
368 | 377 | # Config parameters |
369 | 378 | self.batchsize = 1 # Note: GQA (cpu) seems to require batch-size 1? |
370 | 379 | self.seqlen = 8 |
371 | 380 | self.kv_seqlen = self.seqlen |
372 | | - self.past_seqlen = 16 |
| 381 | + self.past_seqlen = 16 if self.with_past else 0 |
373 | 382 | self.head_size = 16 |
374 | 383 | self.num_heads = 20 |
375 | 384 | self.kv_num_heads = 10 |
@@ -425,7 +434,9 @@ def __init__(self, *args, **kwargs): |
425 | 434 | "key_scale": np.random.rand(Dh).astype(np.float32), |
426 | 435 | } |
427 | 436 |
|
428 | | - def source_model_script(self, with_past: bool, transpose_first: bool): |
| 437 | + def source_model_script(self): |
| 438 | + with_past = self.with_past |
| 439 | + transpose_first = self.transpose_first |
429 | 440 | scale_factor = math.sqrt(math.sqrt(self.head_size)) |
430 | 441 | minval = torch.finfo(torch.float32).min |
431 | 442 | minval_tp = onnx.helper.make_tensor("minval", onnx.TensorProto.FLOAT, [1], [minval]) |
@@ -567,17 +578,11 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal |
567 | 578 |
|
568 | 579 | return gqa |
569 | 580 |
|
570 | | - @parameterized.parameterized.expand([ |
571 | | - (True, True), # with_past=True, transpose_first=True |
572 | | - (True, False), # with_past=True, transpose_first=False |
573 | | - (False, True), # with_past=False, transpose_first=True |
574 | | - (False, False), # with_past=False, transpose_first=False |
575 | | - ]) |
576 | | - def test_fusion(self, with_past, transpose_first): |
| 581 | + def test_fusion(self): |
577 | 582 | """Test that GQA fusion is successful on source model and produces an equivalent model.""" |
578 | 583 | inputs = self.inputs |
579 | 584 |
|
580 | | - source_model = self.source_model_script(with_past, transpose_first).to_model_proto( |
| 585 | + source_model = self.source_model_script().to_model_proto( |
581 | 586 | input_types=self.input_types, |
582 | 587 | output_types=self.output_types, |
583 | 588 | ) |
|
0 commit comments