Skip to content

Commit 7bcbfe3

Browse files
authored
Add attention operator and adapter for onert (#400)
* Fuse LlamaAttention to attention (onert) It fuses LlamaAttention from TinyLlama model. Fused attention works as onert attention op. TICO-DCO-1.0-Signed-off-by: Sanggyu Lee <sg5.lee@samsung.com>
1 parent 2828fa5 commit 7bcbfe3

File tree

5 files changed

+165
-0
lines changed

5 files changed

+165
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# DO NOT REMOVE THIS FILE
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Dict, List, TYPE_CHECKING
16+
17+
import torch
18+
19+
from transformers.cache_utils import DynamicCache
20+
from transformers.models.llama.modeling_llama import LlamaAttention
21+
22+
23+
def llama_attention_forward_adapter(
24+
self: LlamaAttention,
25+
hidden_states: torch.Tensor,
26+
position_embeddings: List[torch.Tensor],
27+
attention_mask: torch.Tensor,
28+
past_key_value: DynamicCache,
29+
cache_position: torch.Tensor,
30+
**kwargs,
31+
):
32+
# past_key_value is a dict with key_cache and value_cache.
33+
# It needs to be decomposed for tico and circle which does not know dict.
34+
key_cache = past_key_value.key_cache # type: ignore[union-attr]
35+
value_cache = past_key_value.value_cache # type: ignore[union-attr]
36+
return (
37+
torch.ops.circle_custom.attention(
38+
hidden_states,
39+
self.q_proj.weight,
40+
self.k_proj.weight,
41+
self.v_proj.weight,
42+
self.o_proj.weight,
43+
position_embeddings[0], # cos
44+
position_embeddings[1], # sin
45+
attention_mask,
46+
key_cache[self.layer_idx],
47+
value_cache[self.layer_idx], # Same to value_cache
48+
cache_position,
49+
),
50+
None,
51+
)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Dict, List, TYPE_CHECKING
16+
17+
if TYPE_CHECKING:
18+
import torch._ops
19+
import torch.fx
20+
import torch
21+
from circle_schema import circle
22+
23+
from tico.serialize.circle_graph import CircleSubgraph
24+
from tico.serialize.operators.hashable_opcode import OpCode
25+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27+
from tico.utils.validate_args_kwargs import CircleAttentionArgs
28+
29+
30+
@register_node_visitor
31+
class AttentionVisitor(NodeVisitor):
32+
target: List[torch._ops.OpOverload] = [
33+
torch.ops.circle_custom.attention.default,
34+
]
35+
36+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
37+
super().__init__(op_codes, graph)
38+
39+
def define_node(
40+
self,
41+
node: torch.fx.Node,
42+
) -> circle.Operator.OperatorT:
43+
args = CircleAttentionArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
44+
op_index = get_op_index(
45+
circle.BuiltinOperator.BuiltinOperator.ATTENTION, self._op_codes
46+
)
47+
48+
inputs = node.args
49+
outputs = [node]
50+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
51+
52+
# Op-specific option
53+
operator.builtinOptionsType = (
54+
circle.BuiltinOptions.BuiltinOptions.AttentionOptions
55+
)
56+
operator.builtinOptions = circle.AttentionOptions.AttentionOptionsT()
57+
58+
return operator

tico/utils/register_custom_op.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,40 @@ def _(
727727
return hidden_states.new_empty(hidden_states.size())
728728

729729

730+
def CircleAttention():
731+
@custom_op("circle_custom::attention", mutates_args=())
732+
def attention(
733+
hidden_states: torch.Tensor,
734+
wq: torch.Tensor,
735+
wk: torch.Tensor,
736+
wv: torch.Tensor,
737+
wo: torch.Tensor,
738+
position_cos: torch.Tensor,
739+
position_sin: torch.Tensor,
740+
attention_mask: torch.Tensor,
741+
past_key: torch.Tensor,
742+
past_value: torch.Tensor,
743+
cache_position: torch.Tensor,
744+
) -> torch.Tensor:
745+
return None
746+
747+
@register_fake("circle_custom::attention")
748+
def _(
749+
hidden_states: torch.Tensor,
750+
wq: torch.Tensor,
751+
wk: torch.Tensor,
752+
wv: torch.Tensor,
753+
wo: torch.Tensor,
754+
position_cos: torch.Tensor,
755+
position_sin: torch.Tensor,
756+
attention_mask: torch.Tensor,
757+
past_key: torch.Tensor,
758+
past_value: torch.Tensor,
759+
cache_position: torch.Tensor,
760+
) -> torch.Tensor:
761+
return hidden_states
762+
763+
730764
# Add custom ops to the torch namespace
731765
def RegisterOps():
732766
CircleResizeNearestNeighbor()
@@ -740,3 +774,4 @@ def RegisterOps():
740774
CircleInstanceNorm()
741775
CircleQuantizeMX()
742776
CircleRMSNorm()
777+
CircleAttention()

tico/utils/validate_args_kwargs.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,26 @@ class CatArgs:
171171
dim: int = 0
172172

173173

174+
@enforce_type
175+
@dataclass
176+
class CircleAttentionArgs:
177+
"""
178+
For circle.BuiltinOperator.BuiltinOperator.ATTENTION
179+
"""
180+
181+
hidden_states: torch.fx.Node
182+
wq: torch.fx.Node
183+
wk: torch.fx.Node
184+
wv: torch.fx.Node
185+
wo: torch.fx.Node
186+
position_cos: torch.fx.Node
187+
position_sin: torch.fx.Node
188+
attention_mask: torch.fx.Node
189+
past_key: torch.fx.Node
190+
past_value: torch.fx.Node
191+
cache_position: torch.fx.Node
192+
193+
174194
@enforce_type
175195
@dataclass
176196
class CircleRMSNormArgs:

0 commit comments

Comments
 (0)