22# Licensed under the MIT License.
33
44import unittest
5- import onnx
6- from packaging import version
75
6+ import onnx
87import onnx_ir as ir
8+ from packaging import version
99
1010import onnxscript
1111import onnxscript .optimizer
12- from onnxscript import FLOAT , script
1312import onnxscript .rewriter .testing
13+ from onnxscript import FLOAT , script
1414from onnxscript .rewriter .rules .fusion ._gqa import fuse_gqa
1515
1616op = onnxscript .values .Opset ("" , 23 )
2020D = [64 ] # Head size
2121G = [2 ] # Number of groups
2222
23+
2324@script (ir_version = 10 )
2425def _gqa_script (
2526 query_BHSD : FLOAT [2 , 8 , 4 , 64 ], # B=2, H=8, S=4, D=64
@@ -29,38 +30,42 @@ def _gqa_script(
2930 past_value_BHkvPD : FLOAT [2 , 4 , 8 , 64 ], # B=2, Hkv=4, P=8, D=64
3031) -> FLOAT [2 , 8 , 4 , 64 ]:
3132 """Basic GQA pattern that should be fused into an Attention op."""
32-
33+
3334 # Concatenate past_key cache and current key
3435 present_key_BHkvStD = op .Concat (past_key_BHkvPD , key_BHkvSD , axis = - 2 ) # [B, Hkv, S+P, D]
35-
36- # Unsqueeze to add group dimension
36+
37+ # Unsqueeze to add group dimension
3738 present_key_BHkv1StD = op .Unsqueeze (present_key_BHkvStD , 2 ) # [B, Hkv, 1, S+P, D]
3839
3940 # Calculate shapes dynamically
4041 B = op .Shape (query_BHSD , start = 0 , end = 1 ) # [B]
4142 T = op .Shape (present_key_BHkvStD , start = 2 , end = 3 ) # [S+P]
42-
43+
4344 # Create expand shape [B, Hkv, G, S+P, D]
4445 expand_shape = op .Concat (B , Hkv , G , T , D , axis = 0 )
4546 present_key_BHkvGStD = op .Expand (present_key_BHkv1StD , expand_shape ) # [B, Hkv, G, S+P, D]
46-
47- # Create reshape shape [B, H, S+P, D]
47+
48+ # Create reshape shape [B, H, S+P, D]
4849 reshape_shape = op .Concat (B , H , T , D , axis = 0 )
4950 present_key_BHStD = op .Reshape (present_key_BHkvGStD , reshape_shape ) # [B, H, S+P, D]
50-
51+
5152 # Same for value
52- present_value_BHkvStD = op .Concat (past_value_BHkvPD , value_BHkvSD , axis = - 2 ) # [B, Hkv, S+P, D]
53+ present_value_BHkvStD = op .Concat (
54+ past_value_BHkvPD , value_BHkvSD , axis = - 2
55+ ) # [B, Hkv, S+P, D]
5356 present_value_BHkv1StD = op .Unsqueeze (present_value_BHkvStD , 2 ) # [B, Hkv, 1, S+P, D]
54- present_value_BHkvGStD = op .Expand (present_value_BHkv1StD , expand_shape ) # [B, Hkv, G, S+P, D]
57+ present_value_BHkvGStD = op .Expand (
58+ present_value_BHkv1StD , expand_shape
59+ ) # [B, Hkv, G, S+P, D]
5560 present_value_BHStD = op .Reshape (present_value_BHkvGStD , reshape_shape ) # [B, H, S+P, D]
56-
61+
5762 # Attention computation
5863 attention_BHSDh = op .Attention (
5964 query_BHSD ,
6065 present_key_BHStD ,
6166 present_value_BHStD ,
6267 )
63-
68+
6469 return attention_BHSDh
6570
6671
0 commit comments