Skip to content

Commit 0b8be89

Browse files
authored
[Cache] Fix from_legacy_cache (#605)
1 parent 721a205 commit 0b8be89

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

fla/models/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,8 @@ def from_legacy_cache(
377377
cache = cls(seen_tokens=seen_tokens, **kwargs)
378378
if isinstance(past_key_values, (list, tuple)):
379379
for i, st in enumerate(past_key_values):
380-
cache.append_new_layers(i)
380+
while len(cache.layers) <= i:
381+
cache.layers.append(cache.layer_class_to_replicate())
381382
cache.layers[i].state = dict(st)
382383
return cache
383384

0 commit comments

Comments
 (0)