File tree Expand file tree Collapse file tree 1 file changed +7
-5
lines changed
tico/serialize/operators/onert Expand file tree Collapse file tree 1 file changed +7
-5
lines changed Original file line number Diff line number Diff line change 4141 Tensor attention_mask,
4242 Tensor past_key,
4343 Tensor past_value,
44- int layer_idx ,
45- Tensor cache_position
44+ Tensor cache_position ,
45+ int layer_idx
4646) -> Tensor
4747"""
4848)
@@ -63,8 +63,8 @@ def attention_llama(*args, **kwargs):
6363 attention_mask ,
6464 past_key ,
6565 past_value ,
66- layer_idx ,
6766 cache_position ,
67+ layer_idx ,
6868 ) = args
6969 return hidden_states
7070
@@ -104,8 +104,8 @@ def llama_attention_forward_adapter(
104104 # k_cache[0] | k_cache[1] | ... | k_cache[n]
105105 key_cache [self .layer_idx ],
106106 value_cache [self .layer_idx ], # Same to value_cache
107- self .layer_idx ,
108107 cache_position ,
108+ self .layer_idx ,
109109 ),
110110 None ,
111111 )
@@ -143,7 +143,9 @@ def define_node(
143143 circle .BuiltinOperator .BuiltinOperator .ATTENTION , self ._op_codes
144144 )
145145
146- inputs = node .args
146+ # remove last arg (= layer_idx) from inputs.
147+ # layer_idx is attention op's param, not input.
148+ inputs = node .args [:- 1 ]
147149 outputs = [node ]
148150 operator = create_builtin_operator (self .graph , op_index , inputs , outputs )
149151
You can’t perform that action at this time.
0 commit comments