3838 Tensor wo,
3939 Tensor position_cos,
4040 Tensor position_sin,
41- Tensor? attention_mask,
4241 Tensor past_key,
4342 Tensor past_value,
4443 int layer_idx,
@@ -60,7 +59,6 @@ def attention_llama(*args, **kwargs):
6059 o_proj ,
6160 position_cos ,
6261 position_sin ,
63- attention_mask ,
6462 past_key ,
6563 past_value ,
6664 layer_idx ,
@@ -69,7 +67,7 @@ def attention_llama(*args, **kwargs):
6967 return hidden_states
7068
7169
72- from typing import List , Optional
70+ from typing import List
7371
7472from transformers .cache_utils import DynamicCache
7573from transformers .models .llama .modeling_llama import LlamaAttention
@@ -79,8 +77,7 @@ def llama_attention_forward_adapter(
7977 self : LlamaAttention ,
8078 hidden_states : torch .Tensor ,
8179 position_embeddings : List [torch .Tensor ],
82- attention_mask : Optional [torch .Tensor ],
83- past_key_value : Optional [DynamicCache ],
80+ past_key_value : DynamicCache ,
8481 cache_position : torch .Tensor ,
8582 ** kwargs ,
8683):
@@ -97,13 +94,12 @@ def llama_attention_forward_adapter(
9794 self .o_proj .weight ,
9895 position_embeddings [0 ], # cos
9996 position_embeddings [1 ], # sin
100- attention_mask ,
10197 # key_cache is a list of cache for each decoder layer.
10298 # Assumtion: key cache is continuous
10399 #
104100 # k_cache[0] | k_cache[1] | ... | k_cache[n]
105- key_cache [0 ],
106- value_cache [0 ], # Same to value_cache
101+ key_cache [self . layer_idx ],
102+ value_cache [self . layer_idx ], # Same to value_cache
107103 self .layer_idx ,
108104 cache_position ,
109105 ),
@@ -132,7 +128,6 @@ def define_node(
132128 wo ,
133129 position_cos ,
134130 position_sin ,
135- attention_mask ,
136131 past_key ,
137132 past_value ,
138133 cache_position ,
0 commit comments