-
Notifications
You must be signed in to change notification settings - Fork 107
Expand file tree
/
Copy path_gqa.py
More file actions
114 lines (96 loc) · 3.78 KB
/
_gqa.py
File metadata and controls
114 lines (96 loc) · 3.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations
from typing import Union
import onnx_ir as ir
import onnxscript.rewriter._fusion_utils as _fusion_utils
from onnxscript.rewriter import _basics, pattern
Dim = Union[int, ir.SymbolicDim]
class OnnxGroupQueryAttention(pattern.RewriteRuleClassBase):
def __init__(self):
super().__init__("ONNXGQA", remove_nodes=False)
def pattern(
self,
op,
query_BHSD,
key_BHkvSD,
value_BHkvSD,
past_key_BHkvSpD,
past_value_BHkvSpD,
):
# Concatenate past_key cache and current key, expand across heads
# that share key/value.
present_key_BHkvStD = op.Concat(past_key_BHkvSpD, key_BHkvSD, axis=-2)
present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2)
present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, pattern.ANY_VALUE)
present_key_BHStD = op.Reshape(
present_key_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_key_BHStD"]
)
# Concatenate past_value cache and current value, expand across heads
# that share key/value.
present_value_BHkvStD = op.Concat(past_value_BHkvSpD, value_BHkvSD, axis=-2)
present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2)
present_value_BHkvGStD = op.Expand(present_value_BHkv1StD, pattern.ANY_VALUE)
present_value_BHStD = op.Reshape(
present_value_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_value_BHStD"]
)
attention_BHSDh = op.Attention(
query_BHSD,
present_key_BHStD,
present_value_BHStD,
pattern.Var("mask", can_match_none=True),
_outputs=["attention_BHSDh"],
)
return attention_BHSDh, present_key_BHkvStD, present_value_BHkvStD
def check(
self,
context: _basics.MatchContext,
query_BHSD,
key_BHkvSD,
value_BHkvSD,
past_key_BHkvSpD,
past_value_BHkvSpD,
present_key_BHStD,
present_value_BHStD,
**_,
):
bindings: dict[str, Dim] = {}
# Check that inputs to new Attention node have expected shapes
_fusion_utils.check_shape(bindings, query_BHSD, ["B", "H", "S", "D"])
_fusion_utils.check_shape(bindings, key_BHkvSD, ["B", "Hkv", "S", "D"])
_fusion_utils.check_shape(bindings, value_BHkvSD, ["B", "Hkv", "S", "D"])
_fusion_utils.check_shape(bindings, past_key_BHkvSpD, ["B", "Hkv", "P", "D"])
_fusion_utils.check_shape(bindings, past_value_BHkvSpD, ["B", "Hkv", "P", "D"])
# We need to check that the Expand/Reshape arguments are as expected.
# As a substitute, we check that the outputs of Expand=>Reshape have expected shapes.
# TODO (rama): May be better to check the actual Expand/Reshape arguments.
_fusion_utils.check_shape(bindings, present_key_BHStD, ["B", "H", "S+P", "D"])
_fusion_utils.check_shape(bindings, present_value_BHStD, ["B", "H", "S+P", "D"])
return True
def rewrite(
self,
op,
query_BHSD,
key_BHkvSD,
value_BHkvSD,
past_key_BHkvSpD,
past_value_BHkvSpD,
mask,
attention_BHSDh,
**_,
):
original_attention_node = attention_BHSDh.producer()
original_attrs = original_attention_node.attributes
return op.Attention(
query_BHSD,
key_BHkvSD,
value_BHkvSD,
mask,
past_key_BHkvSpD,
past_value_BHkvSpD,
**original_attrs,
_outputs=3,
)
_basic_gqa_rule = OnnxGroupQueryAttention.rule()
gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule])
fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules)