Skip to content

Commit 8127f65

Browse files
openvino-dev-samplesmvafin
authored andcommitted
Add Ernie 4.5 OpenVINO support (#1366)
* rebase to non-remote model * update min transformers lib for ernie 4.5 * add export test * switch to OVDecoderModelPatcher
1 parent 47a12dd commit 8127f65

File tree

6 files changed

+32
-0
lines changed

6 files changed

+32
-0
lines changed

docs/source/openvino/models.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ Here is the list of the supported architectures :
4747
- Deepseek_v2
4848
- Deepseek_v3
4949
- DistilBert
50+
- Ernie4.5
5051
- Electra
5152
- Encoder Decoder
5253
- ESM

optimum/exporters/openvino/model_configs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4512,3 +4512,17 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
45124512
)
45134513

45144514
return dummy_inputs
4515+
4516+
4517+
@register_in_tasks_manager("ernie4_5", *["text-generation", "text-generation-with-past"], library_name="transformers")
4518+
class ErnieOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
4519+
MIN_TRANSFORMERS_VERSION = "4.54.0"
4520+
4521+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator)
4522+
DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
4523+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
4524+
4525+
def patch_model_for_export(
4526+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
4527+
) -> "ModelPatcher":
4528+
return OVDecoderModelPatcher(self, model, model_kwargs=model_kwargs)

tests/openvino/test_export.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ class ExportModelTest(unittest.TestCase):
9898
{"stable-diffusion-3": OVStableDiffusion3Pipeline, "flux": OVFluxPipeline, "ltx-video": OVLTXPipeline}
9999
)
100100

101+
if is_transformers_version(">=", "4.54"):
102+
SUPPORTED_ARCHITECTURES.update({"ernie4_5": OVModelForCausalLM})
103+
101104
GENERATIVE_MODELS = ("pix2struct", "t5", "bart", "gpt2", "whisper", "llava", "speecht5")
102105

103106
def _openvino_export(

tests/openvino/test_exporters_cli.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,14 @@ class OVCLIExportTestCase(unittest.TestCase):
114114
("text-to-video", "ltx-video"),
115115
]
116116
)
117+
118+
if is_transformers_version(">=", "4.54"):
119+
SUPPORTED_ARCHITECTURES.extend(
120+
[
121+
("text-generation-with-past", "ernie4_5"),
122+
]
123+
)
124+
117125
EXPECTED_NUMBER_OF_TOKENIZER_MODELS = {
118126
"gpt2": 2 if is_tokenizers_version("<", "0.20") or is_openvino_version(">=", "2024.5") else 0,
119127
"t5": 0 if is_openvino_version("<", "2025.1") else 2, # 2025.1 brings support for unigram tokenizers
@@ -137,6 +145,7 @@ class OVCLIExportTestCase(unittest.TestCase):
137145
"clip": 2 if is_tokenizers_version("<", "0.20.0") or is_openvino_version(">=", "2024.5") else 0,
138146
"mamba": 2,
139147
"falcon-mamba": 2,
148+
"ernie4_5": 2,
140149
}
141150

142151
TOKENIZER_CHAT_TEMPLATE_TESTS_MODELS = {

tests/openvino/test_modeling.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,9 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
11831183
if is_transformers_version(">=", "4.51.3"):
11841184
SUPPORTED_ARCHITECTURES += ("glm4",)
11851185

1186+
if is_transformers_version(">=", "4.54.0"):
1187+
SUPPORTED_ARCHITECTURES += ("ernie4_5",)
1188+
11861189
GENERATION_LENGTH = 100
11871190
REMOTE_CODE_MODELS = (
11881191
"chatglm",
@@ -1270,6 +1273,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
12701273
"qwen3_moe": 2,
12711274
"mamba": 0,
12721275
"falcon-mamba": 0,
1276+
"ernie4_5": 2,
12731277
}
12741278

12751279
# TODO: remove gptq/awq from here

tests/openvino/utils_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
"donut": "fxmarty/tiny-doc-qa-vision-encoder-decoder",
6464
"donut-swin": "hf-internal-testing/tiny-random-DonutSwinModel",
6565
"detr": "hf-internal-testing/tiny-random-DetrModel",
66+
"ernie4_5": "optimum-internal-testing/tiny-random-Ernie4_5ForCausalLM",
6667
"electra": "hf-internal-testing/tiny-random-electra",
6768
"esm": "hf-internal-testing/tiny-random-EsmModel",
6869
"exaone": "katuni4ka/tiny-random-exaone",

0 commit comments

Comments
 (0)