Skip to content

Commit 2bcf9f6

Browse files
authored
Fixes for EncoderDecoderCache (#40008)
* Add expectation to t5 for rocm 9.4 * Made EncoderDecoderCache compatible with nn.DataParallel * Fixed t5gemma EncoderDecoderCache * Added todos in autoformer * Ruff * Init is self-contained * Review compliance * Fixed kwargs init of EncoderDecoderCache
1 parent aa45824 commit 2bcf9f6

File tree

4 files changed

+38
-23
lines changed

4 files changed

+38
-23
lines changed

src/transformers/cache_utils.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,13 +1464,31 @@ class EncoderDecoderCache(Cache):
14641464
```
14651465
"""
14661466

1467-
def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
1468-
self.self_attention_cache = self_attention_cache
1469-
self.cross_attention_cache = cross_attention_cache
1467+
def __init__(self, *caches) -> None:
1468+
# For dp and ddp support, if only one argument is passed, it should be an iterable of tuples of tensors
1469+
if len(caches) == 1:
1470+
self.self_attention_cache = DynamicCache()
1471+
self.cross_attention_cache = DynamicCache()
1472+
# Populate cache from the iterable
1473+
for layer_idx, key_value_states in enumerate(caches[0]):
1474+
key_states, value_states = key_value_states[:2]
1475+
self.self_attention_cache.update(key_states, value_states, layer_idx)
1476+
if len(key_value_states) > 2:
1477+
key_states, value_states = key_value_states[2:]
1478+
self.cross_attention_cache.update(key_states, value_states, layer_idx)
1479+
# Otherwise, we should get two arguments, a self-attention cache and a cross-attention cache
1480+
elif len(caches) == 2:
1481+
if not isinstance(caches[0], Cache) or not isinstance(caches[1], Cache):
1482+
raise TypeError(f"One of the two arguments is not a Cache: {type(caches[0]) = }, {type(caches[1]) = }")
1483+
self.self_attention_cache = caches[0]
1484+
self.cross_attention_cache = caches[1]
1485+
# Error case
1486+
else:
1487+
raise ValueError(f"Expected 1 or 2 arguments, got {len(caches)}")
14701488

14711489
self.is_updated = {}
1472-
for layer_idx in range(len(cross_attention_cache)):
1473-
self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
1490+
for layer_idx in range(len(self.cross_attention_cache)):
1491+
self.is_updated[layer_idx] = bool(self.cross_attention_cache.get_seq_length(layer_idx) > 0)
14741492

14751493
def __repr__(self) -> str:
14761494
return (
@@ -1527,21 +1545,18 @@ def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]:
15271545

15281546
@classmethod
15291547
def from_legacy_cache(
1530-
cls, past_key_values: tuple[tuple[torch.FloatTensor, torch.FloatTensor], ...]
1548+
cls, past_key_values: Optional[Iterable[tuple[torch.FloatTensor, ...]]]
15311549
) -> "EncoderDecoderCache":
15321550
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
1551+
cache = cls(DynamicCache(), DynamicCache())
15331552
if past_key_values is None:
15341553
logger.warning_once("past_key_values should not be None in from_legacy_cache()")
1535-
cache = cls(
1536-
self_attention_cache=DynamicCache(),
1537-
cross_attention_cache=DynamicCache(),
1538-
)
1539-
if past_key_values is not None:
1540-
for layer_idx in range(len(past_key_values)):
1541-
key_states, value_states = past_key_values[layer_idx][:2]
1554+
else:
1555+
for layer_idx, key_value_states in enumerate(past_key_values):
1556+
key_states, value_states = key_value_states[:2]
15421557
cache.self_attention_cache.update(key_states, value_states, layer_idx)
1543-
if len(past_key_values[layer_idx]) > 2:
1544-
key_states, value_states = past_key_values[layer_idx][2:]
1558+
if len(key_value_states) > 2:
1559+
key_states, value_states = key_value_states[2:]
15451560
cache.cross_attention_cache.update(key_states, value_states, layer_idx)
15461561
cache.is_updated[layer_idx] = True
15471562
return cache

src/transformers/models/blip/modeling_blip_text.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -446,9 +446,7 @@ def forward(
446446
elif isinstance(past_key_values, DynamicCache):
447447
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
448448
elif past_key_values is None:
449-
past_key_values = EncoderDecoderCache(
450-
self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache()
451-
)
449+
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
452450

453451
all_hidden_states = () if output_hidden_states else None
454452
all_self_attentions = () if output_attentions else None

tests/models/t5/test_modeling_t5.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
2626
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
2727
from transformers.testing_utils import (
28+
Expectations,
2829
cleanup,
2930
require_accelerate,
3031
require_sentencepiece,
@@ -1200,7 +1201,12 @@ def test_small_integration_test(self):
12001201
loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss
12011202
mtf_score = -(labels.shape[-1] * loss.item())
12021203

1203-
EXPECTED_SCORE = -19.0845
1204+
EXPECTED_SCORE = Expectations(
1205+
{
1206+
(None, None): -19.0845,
1207+
("rocm", (9, 4)): -19.0846,
1208+
}
1209+
).get_expectation()
12041210
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
12051211

12061212
@slow

tests/models/t5gemma/test_modeling_t5gemma.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,10 +1386,6 @@ def test_flex_attention_with_grads(self):
13861386
# If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605)
13871387
_ = model(**dummy_inputs)
13881388

1389-
@unittest.skip("EncoderDecoderCache can't be gathered because it is not iterable.")
1390-
def test_multi_gpu_data_parallel_forward(self):
1391-
pass
1392-
13931389

13941390
class T5GemmaEncoderOnlyModelTester:
13951391
config_class = T5GemmaConfig

0 commit comments

Comments
 (0)