Skip to content

Commit c454b00

Browse files
authored
Expand activation scaling to other submodels of diffusion pipelines (#1039)
* Expand activation scaling to other submodels * Applied comments
1 parent f6a7b83 commit c454b00

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

optimum/exporters/openvino/convert.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,11 @@ def _set_runtime_options(
9898
Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin", "DiffusionPipeline"], "OnnxConfig"],
9999
],
100100
task: str,
101+
library_name: str,
101102
):
102103
for model_name in models_and_export_configs.keys():
103104
_, sub_export_config = models_and_export_configs[model_name]
104-
if "vae_" in model_name or "text-generation" in task:
105+
if "diffusers" in library_name or "text-generation" in task:
105106
sub_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
106107

107108

@@ -754,7 +755,7 @@ def export_from_model(
754755

755756
model.save_config(output)
756757

757-
_set_runtime_options(models_and_export_configs, task)
758+
_set_runtime_options(models_and_export_configs, task, library_name)
758759

759760
export_models(
760761
models_and_export_configs=models_and_export_configs,

tests/openvino/test_export.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,26 @@ def _openvino_export(
140140
self.assertTrue(
141141
ov_model.vae_decoder.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
142142
)
143+
if hasattr(ov_model, "text_encoder") and ov_model.text_encoder:
144+
self.assertTrue(
145+
ov_model.text_encoder.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
146+
)
147+
if hasattr(ov_model, "text_encoder_2") and ov_model.text_encoder_2:
148+
self.assertTrue(
149+
ov_model.text_encoder_2.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
150+
)
151+
if hasattr(ov_model, "text_encoder_3") and ov_model.text_encoder_3:
152+
self.assertTrue(
153+
ov_model.text_encoder_3.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
154+
)
155+
if hasattr(ov_model, "unet") and ov_model.unet:
156+
self.assertTrue(
157+
ov_model.unet.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
158+
)
159+
if hasattr(ov_model, "transformer") and ov_model.transformer:
160+
self.assertTrue(
161+
ov_model.transformer.model.has_rt_info(["runtime_options", "ACTIVATIONS_SCALE_FACTOR"])
162+
)
143163

144164
@parameterized.expand(SUPPORTED_ARCHITECTURES)
145165
def test_export(self, model_type: str):

0 commit comments

Comments
 (0)