|
20 | 20 |
|
21 | 21 | from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig |
22 | 22 | from optimum.exporters.onnx.model_configs import ( |
| 23 | + CodeGenOnnxConfig, |
23 | 24 | FalconOnnxConfig, |
24 | 25 | GemmaOnnxConfig, |
25 | 26 | LlamaOnnxConfig, |
|
44 | 45 | AquilaModelPatcher, |
45 | 46 | BaichuanModelPatcher, |
46 | 47 | ChatGLMModelPatcher, |
| 48 | + CodeGenModelPatcher, |
| 49 | + DBRXModelPatcher, |
47 | 50 | GemmaModelPatcher, |
48 | 51 | InternLM2Patcher, |
49 | 52 | InternLMModelPatcher, |
@@ -112,6 +115,15 @@ class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): |
112 | 115 | NORMALIZED_CONFIG_CLASS = NormalizedTextConfig |
113 | 116 |
|
114 | 117 |
|
| 118 | +@register_in_tasks_manager("qwen2-moe", *["text-generation", "text-generation-with-past"], library_name="transformers") |
| 119 | +class Qwen2MoEOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): |
| 120 | + DEFAULT_ONNX_OPSET = 14 |
| 121 | + |
| 122 | + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) |
| 123 | + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator |
| 124 | + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig |
| 125 | + |
| 126 | + |
115 | 127 | @register_in_tasks_manager("minicpm", *["text-generation", "text-generation-with-past"], library_name="transformers") |
116 | 128 | class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): |
117 | 129 | DEFAULT_ONNX_OPSET = 14 |
@@ -738,3 +750,38 @@ def patch_model_for_export( |
738 | 750 | self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None |
739 | 751 | ) -> "ModelPatcher": |
740 | 752 | return InternLMModelPatcher(self, model, model_kwargs=model_kwargs) |
| 753 | + |
| 754 | + |
| 755 | +@register_in_tasks_manager( |
| 756 | + "codegen", |
| 757 | + *["feature-extraction", "feature-extraction-with-past", "text-generation", "text-generation-with-past"], |
| 758 | + library_name="transformers", |
| 759 | +) |
| 760 | +class CodeGenOpenVINOConfig(CodeGenOnnxConfig): |
| 761 | + def patch_model_for_export( |
| 762 | + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None |
| 763 | + ) -> "ModelPatcher": |
| 764 | + return CodeGenModelPatcher(self, model, model_kwargs=model_kwargs) |
| 765 | + |
| 766 | + |
| 767 | +@register_in_tasks_manager( |
| 768 | + "dbrx", |
| 769 | + *["text-generation", "text-generation-with-past"], |
| 770 | + library_name="transformers", |
| 771 | +) |
| 772 | +class DBRXOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): |
| 773 | + DEFAULT_ONNX_OPSET = 14 |
| 774 | + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( |
| 775 | + num_attention_heads="n_heads", |
| 776 | + hidden_size="d_model", |
| 777 | + num_layers="n_layers", |
| 778 | + num_key_value_heads="attn_config.kv_n_heads", |
| 779 | + allow_new=True, |
| 780 | + ) |
| 781 | + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) |
| 782 | + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator |
| 783 | + |
| 784 | + def patch_model_for_export( |
| 785 | + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None |
| 786 | + ) -> "ModelPatcher": |
| 787 | + return DBRXModelPatcher(self, model, model_kwargs=model_kwargs) |
0 commit comments