Skip to content

Commit 2ae6ca2

Browse files
authored
Fix beam search test for latest optimum (#1290)
* fix beam search test for latest optimum * more tests fixes * fix seq2seq beam search and update mixtral
1 parent 7a716f8 commit 2ae6ca2

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

tests/openvino/test_modeling.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,6 +1405,10 @@ def test_pipeline(self, model_arch):
14051405
if model_arch == "qwen":
14061406
tokenizer._convert_tokens_to_ids = lambda x: 0
14071407

1408+
additional_args = {}
1409+
if is_transformers_version(">=", "4.51"):
1410+
additional_args["use_model_defaults"] = False
1411+
14081412
model = OVModelForCausalLM.from_pretrained(model_id, use_cache=False, compile=False, **model_kwargs)
14091413
model.eval()
14101414
model.config.encoder_no_repeat_ngram_size = 0
@@ -1414,7 +1418,7 @@ def test_pipeline(self, model_arch):
14141418
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
14151419
inputs = "My name is Arthur and I live in"
14161420
set_seed(SEED)
1417-
outputs = pipe(inputs, max_new_tokens=5)
1421+
outputs = pipe(inputs, max_new_tokens=5, **additional_args, do_sample=False)
14181422
self.assertEqual(pipe.device, model.device)
14191423
self.assertTrue(all(inputs in item["generated_text"] for item in outputs))
14201424
ov_pipe = optimum_pipeline(
@@ -1425,7 +1429,7 @@ def test_pipeline(self, model_arch):
14251429
tokenizer=tokenizer if model_arch == "qwen" else None,
14261430
)
14271431
set_seed(SEED)
1428-
ov_outputs = ov_pipe(inputs, max_new_tokens=5)
1432+
ov_outputs = ov_pipe(inputs, max_new_tokens=5, **additional_args, do_sample=False)
14291433
self.assertEqual(outputs[-1]["generated_text"], ov_outputs[-1]["generated_text"])
14301434
del ov_pipe
14311435
del pipe
@@ -1625,15 +1629,14 @@ def test_beam_search(self, model_arch):
16251629
set_seed(SEED)
16261630
with mock_torch_cuda_is_available("awq" in model_arch or "gptq" in model_arch):
16271631
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
1628-
16291632
if model_arch == "arctic":
16301633
transformers_model.to(torch.float32)
16311634
additional_inputs = {}
16321635
# gemma2 does not support dynamic cache, it is unfair to compare dynamic cache result vs hybrid cache, align cache representation in torch model
1633-
if model_arch == "gemma2":
1636+
if model_arch in ["gemma2", "gemma3-text"]:
16341637
patch_update_causal_mask(transformers_model, "4.43.0")
16351638
transformers_model._supports_cache_class = True
1636-
from transformers.cache_utils import DynamicCache
1639+
transformers_model.generation_config.cache_implementation = None
16371640
tokenizer.pad_token_id = tokenizer.eos_token_id
16381641
tokenization_args = {}
16391642
if is_transformers_version(">=", "4.45") and model_arch == "gpt_neo":
@@ -1645,8 +1648,17 @@ def test_beam_search(self, model_arch):
16451648
**tokenization_args,
16461649
)
16471650
ov_model_stateful.generation_config.eos_token_id = None
1651+
ov_model_stateful.generation_config.forced_eos_token_id = None
1652+
ov_model_stateful.generation_config.encoder_no_repeat_ngram_size = None
1653+
ov_model_stateful.generation_config.do_sample = False
16481654
ov_model_stateless.generation_config.eos_token_id = None
1655+
ov_model_stateless.generation_config.forced_eos_token_id = None
1656+
ov_model_stateless.generation_config.encoder_no_repeat_ngram_size = None
1657+
ov_model_stateless.generation_config.do_sample = False
16491658
transformers_model.generation_config.eos_token_id = None
1659+
transformers_model.generation_config.forced_eos_token_id = None
1660+
transformers_model.generation_config.encoder_no_repeat_ngram_size = None
1661+
transformers_model.generation_config.do_sample = False
16501662
ov_model_stateful.config.eos_token_id = None
16511663
ov_model_stateless.config.eos_token_id = None
16521664
transformers_model.config.eos_token_id = None
@@ -1657,10 +1669,14 @@ def test_beam_search(self, model_arch):
16571669
for gen_config in gen_configs:
16581670
if gen_config.do_sample and model_arch in ["baichuan2-13b", "olmo"]:
16591671
continue
1672+
if gen_config.num_beams > 1 and is_transformers_version(">=", "4.51.0") and model_arch in ["mixtral_awq"]:
1673+
continue
16601674
set_seed(SEED)
16611675

1662-
if model_arch == "gemma2":
1663-
additional_inputs = {"past_key_values": DynamicCache()}
1676+
if model_arch in ["gemma2", "gemma3-text"]:
1677+
from transformers.cache_utils import DynamicCache
1678+
1679+
additional_inputs["past_key_values"] = DynamicCache()
16641680
with patch_awq_for_inference("awq" in model_arch):
16651681
transformers_outputs = transformers_model.generate(
16661682
**tokens, generation_config=gen_config, **additional_inputs

tests/openvino/utils_tests.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
"mistral": "echarlaix/tiny-random-mistral",
110110
"mistral-nemo": "katuni4ka/tiny-random-mistral-nemo",
111111
"mixtral": "TitanML/tiny-mixtral",
112-
"mixtral_awq": "TitanML/tiny-mixtral-AWQ-4bit",
112+
"mixtral_awq": "katuni4ka/tiny-mixtral-AWQ-4bit",
113113
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
114114
"mobilenet_v1": "google/mobilenet_v1_0.75_192",
115115
"mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
@@ -138,8 +138,8 @@
138138
"qwen2-moe": "katuni4ka/tiny-random-qwen1.5-moe",
139139
"qwen2_vl": "katuni4ka/tiny-random-qwen2vl",
140140
"qwen2_5_vl": "katuni4ka/tiny-random-qwen2.5-vl",
141-
"qwen3": "snake7gun/tiny-random-qwen3",
142-
"qwen3-moe": "snake7gun/tiny-random-qwen3moe",
141+
"qwen3": "katuni4ka/tiny-random-qwen3",
142+
"qwen3-moe": "katuni4ka/tiny-random-qwen3moe",
143143
"resnet": "hf-internal-testing/tiny-random-resnet",
144144
"roberta": "hf-internal-testing/tiny-random-roberta",
145145
"roformer": "hf-internal-testing/tiny-random-roformer",

0 commit comments

Comments
 (0)