Skip to content

Commit 0331251

Browse files
committed
Run lint
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent 738c869 commit 0331251

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

onnxscript/rewriter/rules/fusion/_gqa_test.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
# Licensed under the MIT License.
33

44
import unittest
5-
import onnx
6-
from packaging import version
75

6+
import onnx
87
import onnx_ir as ir
8+
from packaging import version
99

1010
import onnxscript
1111
import onnxscript.optimizer
12-
from onnxscript import FLOAT, script
1312
import onnxscript.rewriter.testing
13+
from onnxscript import FLOAT, script
1414
from onnxscript.rewriter.rules.fusion._gqa import fuse_gqa
1515

1616
op = onnxscript.values.Opset("", 23)
@@ -20,6 +20,7 @@
2020
D = [64] # Head size
2121
G = [2] # Number of groups
2222

23+
2324
@script(ir_version=10)
2425
def _gqa_script(
2526
query_BHSD: FLOAT[2, 8, 4, 64], # B=2, H=8, S=4, D=64
@@ -29,38 +30,42 @@ def _gqa_script(
2930
past_value_BHkvPD: FLOAT[2, 4, 8, 64], # B=2, Hkv=4, P=8, D=64
3031
) -> FLOAT[2, 8, 4, 64]:
3132
"""Basic GQA pattern that should be fused into an Attention op."""
32-
33+
3334
# Concatenate past_key cache and current key
3435
present_key_BHkvStD = op.Concat(past_key_BHkvPD, key_BHkvSD, axis=-2) # [B, Hkv, S+P, D]
35-
36-
# Unsqueeze to add group dimension
36+
37+
# Unsqueeze to add group dimension
3738
present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2) # [B, Hkv, 1, S+P, D]
3839

3940
# Calculate shapes dynamically
4041
B = op.Shape(query_BHSD, start=0, end=1) # [B]
4142
T = op.Shape(present_key_BHkvStD, start=2, end=3) # [S+P]
42-
43+
4344
# Create expand shape [B, Hkv, G, S+P, D]
4445
expand_shape = op.Concat(B, Hkv, G, T, D, axis=0)
4546
present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, expand_shape) # [B, Hkv, G, S+P, D]
46-
47-
# Create reshape shape [B, H, S+P, D]
47+
48+
# Create reshape shape [B, H, S+P, D]
4849
reshape_shape = op.Concat(B, H, T, D, axis=0)
4950
present_key_BHStD = op.Reshape(present_key_BHkvGStD, reshape_shape) # [B, H, S+P, D]
50-
51+
5152
# Same for value
52-
present_value_BHkvStD = op.Concat(past_value_BHkvPD, value_BHkvSD, axis=-2) # [B, Hkv, S+P, D]
53+
present_value_BHkvStD = op.Concat(
54+
past_value_BHkvPD, value_BHkvSD, axis=-2
55+
) # [B, Hkv, S+P, D]
5356
present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2) # [B, Hkv, 1, S+P, D]
54-
present_value_BHkvGStD = op.Expand(present_value_BHkv1StD, expand_shape) # [B, Hkv, G, S+P, D]
57+
present_value_BHkvGStD = op.Expand(
58+
present_value_BHkv1StD, expand_shape
59+
) # [B, Hkv, G, S+P, D]
5560
present_value_BHStD = op.Reshape(present_value_BHkvGStD, reshape_shape) # [B, H, S+P, D]
56-
61+
5762
# Attention computation
5863
attention_BHSDh = op.Attention(
5964
query_BHSD,
6065
present_key_BHStD,
6166
present_value_BHStD,
6267
)
63-
68+
6469
return attention_BHSDh
6570

6671

0 commit comments

Comments
 (0)