Skip to content

Commit ae0b0c3

Browse files
author
Sanggyu Lee
committed
Remove attention_mask and make kv_cache mandatory, not optional
1 parent 086d86c commit ae0b0c3

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

tico/serialize/operators/onert/op_attention.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
Tensor wo,
3939
Tensor position_cos,
4040
Tensor position_sin,
41-
Tensor? attention_mask,
4241
Tensor past_key,
4342
Tensor past_value,
4443
int layer_idx,
@@ -60,7 +59,6 @@ def attention_llama(*args, **kwargs):
6059
o_proj,
6160
position_cos,
6261
position_sin,
63-
attention_mask,
6462
past_key,
6563
past_value,
6664
layer_idx,
@@ -69,7 +67,7 @@ def attention_llama(*args, **kwargs):
6967
return hidden_states
7068

7169

72-
from typing import List, Optional
70+
from typing import List
7371

7472
from transformers.cache_utils import DynamicCache
7573
from transformers.models.llama.modeling_llama import LlamaAttention
@@ -79,8 +77,7 @@ def llama_attention_forward_adapter(
7977
self: LlamaAttention,
8078
hidden_states: torch.Tensor,
8179
position_embeddings: List[torch.Tensor],
82-
attention_mask: Optional[torch.Tensor],
83-
past_key_value: Optional[DynamicCache],
80+
past_key_value: DynamicCache,
8481
cache_position: torch.Tensor,
8582
**kwargs,
8683
):
@@ -97,13 +94,12 @@ def llama_attention_forward_adapter(
9794
self.o_proj.weight,
9895
position_embeddings[0], # cos
9996
position_embeddings[1], # sin
100-
attention_mask,
10197
# key_cache is a list of cache for each decoder layer.
10298
# Assumtion: key cache is continuous
10399
#
104100
# k_cache[0] | k_cache[1] | ... | k_cache[n]
105-
key_cache[0],
106-
value_cache[0], # Same to value_cache
101+
key_cache[self.layer_idx],
102+
value_cache[self.layer_idx], # Same to value_cache
107103
self.layer_idx,
108104
cache_position,
109105
),
@@ -132,7 +128,6 @@ def define_node(
132128
wo,
133129
position_cos,
134130
position_sin,
135-
attention_mask,
136131
past_key,
137132
past_value,
138133
cache_position,

0 commit comments

Comments
 (0)