Skip to content

Commit 9470d65

Browse files
authored
Fix low memory beam search (#34746)
* fix * higher max positions in tests
1 parent 145fbd4 commit 9470d65

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

src/transformers/cache_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int
528528
cache = cls()
529529
for idx in range(len(splits[0])):
530530
key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
531-
value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
531+
value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx] != []]
532532
if key_cache != []:
533533
layer_keys = torch.cat(key_cache, dim=0)
534534
layer_values = torch.cat(value_cache, dim=0)
@@ -1523,7 +1523,10 @@ def crop(self, maximum_length: int):
15231523
self.check_dynamic_cache(self.crop.__name__)
15241524
self.self_attention_cache.crop(maximum_length)
15251525

1526-
def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
1526+
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
1527+
def batch_split(
1528+
self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
1529+
) -> "List[EncoderDecoderCache]":
15271530
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
15281531
`_split_model_inputs()` in `generation.utils`"""
15291532
self.check_dynamic_cache(self.batch_split.__name__)
@@ -1536,7 +1539,10 @@ def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDec
15361539
return out
15371540

15381541
@classmethod
1539-
def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
1542+
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
1543+
def from_batch_splits(
1544+
cls, splits: List["EncoderDecoderCache"], num_hidden_layers: int = None
1545+
) -> "EncoderDecoderCache":
15401546
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
15411547
`generation.utils`"""
15421548
self_attention_cache = DynamicCache()

tests/generation/test_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1046,7 +1046,6 @@ def test_contrastive_generate_low_memory(self):
10461046
self.assertListEqual(low_output.tolist(), high_output.tolist())
10471047

10481048
@pytest.mark.generate
1049-
@unittest.skip("Started to break with https://github.com/huggingface/transformers/pull/33703")
10501049
def test_beam_search_low_memory(self):
10511050
# Check that choosing 'low_memory' does not change the model output
10521051
for model_class in self.all_generative_model_classes:

tests/models/blip_2/test_modeling_blip_2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def __init__(
330330
hidden_act="gelu",
331331
hidden_dropout_prob=0.1,
332332
attention_probs_dropout_prob=0.1,
333-
max_position_embeddings=20,
333+
max_position_embeddings=512,
334334
eos_token_id=2,
335335
pad_token_id=1,
336336
bos_token_id=0,

0 commit comments

Comments
 (0)