Skip to content

Commit c69fe32

Browse files
eaidovaecharlaix
andauthored
Add support export for new architectures (#720)
* update codegen config for support codegen2 * add support DBRX * add qwen2moe support * fix test models * buichuan sdpa * apply review comments * Apply suggestions from code review Co-authored-by: Ella Charlaix <[email protected]> * Apply suggestions from code review Co-authored-by: Ella Charlaix <[email protected]> --------- Co-authored-by: Ella Charlaix <[email protected]>
1 parent 7a929e8 commit c69fe32

File tree

4 files changed

+396
-1
lines changed

4 files changed

+396
-1
lines changed

optimum/exporters/openvino/model_configs.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
2222
from optimum.exporters.onnx.model_configs import (
23+
CodeGenOnnxConfig,
2324
FalconOnnxConfig,
2425
GemmaOnnxConfig,
2526
LlamaOnnxConfig,
@@ -44,6 +45,8 @@
4445
AquilaModelPatcher,
4546
BaichuanModelPatcher,
4647
ChatGLMModelPatcher,
48+
CodeGenModelPatcher,
49+
DBRXModelPatcher,
4750
GemmaModelPatcher,
4851
InternLM2Patcher,
4952
InternLMModelPatcher,
@@ -112,6 +115,15 @@ class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
112115
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
113116

114117

118+
@register_in_tasks_manager("qwen2-moe", *["text-generation", "text-generation-with-past"], library_name="transformers")
119+
class Qwen2MoEOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
120+
DEFAULT_ONNX_OPSET = 14
121+
122+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
123+
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
124+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
125+
126+
115127
@register_in_tasks_manager("minicpm", *["text-generation", "text-generation-with-past"], library_name="transformers")
116128
class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
117129
DEFAULT_ONNX_OPSET = 14
@@ -738,3 +750,38 @@ def patch_model_for_export(
738750
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
739751
) -> "ModelPatcher":
740752
return InternLMModelPatcher(self, model, model_kwargs=model_kwargs)
753+
754+
755+
@register_in_tasks_manager(
756+
"codegen",
757+
*["feature-extraction", "feature-extraction-with-past", "text-generation", "text-generation-with-past"],
758+
library_name="transformers",
759+
)
760+
class CodeGenOpenVINOConfig(CodeGenOnnxConfig):
761+
def patch_model_for_export(
762+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
763+
) -> "ModelPatcher":
764+
return CodeGenModelPatcher(self, model, model_kwargs=model_kwargs)
765+
766+
767+
@register_in_tasks_manager(
768+
"dbrx",
769+
*["text-generation", "text-generation-with-past"],
770+
library_name="transformers",
771+
)
772+
class DBRXOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
773+
DEFAULT_ONNX_OPSET = 14
774+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
775+
num_attention_heads="n_heads",
776+
hidden_size="d_model",
777+
num_layers="n_layers",
778+
num_key_value_heads="attn_config.kv_n_heads",
779+
allow_new=True,
780+
)
781+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator)
782+
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
783+
784+
def patch_model_for_export(
785+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
786+
) -> "ModelPatcher":
787+
return DBRXModelPatcher(self, model, model_kwargs=model_kwargs)

0 commit comments

Comments
 (0)