@@ -165,6 +165,8 @@ def __init__(
165
165
)
166
166
torch ._dynamo .mark_static_address (conv_state )
167
167
self .conv_cache .append (conv_state )
168
+ self .key_cache .append (torch .tensor ([]))
169
+ self .value_cache .append (torch .tensor ([]))
168
170
169
171
def update (
170
172
self ,
@@ -190,35 +192,27 @@ def update(
190
192
A tuple containing the updated key and value states.
191
193
"""
192
194
# Update the cache
193
- if key_states is not None :
194
- if len (self .key_cache ) <= layer_idx :
195
- # There may be skipped layers, fill them with empty lists
196
- for _ in range (len (self .key_cache ), layer_idx ):
197
- self .key_cache .append (torch .tensor ([]))
198
- self .value_cache .append (torch .tensor ([]))
199
- self .key_cache .append (key_states )
200
- self .value_cache .append (value_states )
201
- elif (
202
- not self .key_cache [layer_idx ].numel () # prefers not t.numel() to len(t) == 0 to export the model
203
- ): # fills previously skipped layers; checking for tensor causes errors
204
- self .key_cache [layer_idx ] = key_states
205
- self .value_cache [layer_idx ] = value_states
206
- else :
207
- self .key_cache [layer_idx ] = torch .cat ([self .key_cache [layer_idx ], key_states ], dim = - 2 )
208
- self .value_cache [layer_idx ] = torch .cat ([self .value_cache [layer_idx ], value_states ], dim = - 2 )
195
+ if self .key_cache [layer_idx ].numel () == 0 :
196
+ self .key_cache [layer_idx ] = key_states
197
+ self .value_cache [layer_idx ] = value_states
198
+ else :
199
+ self .key_cache [layer_idx ] = torch .cat ([self .key_cache [layer_idx ], key_states ], dim = - 2 )
200
+ self .value_cache [layer_idx ] = torch .cat ([self .value_cache [layer_idx ], value_states ], dim = - 2 )
209
201
210
202
return self .key_cache [layer_idx ], self .value_cache [layer_idx ]
211
203
212
204
def reorder_cache (self , beam_idx : torch .LongTensor ):
213
205
"""Reorders the cache for beam search, given the selected beam indices."""
214
206
for layer_idx in range (len (self .key_cache )):
215
- device = self .key_cache [layer_idx ].device
216
- self .key_cache [layer_idx ] = self .key_cache [layer_idx ].index_select (0 , beam_idx .to (device ))
217
- device = self .value_cache [layer_idx ].device
218
- self .value_cache [layer_idx ] = self .value_cache [layer_idx ].index_select (0 , beam_idx .to (device ))
219
-
220
- device = self .conv_cache [layer_idx ].device
221
- self .conv_cache [layer_idx ] = self .conv_cache [layer_idx ].index_select (0 , beam_idx .to (device ))
207
+ if self .key_cache [layer_idx ].numel ():
208
+ device = self .key_cache [layer_idx ].device
209
+ self .key_cache [layer_idx ] = self .key_cache [layer_idx ].index_select (0 , beam_idx .to (device ))
210
+ device = self .value_cache [layer_idx ].device
211
+ self .value_cache [layer_idx ] = self .value_cache [layer_idx ].index_select (0 , beam_idx .to (device ))
212
+
213
+ if self .conv_cache [layer_idx ].numel ():
214
+ device = self .conv_cache [layer_idx ].device
215
+ self .conv_cache [layer_idx ] = self .conv_cache [layer_idx ].index_select (0 , beam_idx .to (device ))
222
216
223
217
def get_seq_length (self , layer_idx : Optional [int ] = 0 ) -> int :
224
218
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
0 commit comments