Skip to content

Commit 8fc6ecb

Browse files
authored
VLM: enable skipped tests (#35746)
* fix cached tests * fix some tests * fix pix2struct * fix
1 parent d6897b4 commit 8fc6ecb

File tree

10 files changed

+216
-20
lines changed

10 files changed

+216
-20
lines changed

src/transformers/models/blip_2/modeling_blip_2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,9 @@ def _init_weights(self, module):
579579
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
580580
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
581581
Whether to interpolate the pre-trained position encodings.
582+
use_cache (`bool`, *optional*):
583+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
584+
`past_key_values`).
582585
"""
583586

584587
BLIP2_IMAGE_TEXT_RETRIEVAL_INPUTS_DOCSTRING = r"""
@@ -2094,6 +2097,7 @@ def forward(
20942097
labels: Optional[torch.LongTensor] = None,
20952098
return_dict: Optional[bool] = None,
20962099
interpolate_pos_encoding: bool = False,
2100+
use_cache: Optional[bool] = None,
20972101
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
20982102
r"""
20992103
Returns:
@@ -2217,6 +2221,7 @@ def forward(
22172221
output_attentions=output_attentions,
22182222
output_hidden_states=output_hidden_states,
22192223
return_dict=return_dict,
2224+
use_cache=use_cache,
22202225
)
22212226
logits = outputs.logits if return_dict else outputs[0]
22222227
loss = None
@@ -2242,6 +2247,7 @@ def forward(
22422247
output_hidden_states=output_hidden_states,
22432248
return_dict=True, # toggle for easier access to loss/logits below
22442249
labels=labels,
2250+
use_cache=use_cache,
22452251
)
22462252
loss = outputs.loss
22472253
logits = outputs.logits

src/transformers/models/instructblip/modeling_instructblip.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,9 @@ def _init_weights(self, module):
441441
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
442442
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
443443
Whether to interpolate the pre-trained position encodings.
444+
use_cache (`bool`, *optional*):
445+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
446+
`past_key_values`).
444447
"""
445448

446449

@@ -1375,6 +1378,7 @@ def forward(
13751378
labels: Optional[torch.LongTensor] = None,
13761379
return_dict: Optional[bool] = None,
13771380
interpolate_pos_encoding: bool = False,
1381+
use_cache: Optional[bool] = None,
13781382
) -> Union[Tuple, InstructBlipForConditionalGenerationModelOutput]:
13791383
r"""
13801384
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1485,6 +1489,7 @@ def forward(
14851489
output_attentions=output_attentions,
14861490
output_hidden_states=output_hidden_states,
14871491
return_dict=return_dict,
1492+
use_cache=use_cache,
14881493
)
14891494
logits = outputs.logits if return_dict else outputs[0]
14901495
loss = None
@@ -1510,6 +1515,7 @@ def forward(
15101515
output_hidden_states=output_hidden_states,
15111516
return_dict=return_dict,
15121517
labels=labels,
1518+
use_cache=use_cache,
15131519
)
15141520
loss = outputs.loss if return_dict else outputs[0]
15151521
logits = outputs.logits if return_dict else outputs[1]

src/transformers/models/instructblipvideo/modeling_instructblipvideo.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,6 +1265,9 @@ def forward(
12651265
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
12661266
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
12671267
Whether to interpolate the pre-trained position encodings.
1268+
use_cache (`bool`, *optional*):
1269+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1270+
`past_key_values`).
12681271
"""
12691272

12701273

@@ -1369,6 +1372,7 @@ def forward(
13691372
labels: Optional[torch.LongTensor] = None,
13701373
return_dict: Optional[bool] = None,
13711374
interpolate_pos_encoding: bool = False,
1375+
use_cache: Optional[bool] = None,
13721376
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
13731377
r"""
13741378
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1512,6 +1516,7 @@ def forward(
15121516
output_attentions=output_attentions,
15131517
output_hidden_states=output_hidden_states,
15141518
return_dict=return_dict,
1519+
use_cache=use_cache,
15151520
)
15161521
logits = outputs.logits if return_dict else outputs[0]
15171522
loss = None
@@ -1537,6 +1542,7 @@ def forward(
15371542
output_hidden_states=output_hidden_states,
15381543
return_dict=return_dict,
15391544
labels=labels,
1545+
use_cache=use_cache,
15401546
)
15411547
loss = outputs.loss if return_dict else outputs[0]
15421548
logits = outputs.logits if return_dict else outputs[1]

src/transformers/models/instructblipvideo/modular_instructblipvideo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def forward(
188188
labels: Optional[torch.LongTensor] = None,
189189
return_dict: Optional[bool] = None,
190190
interpolate_pos_encoding: bool = False,
191+
use_cache: Optional[bool] = None,
191192
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]:
192193
r"""
193194
```python
@@ -322,6 +323,7 @@ def forward(
322323
output_attentions=output_attentions,
323324
output_hidden_states=output_hidden_states,
324325
return_dict=return_dict,
326+
use_cache=use_cache,
325327
)
326328
logits = outputs.logits if return_dict else outputs[0]
327329
loss = None
@@ -347,6 +349,7 @@ def forward(
347349
output_hidden_states=output_hidden_states,
348350
return_dict=return_dict,
349351
labels=labels,
352+
use_cache=use_cache,
350353
)
351354
loss = outputs.loss if return_dict else outputs[0]
352355
logits = outputs.logits if return_dict else outputs[1]

src/transformers/models/kosmos2/modeling_kosmos2.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,7 @@ def prepare_inputs_for_generation(
16941694
past_key_values=None,
16951695
attention_mask=None,
16961696
use_cache=None,
1697+
cache_position=None,
16971698
**model_kwargs,
16981699
):
16991700
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -1704,17 +1705,21 @@ def prepare_inputs_for_generation(
17041705
attention_mask = input_ids.new_ones(input_shape)
17051706

17061707
position_ids = None
1708+
if cache_position is None:
1709+
past_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1710+
cache_position = torch.arange(past_length, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
17071711

1708-
# cut input_ids if past_key_values is used
17091712
if past_key_values is not None:
17101713
position_ids = create_position_ids_from_input_ids(
17111714
input_ids,
17121715
padding_idx=self.config.pad_token_id,
17131716
past_key_values_length=0,
1714-
)[:, -1:]
1717+
)
1718+
1719+
if input_ids.shape[1] != cache_position.shape[0]:
1720+
input_ids = input_ids[:, cache_position]
1721+
position_ids = position_ids[:, -input_ids.shape[1] :]
17151722

1716-
input_ids = input_ids[:, -1:]
1717-
# the image info. is already encoded into the past keys/values
17181723
image_embeds = None
17191724
image_embeds_position_mask = None
17201725
elif image_embeds_position_mask is not None:

tests/generation/test_utils.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ def test_greedy_generate_dict_outputs_use_cache(self):
516516
if self.has_attentions:
517517
config._attn_implementation = "eager" # can't output attentions otherwise
518518

519-
if not hasattr(config, "use_cache"):
519+
if not hasattr(config.get_text_config(), "use_cache"):
520520
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
521521
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
522522
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
@@ -651,7 +651,7 @@ def test_beam_search_generate_dict_outputs_use_cache(self):
651651
for model_class in self.all_generative_model_classes:
652652
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
653653

654-
if not hasattr(config, "use_cache"):
654+
if not hasattr(config.get_text_config(), "use_cache"):
655655
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
656656
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
657657
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
@@ -989,7 +989,7 @@ def test_contrastive_generate(self):
989989
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
990990

991991
# NOTE: contrastive search only works with cache on at the moment.
992-
if not hasattr(config, "use_cache"):
992+
if not hasattr(config.get_text_config(), "use_cache"):
993993
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
994994
config.is_decoder = True
995995

@@ -1018,7 +1018,7 @@ def test_contrastive_generate_dict_outputs_use_cache(self):
10181018
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
10191019

10201020
# NOTE: contrastive search only works with cache on at the moment.
1021-
if not hasattr(config, "use_cache"):
1021+
if not hasattr(config.get_text_config(), "use_cache"):
10221022
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
10231023
config.is_decoder = True
10241024
if self.has_attentions:
@@ -1060,7 +1060,7 @@ def test_contrastive_generate_low_memory(self):
10601060
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
10611061

10621062
# NOTE: contrastive search only works with cache on at the moment.
1063-
if not hasattr(config, "use_cache"):
1063+
if not hasattr(config.get_text_config(), "use_cache"):
10641064
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
10651065

10661066
config.is_decoder = True
@@ -1179,6 +1179,10 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
11791179
"prophetnet",
11801180
"seamlessm4t",
11811181
"clvp",
1182+
"mllama", # special cache sizes
1183+
"blip2", # overridden `generate()`
1184+
"instructblip",
1185+
"instructblipvideo",
11821186
]
11831187
):
11841188
self.skipTest(reason="May fix in the future: need model-specific fixes")
@@ -1187,7 +1191,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
11871191
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
11881192

11891193
# NOTE: assisted generation only works with cache on at the moment.
1190-
if not hasattr(config, "use_cache"):
1194+
if not hasattr(config.get_text_config(), "use_cache"):
11911195
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
11921196

11931197
config.is_decoder = True
@@ -1254,6 +1258,10 @@ def test_prompt_lookup_decoding_matches_greedy_search(self):
12541258
"seamlessm4t",
12551259
"clvp",
12561260
"fuyu",
1261+
"mllama", # special cache sizes
1262+
"blip2", # overridden `generate()`
1263+
"instructblip",
1264+
"instructblipvideo",
12571265
]
12581266
):
12591267
self.skipTest(reason="May fix in the future: need model-specific fixes")
@@ -1262,7 +1270,7 @@ def test_prompt_lookup_decoding_matches_greedy_search(self):
12621270
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
12631271

12641272
# NOTE: assisted generation only works with cache on at the moment.
1265-
if not hasattr(config, "use_cache"):
1273+
if not hasattr(config.get_text_config(), "use_cache"):
12661274
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
12671275

12681276
config.is_decoder = True
@@ -1368,6 +1376,10 @@ def test_assisted_decoding_sample(self):
13681376
"prophetnet",
13691377
"seamlessm4t",
13701378
"clvp",
1379+
"mllama", # special cache sizes
1380+
"blip2", # overridden `generate()`
1381+
"instructblip",
1382+
"instructblipvideo",
13711383
]
13721384
):
13731385
self.skipTest(reason="May fix in the future: need model-specific fixes")
@@ -1376,7 +1388,7 @@ def test_assisted_decoding_sample(self):
13761388
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
13771389

13781390
# NOTE: assisted generation only works with cache on at the moment.
1379-
if not hasattr(config, "use_cache"):
1391+
if not hasattr(config.get_text_config(), "use_cache"):
13801392
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
13811393

13821394
config.is_decoder = True
@@ -1570,7 +1582,7 @@ def test_past_key_values_format(self):
15701582
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
15711583

15721584
# If it doesn't support cache, pass the test
1573-
if not hasattr(config, "use_cache"):
1585+
if not hasattr(config.get_text_config(), "use_cache"):
15741586
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
15751587

15761588
model = model_class(config).to(torch_device)
@@ -1605,7 +1617,14 @@ def test_past_key_values_format(self):
16051617

16061618
# Encoder-Decoder checks
16071619
if config.is_encoder_decoder:
1608-
encoder_num_attention_heads = config.encoder_attention_heads
1620+
# encoder-decoder models usually don't have text config
1621+
# below is needed only for Pix2Struct which we cannot modify now due to BC
1622+
config = config.get_text_config()
1623+
encoder_num_attention_heads = (
1624+
config.encoder_attention_heads
1625+
if hasattr(config, "encoder_attention_heads")
1626+
else config.num_attention_heads
1627+
)
16091628
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
16101629
batch_size, seq_length = inputs["decoder_input_ids"].shape
16111630
for i in range(num_hidden_layers):
@@ -1804,14 +1823,14 @@ def test_generate_from_inputs_embeds_with_static_cache(self):
18041823
def test_generate_continue_from_past_key_values(self):
18051824
# Tests that we can continue generating from past key values, returned from a previous `generate` call
18061825
for model_class in self.all_generative_model_classes:
1807-
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
1826+
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt", "mllama"]):
18081827
self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
18091828
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
18101829
self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")
18111830

18121831
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
18131832

1814-
if not hasattr(config, "use_cache"):
1833+
if not hasattr(config.get_text_config(), "use_cache"):
18151834
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
18161835

18171836
# Let's make it always:
@@ -2251,7 +2270,7 @@ def test_assisted_decoding_with_logits_to_keep(self):
22512270

22522271
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
22532272
# NOTE: assisted generation only works with cache on at the moment.
2254-
if not hasattr(config, "use_cache"):
2273+
if not hasattr(config.get_text_config(), "use_cache"):
22552274
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
22562275
config.use_cache = True
22572276
config.is_decoder = True

tests/models/aria/test_modeling_aria.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,14 @@ def __init__(
8282
moe_intermediate_size=4,
8383
moe_num_experts=4,
8484
moe_topk=2,
85-
num_attention_heads=20,
85+
num_attention_heads=8,
8686
num_experts_per_tok=3,
8787
num_hidden_layers=2,
88-
num_key_value_heads=20,
88+
num_key_value_heads=8,
8989
rope_theta=5000000,
9090
vocab_size=99,
9191
eos_token_id=2,
92-
head_dim=2,
92+
head_dim=4,
9393
),
9494
is_training=True,
9595
vision_config=Idefics3VisionConfig(

0 commit comments

Comments
 (0)