@@ -1464,13 +1464,31 @@ class EncoderDecoderCache(Cache):
1464
1464
```
1465
1465
"""
1466
1466
1467
- def __init__ (self , self_attention_cache : Cache , cross_attention_cache : Cache ):
1468
- self .self_attention_cache = self_attention_cache
1469
- self .cross_attention_cache = cross_attention_cache
1467
+ def __init__ (self , * caches ) -> None :
1468
+ # For dp and ddp support, if only one argument is passed, it should be an iterable of tuples of tensors
1469
+ if len (caches ) == 1 :
1470
+ self .self_attention_cache = DynamicCache ()
1471
+ self .cross_attention_cache = DynamicCache ()
1472
+ # Populate cache from the iterable
1473
+ for layer_idx , key_value_states in enumerate (caches [0 ]):
1474
+ key_states , value_states = key_value_states [:2 ]
1475
+ self .self_attention_cache .update (key_states , value_states , layer_idx )
1476
+ if len (key_value_states ) > 2 :
1477
+ key_states , value_states = key_value_states [2 :]
1478
+ self .cross_attention_cache .update (key_states , value_states , layer_idx )
1479
+ # Otherwise, we should get two arguments, a self-attention cache and a cross-attention cache
1480
+ elif len (caches ) == 2 :
1481
+ if not isinstance (caches [0 ], Cache ) or not isinstance (caches [1 ], Cache ):
1482
+ raise TypeError (f"One of the two arguments is not a Cache: { type (caches [0 ]) = } , { type (caches [1 ]) = } " )
1483
+ self .self_attention_cache = caches [0 ]
1484
+ self .cross_attention_cache = caches [1 ]
1485
+ # Error case
1486
+ else :
1487
+ raise ValueError (f"Expected 1 or 2 arguments, got { len (caches )} " )
1470
1488
1471
1489
self .is_updated = {}
1472
- for layer_idx in range (len (cross_attention_cache )):
1473
- self .is_updated [layer_idx ] = bool (cross_attention_cache .get_seq_length (layer_idx ) > 0 )
1490
+ for layer_idx in range (len (self . cross_attention_cache )):
1491
+ self .is_updated [layer_idx ] = bool (self . cross_attention_cache .get_seq_length (layer_idx ) > 0 )
1474
1492
1475
1493
def __repr__ (self ) -> str :
1476
1494
return (
@@ -1527,21 +1545,18 @@ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]:
1527
1545
1528
1546
@classmethod
1529
1547
def from_legacy_cache (
1530
- cls , past_key_values : tuple [ tuple [torch .FloatTensor , torch . FloatTensor ], ... ]
1548
+ cls , past_key_values : Optional [ Iterable [ tuple [torch .FloatTensor , ...]] ]
1531
1549
) -> "EncoderDecoderCache" :
1532
1550
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
1551
+ cache = cls (DynamicCache (), DynamicCache ())
1533
1552
if past_key_values is None :
1534
1553
logger .warning_once ("past_key_values should not be None in from_legacy_cache()" )
1535
- cache = cls (
1536
- self_attention_cache = DynamicCache (),
1537
- cross_attention_cache = DynamicCache (),
1538
- )
1539
- if past_key_values is not None :
1540
- for layer_idx in range (len (past_key_values )):
1541
- key_states , value_states = past_key_values [layer_idx ][:2 ]
1554
+ else :
1555
+ for layer_idx , key_value_states in enumerate (past_key_values ):
1556
+ key_states , value_states = key_value_states [:2 ]
1542
1557
cache .self_attention_cache .update (key_states , value_states , layer_idx )
1543
- if len (past_key_values [ layer_idx ] ) > 2 :
1544
- key_states , value_states = past_key_values [ layer_idx ] [2 :]
1558
+ if len (key_value_states ) > 2 :
1559
+ key_states , value_states = key_value_states [2 :]
1545
1560
cache .cross_attention_cache .update (key_states , value_states , layer_idx )
1546
1561
cache .is_updated [layer_idx ] = True
1547
1562
return cache
0 commit comments