Skip to content

Commit 222748e

Browse files
authored
fix conversion for text embeddings for fp16 models (#968)
* fix conversion for text embeddings for fp16 models * fix rebasing issue * apply review comments * Update tests/openvino/utils_tests.py
1 parent d357376 commit 222748e

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

optimum/exporters/openvino/model_configs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
GptNeoxJapaneseModelPatcher,
7070
GptNeoxModelPatcher,
7171
IBertModelPatcher,
72+
InputEmbeddingPatcher,
7273
InternLM2Patcher,
7374
InternLMModelPatcher,
7475
InternVLChatImageEmbeddingModelPatcher,
@@ -1264,6 +1265,12 @@ def rename_ambiguous_inputs(self, inputs):
12641265
model_inputs["input"] = inputs["input_ids"]
12651266
return model_inputs
12661267

1268+
def patch_model_for_export(
1269+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
1270+
) -> "ModelPatcher":
1271+
# making 16bit tracable overrides embeedings input signature these changes required to prevent this issue
1272+
return InputEmbeddingPatcher(self, model, model_kwargs)
1273+
12671274

12681275
class LlavaConfigBehavior(str, enum.Enum):
12691276
LANGUAGE = "language"

optimum/exporters/openvino/model_patcher.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2991,3 +2991,24 @@ def __init__(
29912991
def __exit__(self, exc_type, exc_value, traceback):
29922992
super().__exit__(exc_type, exc_value, traceback)
29932993
self._model.forward = self._model.__orig_forward
2994+
2995+
2996+
class InputEmbeddingPatcher(ModelPatcher):
2997+
def __init__(
2998+
self,
2999+
config: "OnnxConfig",
3000+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
3001+
model_kwargs: Dict[str, Any],
3002+
):
3003+
model.__orig_forward = model.forward
3004+
3005+
def forward(self, input):
3006+
return self.__orig_forward(input)
3007+
3008+
model.forward = types.MethodType(forward, model)
3009+
3010+
super().__init__(config, model, model_kwargs)
3011+
3012+
def __exit__(self, exc_type, exc_value, traceback):
3013+
super().__exit__(exc_type, exc_value, traceback)
3014+
self._model.forward = self._model.__orig_forward

0 commit comments

Comments
 (0)