55import math
66import unittest
77
8+ import parameterized
89import numpy as np
910import onnx
1011import onnx_ir as ir
@@ -424,7 +425,7 @@ def __init__(self, *args, **kwargs):
424425 "key_scale" : np .random .rand (Dh ).astype (np .float32 ),
425426 }
426427
427- def source_model_script (self ):
428+ def source_model_script (self , with_past : bool , transpose_first : bool ):
428429 scale_factor = math .sqrt (math .sqrt (self .head_size ))
429430 minval = torch .finfo (torch .float32 ).min
430431 minval_tp = onnx .helper .make_tensor ("minval" , onnx .TensorProto .FLOAT , [1 ], [minval ])
@@ -458,16 +459,26 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal
458459 # We convert them into BHSDh (i.e., BHSd) format. In this version, we have only
459460 # one sequence length (S) for all Q, K, and V (with no cache).
460461 query_BSHDh = op .Reshape (query , shape_BSHDh )
461- query_BHSDh = op .Transpose (query_BSHDh , perm = [0 , 2 , 1 , 3 ])
462- query_BHSDh_normalized = op .SimplifiedLayerNormalization (
463- query_BHSDh , query_scale , axis = - 1 , epsilon = 1e-06 , stash_type = 1
464- )
465-
466462 key_BSHkvDh = op .Reshape (key , shape_BSHkvDh )
467- key_BHkvSDh = op .Transpose (key_BSHkvDh , perm = [0 , 2 , 1 , 3 ])
468- key_BHkvSDh_normalized = op .SimplifiedLayerNormalization (
469- key_BHkvSDh , key_scale , axis = - 1 , epsilon = 1e-06 , stash_type = 1
470- )
463+
464+ if transpose_first :
465+ query_BHSDh = op .Transpose (query_BSHDh , perm = [0 , 2 , 1 , 3 ])
466+ query_BHSDh_normalized = op .SimplifiedLayerNormalization (
467+ query_BHSDh , query_scale , axis = - 1 , epsilon = 1e-06 , stash_type = 1
468+ )
469+ key_BHkvSDh = op .Transpose (key_BSHkvDh , perm = [0 , 2 , 1 , 3 ])
470+ key_BHkvSDh_normalized = op .SimplifiedLayerNormalization (
471+ key_BHkvSDh , key_scale , axis = - 1 , epsilon = 1e-06 , stash_type = 1
472+ )
473+ else :
474+ query_BSHDh_normalized = op .SimplifiedLayerNormalization (
475+ query_BSHDh , query_scale , axis = - 1 , epsilon = 1e-06 , stash_type = 1
476+ )
477+ query_BHSDh_normalized = op .Transpose (query_BSHDh_normalized , perm = [0 , 2 , 1 , 3 ])
478+ key_BSHkvDh_normalized = op .SimplifiedLayerNormalization (
479+ key_BSHkvDh , key_scale , axis = - 1 , epsilon = 1e-06 , stash_type = 1
480+ )
481+ key_BHkvSDh_normalized = op .Transpose (key_BSHkvDh_normalized , perm = [0 , 2 , 1 , 3 ])
471482
472483 value_BSHkvDh = op .Reshape (value , shape_BSHkvDh )
473484 value_BHkvSDh = op .Transpose (value_BSHkvDh , perm = [0 , 2 , 1 , 3 ])
@@ -489,9 +500,13 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal
489500 cos ,
490501 sin ,
491502 )
492- key_seq_BHkvSkvDh = op .Concat (past_key , key_BHkvSDh_rope , axis = - 2 )
493503
494- value_seq_BHkvSkvDh = op .Concat (past_value , value_BHkvSDh , axis = - 2 )
504+ if with_past :
505+ key_seq_BHkvSkvDh = op .Concat (past_key , key_BHkvSDh_rope , axis = - 2 )
506+ value_seq_BHkvSkvDh = op .Concat (past_value , value_BHkvSDh , axis = - 2 )
507+ else :
508+ key_seq_BHkvSkvDh = key_BHkvSDh_rope
509+ value_seq_BHkvSkvDh = value_BHkvSDh
495510
496511 # Now, expand from shared heads to all heads
497512 key_BHkv1SDh = op .Unsqueeze (key_seq_BHkvSkvDh , 2 )
@@ -552,11 +567,17 @@ def gqa(query, key, value, past_key, past_value, cos, sin, query_scale, key_scal
552567
553568 return gqa
554569
555- def test_fusion (self ):
570+ @parameterized .parameterized .expand ([
571+ (True , True ), # with_past=True, transpose_first=True
572+ (True , False ), # with_past=True, transpose_first=False
573+ (False , True ), # with_past=False, transpose_first=True
574+ (False , False ), # with_past=False, transpose_first=False
575+ ])
576+ def test_fusion (self , with_past , transpose_first ):
556577 """Test that GQA fusion is successful on source model and produces an equivalent model."""
557578 inputs = self .inputs
558579
559- source_model = self .source_model_script ().to_model_proto (
580+ source_model = self .source_model_script (with_past , transpose_first ).to_model_proto (
560581 input_types = self .input_types ,
561582 output_types = self .output_types ,
562583 )
0 commit comments