@@ -516,7 +516,7 @@ def test_greedy_generate_dict_outputs_use_cache(self):
516
516
if self .has_attentions :
517
517
config ._attn_implementation = "eager" # can't output attentions otherwise
518
518
519
- if not hasattr (config , "use_cache" ):
519
+ if not hasattr (config . get_text_config () , "use_cache" ):
520
520
self .skipTest (reason = f"{ model_class .__name__ } doesn't support caching" )
521
521
if any (model_name in model_class .__name__ .lower () for model_name in ["rwkv" ]):
522
522
self .skipTest (reason = "Won't fix: model with non-standard dictionary output shapes" )
@@ -651,7 +651,7 @@ def test_beam_search_generate_dict_outputs_use_cache(self):
651
651
for model_class in self .all_generative_model_classes :
652
652
config , inputs_dict = self .prepare_config_and_inputs_for_generate ()
653
653
654
- if not hasattr (config , "use_cache" ):
654
+ if not hasattr (config . get_text_config () , "use_cache" ):
655
655
self .skipTest (reason = f"{ model_class .__name__ } doesn't support caching" )
656
656
if any (model_name in model_class .__name__ .lower () for model_name in ["rwkv" ]):
657
657
self .skipTest (reason = "Won't fix: model with non-standard dictionary output shapes" )
@@ -989,7 +989,7 @@ def test_contrastive_generate(self):
989
989
config , inputs_dict = self .prepare_config_and_inputs_for_generate ()
990
990
991
991
# NOTE: contrastive search only works with cache on at the moment.
992
- if not hasattr (config , "use_cache" ):
992
+ if not hasattr (config . get_text_config () , "use_cache" ):
993
993
self .skipTest (reason = f"{ model_class .__name__ } doesn't support caching" )
994
994
config .is_decoder = True
995
995
@@ -1018,7 +1018,7 @@ def test_contrastive_generate_dict_outputs_use_cache(self):
1018
1018
config , inputs_dict = self .prepare_config_and_inputs_for_generate ()
1019
1019
1020
1020
# NOTE: contrastive search only works with cache on at the moment.
1021
- if not hasattr (config , "use_cache" ):
1021
+ if not hasattr (config . get_text_config () , "use_cache" ):
1022
1022
self .skipTest (reason = f"{ model_class .__name__ } doesn't support caching" )
1023
1023
config .is_decoder = True
1024
1024
if self .has_attentions :
@@ -1060,7 +1060,7 @@ def test_contrastive_generate_low_memory(self):
1060
1060
config , inputs_dict = self .prepare_config_and_inputs_for_generate (batch_size = 1 )
1061
1061
1062
1062
# NOTE: contrastive search only works with cache on at the moment.
1063
- if not hasattr (config , "use_cache" ):
1063
+ if not hasattr (config . get_text_config () , "use_cache" ):
1064
1064
self .skipTest (reason = f"{ model_class .__name__ } doesn't support caching" )
1065
1065
1066
1066
config .is_decoder = True
@@ -1179,6 +1179,10 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
1179
1179
"prophetnet" ,
1180
1180
"seamlessm4t" ,
1181
1181
"clvp" ,
1182
+ "mllama" , # special cache sizes
1183
+ "blip2" , # overridden `generate()`
1184
+ "instructblip" ,
1185
+ "instructblipvideo" ,
1182
1186
]
1183
1187
):
1184
1188
self .skipTest (reason = "May fix in the future: need model-specific fixes" )
@@ -1187,7 +1191,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
1187
1191
config , inputs_dict = self .prepare_config_and_inputs_for_generate (batch_size = 1 )
1188
1192
1189
1193
# NOTE: assisted generation only works with cache on at the moment.
1190
- if not hasattr (config , "use_cache" ):
1194
+ if not hasattr (config . get_text_config () , "use_cache" ):
1191
1195
self .skipTest (reason = f"{ model_class .__name__ } doesn't support caching" )
1192
1196
1193
1197
config .is_decoder = True
@@ -1254,6 +1258,10 @@ def test_prompt_lookup_decoding_matches_greedy_search(self):
1254
1258
"seamlessm4t" ,
1255
1259
"clvp" ,
1256
1260
"fuyu" ,
1261
+ "mllama" , # special cache sizes
1262
+ "blip2" , # overridden `generate()`
1263
+ "instructblip" ,
1264
+ "instructblipvideo" ,
1257
1265
]
1258
1266
):
1259
1267
self .skipTest (reason = "May fix in the future: need model-specific fixes" )
@@ -1262,7 +1270,7 @@ def test_prompt_lookup_decoding_matches_greedy_search(self):
1262
1270
config , inputs_dict = self .prepare_config_and_inputs_for_generate (batch_size = 1 )
1263
1271
1264
1272
# NOTE: assisted generation only works with cache on at the moment.
1265
- if not hasattr (config , "use_cache" ):
1273
+ if not hasattr (config . get_text_config () , "use_cache" ):
1266
1274
self .skipTest (reason = f"{ model_class .__name__ } doesn't support caching" )
1267
1275
1268
1276
config .is_decoder = True
@@ -1368,6 +1376,10 @@ def test_assisted_decoding_sample(self):
1368
1376
"prophetnet" ,
1369
1377
"seamlessm4t" ,
1370
1378
"clvp" ,
1379
+ "mllama" , # special cache sizes
1380
+ "blip2" , # overridden `generate()`
1381
+ "instructblip" ,
1382
+ "instructblipvideo" ,
1371
1383
]
1372
1384
):
1373
1385
self .skipTest (reason = "May fix in the future: need model-specific fixes" )
@@ -1376,7 +1388,7 @@ def test_assisted_decoding_sample(self):
1376
1388
config , inputs_dict = self .prepare_config_and_inputs_for_generate (batch_size = 1 )
1377
1389
1378
1390
# NOTE: assisted generation only works with cache on at the moment.
1379
- if not hasattr (config , "use_cache" ):
1391
+ if not hasattr (config . get_text_config () , "use_cache" ):
1380
1392
self .skipTest (reason = f"{ model_class .__name__ } doesn't support caching" )
1381
1393
1382
1394
config .is_decoder = True
@@ -1570,7 +1582,7 @@ def test_past_key_values_format(self):
1570
1582
config , inputs = self .model_tester .prepare_config_and_inputs_for_common ()
1571
1583
1572
1584
# If it doesn't support cache, pass the test
1573
- if not hasattr (config , "use_cache" ):
1585
+ if not hasattr (config . get_text_config () , "use_cache" ):
1574
1586
self .skipTest (reason = f"{ model_class .__name__ } doesn't support caching" )
1575
1587
1576
1588
model = model_class (config ).to (torch_device )
@@ -1605,7 +1617,14 @@ def test_past_key_values_format(self):
1605
1617
1606
1618
# Encoder-Decoder checks
1607
1619
if config .is_encoder_decoder :
1608
- encoder_num_attention_heads = config .encoder_attention_heads
1620
+ # encoder-decoder models usually don't have text config
1621
+ # below is needed only for Pix2Struct which we cannot modify now due to BC
1622
+ config = config .get_text_config ()
1623
+ encoder_num_attention_heads = (
1624
+ config .encoder_attention_heads
1625
+ if hasattr (config , "encoder_attention_heads" )
1626
+ else config .num_attention_heads
1627
+ )
1609
1628
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
1610
1629
batch_size , seq_length = inputs ["decoder_input_ids" ].shape
1611
1630
for i in range (num_hidden_layers ):
@@ -1804,14 +1823,14 @@ def test_generate_from_inputs_embeds_with_static_cache(self):
1804
1823
def test_generate_continue_from_past_key_values (self ):
1805
1824
# Tests that we can continue generating from past key values, returned from a previous `generate` call
1806
1825
for model_class in self .all_generative_model_classes :
1807
- if any (model_name in model_class .__name__ .lower () for model_name in ["imagegpt" ]):
1826
+ if any (model_name in model_class .__name__ .lower () for model_name in ["imagegpt" , "mllama" ]):
1808
1827
self .skipTest (reason = "Won't fix: old model with unique inputs/caches/other" )
1809
1828
if any (model_name in model_class .__name__ .lower () for model_name in ["umt5" ]):
1810
1829
self .skipTest (reason = "TODO: needs modeling or test input preparation fixes for compatibility" )
1811
1830
1812
1831
config , inputs = self .model_tester .prepare_config_and_inputs_for_common ()
1813
1832
1814
- if not hasattr (config , "use_cache" ):
1833
+ if not hasattr (config . get_text_config () , "use_cache" ):
1815
1834
self .skipTest (reason = f"{ model_class .__name__ } doesn't support caching" )
1816
1835
1817
1836
# Let's make it always:
@@ -2251,7 +2270,7 @@ def test_assisted_decoding_with_logits_to_keep(self):
2251
2270
2252
2271
config , inputs_dict = self .prepare_config_and_inputs_for_generate (batch_size = 1 )
2253
2272
# NOTE: assisted generation only works with cache on at the moment.
2254
- if not hasattr (config , "use_cache" ):
2273
+ if not hasattr (config . get_text_config () , "use_cache" ):
2255
2274
self .skipTest (reason = f"{ model_class .__name__ } doesn't support caching" )
2256
2275
config .use_cache = True
2257
2276
config .is_decoder = True
0 commit comments