Skip to content

Commit 9c30790

Browse files
committed
style
1 parent 34ec242 commit 9c30790

File tree

2 files changed

+66
-19
lines changed

2 files changed

+66
-19
lines changed

tests/openvino/test_seq2seq.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,15 @@
4242
pipeline,
4343
set_seed,
4444
)
45+
from transformers.models.auto.configuration_auto import CONFIG_MAPPING_NAMES
4546
from transformers.onnx.utils import get_preprocessor
4647
from transformers.testing_utils import slow
4748
from transformers.utils import http_user_agent
4849
from utils_tests import MODEL_NAMES, TEST_IMAGE_URL
4950

5051
from optimum.exporters.openvino.model_patcher import patch_update_causal_mask
5152
from optimum.exporters.openvino.stateful import model_has_state
53+
from optimum.exporters.tasks import TasksManager
5254
from optimum.intel import (
5355
OVModelForPix2Struct,
5456
OVModelForSeq2SeqLM,
@@ -58,6 +60,12 @@
5860
OVModelForVisualCausalLM,
5961
)
6062
from optimum.intel.openvino.modeling_seq2seq import OVDecoder, OVEncoder
63+
from optimum.intel.openvino.modeling_text2speech import (
64+
OVTextToSpeechDecoder,
65+
OVTextToSpeechEncoder,
66+
OVTextToSpeechPostNet,
67+
OVTextToSpeechVocoder,
68+
)
6169
from optimum.intel.openvino.modeling_visual_language import MODEL_PARTS_CLS_MAPPING, MODEL_TYPE_TO_CLS_MAPPING
6270
from optimum.intel.pipelines import pipeline as optimum_pipeline
6371
from optimum.intel.utils.import_utils import is_openvino_version, is_transformers_version
@@ -100,24 +108,46 @@ def check_openvino_model_attributes(self, openvino_model, use_cache: bool = True
100108
self.assertEqual(openvino_model.decoder.stateful, stateful)
101109
self.assertEqual(model_has_state(openvino_model.decoder.model), stateful)
102110

111+
def _test_find_untested_architectures(self):
112+
if len(self.SUPPORTED_ARCHITECTURES) != len(set(self.SUPPORTED_ARCHITECTURES)):
113+
raise ValueError(
114+
f"For the task `{self.TASK}`, some architectures are duplicated in the list of tested architectures: "
115+
f"{self.SUPPORTED_ARCHITECTURES}.\n"
116+
)
117+
118+
tested_architectures = set(self.SUPPORTED_ARCHITECTURES)
119+
transformers_architectures = set(CONFIG_MAPPING_NAMES.keys())
120+
ov_architectures = set(TasksManager.get_supported_model_type_for_task(task=self.TASK, exporter="openvino"))
121+
supported_architectures = ov_architectures & transformers_architectures
122+
123+
untested_architectures = supported_architectures - tested_architectures
124+
125+
if len(untested_architectures) > 0:
126+
raise ValueError(
127+
f"For the task `{self.TASK}`, the OpenVINO exporter supports {untested_architectures} which are not tested"
128+
)
129+
103130

104131
class OVModelForSeq2SeqLMIntegrationTest(OVSeq2SeqTestMixin):
105132
SUPPORTED_ARCHITECTURES = (
106133
"bart",
107-
# "bigbird_pegasus",
134+
"bigbird_pegasus",
108135
"blenderbot",
109136
"blenderbot-small",
110-
# "longt5",
137+
"encoder-decoder",
138+
"longt5",
111139
"m2m_100",
140+
"marian",
112141
"mbart",
113142
"mt5",
114143
"pegasus",
115144
"t5",
116145
)
117-
GENERATION_LENGTH = 100
118-
SPEEDUP_CACHE = 1.1
119146
OVMODEL_CLASS = OVModelForSeq2SeqLM
120147
AUTOMODEL_CLASS = AutoModelForSeq2SeqLM
148+
TASK = "text2text-generation"
149+
GENERATION_LENGTH = 100
150+
SPEEDUP_CACHE = 1.1
121151

122152
if not (is_openvino_version(">=", "2025.3.0") and is_openvino_version("<", "2025.5.0")):
123153
# There are known issues with marian model on OpenVINO 2025.3.x and 2025.4.x
@@ -129,6 +159,9 @@ class OVModelForSeq2SeqLMIntegrationTest(OVSeq2SeqTestMixin):
129159
if is_transformers_version(">=", "4.53.0"):
130160
SUPPORT_STATEFUL += ("pegasus",)
131161

162+
def test_find_untested_architectures(self):
163+
self._test_find_untested_architectures()
164+
132165
@parameterized.expand(SUPPORTED_ARCHITECTURES)
133166
def test_compare_to_transformers(self, model_arch):
134167
model_id = MODEL_NAMES[model_arch]
@@ -278,9 +311,9 @@ def test_compare_with_and_without_past_key_values(self):
278311

279312
class OVModelForSpeechSeq2SeqIntegrationTest(OVSeq2SeqTestMixin):
280313
SUPPORTED_ARCHITECTURES = ("whisper",)
281-
282314
OVMODEL_CLASS = OVModelForSpeechSeq2Seq
283315
AUTOMODEL_CLASS = AutoModelForSpeechSeq2Seq
316+
TASK = "automatic-speech-recognition"
284317

285318
def _generate_random_audio_data(self):
286319
np.random.seed(10)
@@ -916,31 +949,44 @@ def _get_vocoder(self, vocoder_id, model_arch):
916949
else:
917950
raise Exception("{} unknown model for text-to-speech".format(model_arch))
918951

952+
def check_openvino_model_attributes(self, openvino_model, use_cache: bool = True):
953+
self.assertIsInstance(openvino_model, self.OVMODEL_CLASS)
954+
self.assertIsInstance(openvino_model.config, PretrainedConfig)
955+
self.assertIsInstance(openvino_model.generation_config, GenerationConfig)
956+
957+
self.assertIsInstance(openvino_model.encoder, OVTextToSpeechEncoder)
958+
self.assertIsInstance(openvino_model.decoder, OVTextToSpeechDecoder)
959+
self.assertIsInstance(openvino_model.postnet, OVTextToSpeechPostNet)
960+
self.assertIsInstance(openvino_model.vocoder, OVTextToSpeechVocoder)
961+
self.assertIsInstance(openvino_model.encoder.model, openvino.Model)
962+
self.assertIsInstance(openvino_model.decoder.model, openvino.Model)
963+
self.assertIsInstance(openvino_model.postnet.model, openvino.Model)
964+
self.assertIsInstance(openvino_model.vocoder.model, openvino.Model)
965+
966+
self.assertEqual(openvino_model.use_cache, use_cache)
967+
self.assertEqual(model_has_state(openvino_model.decoder.model), use_cache)
968+
919969
@parameterized.expand(SUPPORTED_ARCHITECTURES)
920970
def test_compare_to_transformers(self, model_arch):
921971
set_seed(SEED)
922972
text_data = self._generate_text()
923973
speaker_embeddings = self._generate_speaker_embedding()
924974
model_id = MODEL_NAMES[model_arch]
925975

926-
if model_arch == "speecht5":
927-
# since Auto class for text-to-audio is not implemented in optimum
928-
# generate model classes for reference generation
929-
vocoder_id = "fxmarty/speecht5-hifigan-tiny"
930-
processor = self._get_processor(model_id, model_arch)
931-
932-
model = self.AUTOMODEL_CLASS.from_pretrained(model_id)
976+
# since Auto class for text-to-audio is not implemented in optimum
977+
# generate model classes for reference generation
978+
vocoder_id = "fxmarty/speecht5-hifigan-tiny"
979+
processor = self._get_processor(model_id, model_arch)
980+
vocoder = self._get_vocoder(vocoder_id, model_arch)
981+
model = self.AUTOMODEL_CLASS.from_pretrained(model_id)
933982

934-
vocoder = self._get_vocoder(vocoder_id, model_arch)
935-
inputs = processor(text=text_data, return_tensors="pt")
936-
ref_speech = model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)
937-
ref_speech = ref_speech.unsqueeze(0) if ref_speech.dim() == 1 else ref_speech
938-
else:
939-
raise Exception("{} unknown model for text-to-speech".format(model_arch))
983+
inputs = processor(text=text_data, return_tensors="pt")
984+
ref_speech = model.generate_speech(inputs["input_ids"], speaker_embeddings, vocoder=vocoder)
985+
ref_speech = ref_speech.unsqueeze(0) if ref_speech.dim() == 1 else ref_speech
940986

941987
ov_model = self.OVMODEL_CLASS.from_pretrained(model_id, vocoder=vocoder_id)
942988
ov_speech = ov_model.generate(input_ids=inputs["input_ids"], speaker_embeddings=speaker_embeddings)
943-
self.check_openvino_model_attributes(ov_model, use_cache=True, stateful=True)
989+
self.check_openvino_model_attributes(ov_model, use_cache=True)
944990
self.assertTrue(torch.allclose(ov_speech, ref_speech, atol=1e-3))
945991

946992
del vocoder

tests/openvino/utils_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
"donut-swin": "hf-internal-testing/tiny-random-DonutSwinModel",
6767
"detr": "hf-internal-testing/tiny-random-DetrModel",
6868
"electra": "hf-internal-testing/tiny-random-electra",
69+
"encoder-decoder": "optimum-internal-testing/tiny-random-encoder-decoder-gpt2-bert",
6970
"esm": "hf-internal-testing/tiny-random-EsmModel",
7071
"exaone": "katuni4ka/tiny-random-exaone",
7172
"gemma": "fxmarty/tiny-random-GemmaForCausalLM",

0 commit comments

Comments
 (0)