|
91 | 91 | from optimum.intel.openvino.configuration import OVConfig |
92 | 92 |
|
93 | 93 |
|
94 | | -def _save_model(model, path: str, ov_config: Optional["OVConfig"] = None, library_name: Optional[str] = None): |
| 94 | +def _set_runtime_options( |
| 95 | + models_and_export_configs: Dict[ |
| 96 | + str, |
| 97 | + Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin", "DiffusionPipeline"], "OnnxConfig"], |
| 98 | + ], |
| 99 | + task: str, |
| 100 | +): |
| 101 | + for model_name in models_and_export_configs.keys(): |
| 102 | + _, sub_export_config = models_and_export_configs[model_name] |
| 103 | + if "vae_" in model_name or "text-generation" in task: |
| 104 | + sub_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"} |
| 105 | + |
| 106 | + |
| 107 | +def _save_model( |
| 108 | + model, |
| 109 | + path: str, |
| 110 | + ov_config: Optional["OVConfig"] = None, |
| 111 | + library_name: Optional[str] = None, |
| 112 | + config: OnnxConfig = None, |
| 113 | +): |
95 | 114 | compress_to_fp16 = ov_config is not None and ov_config.dtype == "fp16" |
96 | 115 | model = _add_version_info_to_model(model, library_name) |
| 116 | + |
| 117 | + if hasattr(config, "runtime_options"): |
| 118 | + model = _add_runtime_options_to_rt_info(model, config.runtime_options) |
97 | 119 | save_model(model, path, compress_to_fp16) |
98 | 120 |
|
99 | 121 |
|
@@ -213,6 +235,7 @@ def export_tensorflow( |
213 | 235 | output.parent / output, |
214 | 236 | ov_config=ov_config, |
215 | 237 | library_name=library_name, |
| 238 | + config=config, |
216 | 239 | ) |
217 | 240 | del ov_model |
218 | 241 | return input_names, output_names, True |
@@ -276,6 +299,7 @@ def export_pytorch_via_onnx( |
276 | 299 | output.parent / OV_XML_FILE_NAME if output.suffix != ".xml" else output, |
277 | 300 | ov_config=ov_config, |
278 | 301 | library_name=library_name, |
| 302 | + config=config, |
279 | 303 | ) |
280 | 304 | del ov_model |
281 | 305 | return input_names, output_names, True |
@@ -450,6 +474,7 @@ def ts_patched_forward(*args, **kwargs): |
450 | 474 | output, |
451 | 475 | ov_config=ov_config, |
452 | 476 | library_name=library_name, |
| 477 | + config=config, |
453 | 478 | ) |
454 | 479 | clear_class_registry() |
455 | 480 | del ov_model |
@@ -718,6 +743,8 @@ def export_from_model( |
718 | 743 |
|
719 | 744 | model.save_config(output) |
720 | 745 |
|
| 746 | + _set_runtime_options(models_and_export_configs, task) |
| 747 | + |
721 | 748 | export_models( |
722 | 749 | models_and_export_configs=models_and_export_configs, |
723 | 750 | output_dir=output, |
@@ -792,6 +819,19 @@ def export_tokenizer( |
792 | 819 | save_model(model, output / file_name.format(suffix)) |
793 | 820 |
|
794 | 821 |
|
| 822 | +def _add_runtime_options_to_rt_info(model: Model, options: Dict): |
| 823 | + """ |
| 824 | + Add runtime optinos |
| 825 | + """ |
| 826 | + try: |
| 827 | + for name, value in options.items(): |
| 828 | + model.set_rt_info(value, ["runtime_options", name]) |
| 829 | + except Exception: |
| 830 | + pass |
| 831 | + |
| 832 | + return model |
| 833 | + |
| 834 | + |
795 | 835 | def _add_version_info_to_model(model: Model, library_name: Optional[str] = None): |
796 | 836 | """ |
797 | 837 | Add dependency versions to OpenVINO model |
|
0 commit comments