Skip to content

Commit 375c62f

Browse files
author
Sanggyu Lee
committed
Fix wrong arg order and move layer_idx from inputs to params
1 parent 5117bb8 commit 375c62f

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

tico/serialize/operators/onert/op_attention.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
Tensor attention_mask,
4242
Tensor past_key,
4343
Tensor past_value,
44-
int layer_idx,
45-
Tensor cache_position
44+
Tensor cache_position,
45+
int layer_idx
4646
) -> Tensor
4747
"""
4848
)
@@ -63,8 +63,8 @@ def attention_llama(*args, **kwargs):
6363
attention_mask,
6464
past_key,
6565
past_value,
66-
layer_idx,
6766
cache_position,
67+
layer_idx,
6868
) = args
6969
return hidden_states
7070

@@ -104,8 +104,8 @@ def llama_attention_forward_adapter(
104104
# k_cache[0] | k_cache[1] | ... | k_cache[n]
105105
key_cache[self.layer_idx],
106106
value_cache[self.layer_idx], # Same to value_cache
107-
self.layer_idx,
108107
cache_position,
108+
self.layer_idx,
109109
),
110110
None,
111111
)
@@ -143,7 +143,9 @@ def define_node(
143143
circle.BuiltinOperator.BuiltinOperator.ATTENTION, self._op_codes
144144
)
145145

146-
inputs = node.args
146+
# remove last arg (= layer_idx) from inputs.
147+
# layer_idx is attention op's param, not input.
148+
inputs = node.args[:-1]
147149
outputs = [node]
148150
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
149151

0 commit comments

Comments
 (0)