@@ -121,9 +121,9 @@ def call(
121
121
input_shape = ops .shape (hidden_states )[:- 1 ]
122
122
hidden_shape = (* input_shape , self .num_attention_heads , self .head_dim )
123
123
124
- query_states = ops .reshape (self .q_proj (hidden_states ),hidden_shape )
124
+ query_states = ops .reshape (self .q_proj (hidden_states ), hidden_shape )
125
125
# (batch, num_heads, seq_len, head_dim)
126
- query_states = ops .transpose (query_states , axes = (0 , 2 , 1 , 3 ))
126
+ query_states = ops .transpose (query_states , axes = (0 , 2 , 1 , 3 ))
127
127
128
128
def _compute_kv_values (x_input ):
129
129
kv_hidden_shape = (
@@ -132,13 +132,9 @@ def _compute_kv_values(x_input):
132
132
self .head_dim ,
133
133
)
134
134
135
- key_states_raw = ops .reshape (
136
- self .k_proj (x_input ),
137
- kv_hidden_shape
138
- )
135
+ key_states_raw = ops .reshape (self .k_proj (x_input ), kv_hidden_shape )
139
136
value_states_raw = ops .reshape (
140
- self .v_proj (x_input ),
141
- kv_hidden_shape
137
+ self .v_proj (x_input ), kv_hidden_shape
142
138
)
143
139
144
140
key_states = ops .transpose (key_states_raw , axes = (0 , 2 , 1 , 3 ))
@@ -155,7 +151,9 @@ def _compute_kv_values(x_input):
155
151
key_states = key_cache
156
152
value_states = value_cache
157
153
else :
158
- print ("self_attention_cache_update_index is not None, computing kv values" )
154
+ print (
155
+ "self_attention_cache_update_index is not None, computing kv values"
156
+ )
159
157
key_update , value_update = _compute_kv_values (hidden_states )
160
158
update_idx_tensor = ops .convert_to_tensor (
161
159
self_attention_cache_update_index , dtype = "int32"
@@ -417,7 +415,7 @@ def call(
417
415
position_embeddings: Optional tuple of (cos, sin) tensors for RoPE.
418
416
training: Whether the layer is in training mode.
419
417
"""
420
- self_attention_cache = kwargs .get (' self_attention_cache' , None )
418
+ self_attention_cache = kwargs .get (" self_attention_cache" , None )
421
419
422
420
residual = hidden_states
423
421
hidden_states = self .input_layernorm (hidden_states )
0 commit comments