Skip to content

Commit 78a74ce

Browse files
eaidovaIlyasMoutawwakilecharlaix
authored
Add OpenVINO sana support (#1106)
* support sana text2image * add pipeline * update tests * add variant for model loading in from_transformers * Update optimum/exporters/openvino/__main__.py * provide missed params to data-aware cli * apply review comments * rename weights_variant to variant * Apply suggestions from code review * Apply suggestions from code review * autocls with if else in tests * Update tests/openvino/test_diffusion.py * Apply suggestions from code review Co-authored-by: Ilyas Moutawwakil <[email protected]> * Update optimum/intel/openvino/modeling_decoder.py Co-authored-by: Ella Charlaix <[email protected]> * add tests for variant --------- Co-authored-by: Ilyas Moutawwakil <[email protected]> Co-authored-by: Ella Charlaix <[email protected]>
1 parent 3d7aba4 commit 78a74ce

19 files changed

+339
-27
lines changed

optimum/commands/export/openvino.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ def parse_args_openvino(parser: "ArgumentParser"):
105105
"This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it."
106106
),
107107
)
108+
optional_group.add_argument(
109+
"--variant",
110+
type=str,
111+
default=None,
112+
help=("If specified load weights from variant filename."),
113+
)
108114
optional_group.add_argument(
109115
"--ratio",
110116
type=float,
@@ -415,6 +421,10 @@ def run(self):
415421
from optimum.intel import OVFluxPipeline
416422

417423
model_cls = OVFluxPipeline
424+
elif class_name == "SanaPipeline":
425+
from optimum.intel import OVSanaPipeline
426+
427+
model_cls = OVSanaPipeline
418428
else:
419429
raise NotImplementedError(f"Quantization in hybrid mode isn't supported for class {class_name}.")
420430

@@ -447,6 +457,8 @@ def run(self):
447457
quantization_config=quantization_config,
448458
stateful=not self.args.disable_stateful,
449459
trust_remote_code=self.args.trust_remote_code,
460+
variant=self.args.variant,
461+
cache_dir=self.args.cache_dir,
450462
)
451463
model.save_pretrained(self.args.output)
452464

@@ -468,5 +480,6 @@ def run(self):
468480
stateful=not self.args.disable_stateful,
469481
convert_tokenizer=not self.args.disable_convert_tokenizer,
470482
library_name=library_name,
483+
variant=self.args.variant,
471484
# **input_shapes,
472485
)

optimum/exporters/openvino/__main__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def main_export(
122122
convert_tokenizer: bool = False,
123123
library_name: Optional[str] = None,
124124
model_loading_kwargs: Optional[Dict[str, Any]] = None,
125+
variant: Optional[str] = None,
125126
**kwargs_shapes,
126127
):
127128
"""
@@ -237,6 +238,8 @@ def main_export(
237238
custom_architecture = False
238239
patch_16bit = False
239240
loading_kwargs = model_loading_kwargs or {}
241+
if variant is not None:
242+
loading_kwargs["variant"] = variant
240243
if library_name == "transformers":
241244
config = AutoConfig.from_pretrained(
242245
model_name_or_path,
@@ -347,6 +350,7 @@ class StoreAttr(object):
347350

348351
GPTQQuantizer.post_init_model = post_init_model
349352
elif library_name == "diffusers" and is_openvino_version(">=", "2024.6"):
353+
_loading_kwargs = {} if variant is None else {"variant": variant}
350354
dtype = deduce_diffusers_dtype(
351355
model_name_or_path,
352356
revision=revision,
@@ -355,6 +359,7 @@ class StoreAttr(object):
355359
local_files_only=local_files_only,
356360
force_download=force_download,
357361
trust_remote_code=trust_remote_code,
362+
**_loading_kwargs,
358363
)
359364
if dtype in [torch.float16, torch.bfloat16]:
360365
loading_kwargs["torch_dtype"] = dtype

optimum/exporters/openvino/convert.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,7 @@ def get_diffusion_models_for_export_ext(
10161016
is_sdxl = pipeline.__class__.__name__.startswith("StableDiffusionXL")
10171017
is_sd3 = pipeline.__class__.__name__.startswith("StableDiffusion3")
10181018
is_flux = pipeline.__class__.__name__.startswith("Flux")
1019+
is_sana = pipeline.__class__.__name__.startswith("Sana")
10191020
is_sd = pipeline.__class__.__name__.startswith("StableDiffusion") and not is_sd3
10201021
is_lcm = pipeline.__class__.__name__.startswith("LatentConsistencyModel")
10211022

@@ -1034,11 +1035,78 @@ def get_diffusion_models_for_export_ext(
10341035
models_for_export = get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype)
10351036
elif is_flux:
10361037
models_for_export = get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype)
1038+
elif is_sana:
1039+
models_for_export = get_sana_models_for_export(pipeline, exporter, int_dtype, float_dtype)
10371040
else:
10381041
raise ValueError(f"Unsupported pipeline type `{pipeline.__class__.__name__}` provided")
10391042
return None, models_for_export
10401043

10411044

1045+
def get_sana_models_for_export(pipeline, exporter, int_dtype, float_dtype):
1046+
models_for_export = {}
1047+
text_encoder = pipeline.text_encoder
1048+
text_encoder_config_constructor = TasksManager.get_exporter_config_constructor(
1049+
model=text_encoder,
1050+
exporter=exporter,
1051+
library_name="diffusers",
1052+
task="feature-extraction",
1053+
model_type="gemma2-text-encoder",
1054+
)
1055+
text_encoder_export_config = text_encoder_config_constructor(
1056+
pipeline.text_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
1057+
)
1058+
text_encoder_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
1059+
models_for_export["text_encoder"] = (text_encoder, text_encoder_export_config)
1060+
transformer = pipeline.transformer
1061+
transformer.config.text_encoder_projection_dim = transformer.config.caption_channels
1062+
transformer.config.requires_aesthetics_score = False
1063+
transformer.config.time_cond_proj_dim = None
1064+
export_config_constructor = TasksManager.get_exporter_config_constructor(
1065+
model=transformer,
1066+
exporter=exporter,
1067+
library_name="diffusers",
1068+
task="semantic-segmentation",
1069+
model_type="sana-transformer",
1070+
)
1071+
transformer_export_config = export_config_constructor(
1072+
pipeline.transformer.config, int_dtype=int_dtype, float_dtype=float_dtype
1073+
)
1074+
models_for_export["transformer"] = (transformer, transformer_export_config)
1075+
# VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
1076+
vae_encoder = copy.deepcopy(pipeline.vae)
1077+
vae_encoder.forward = lambda sample: {"latent": vae_encoder.encode(x=sample)["latent"]}
1078+
vae_config_constructor = TasksManager.get_exporter_config_constructor(
1079+
model=vae_encoder,
1080+
exporter=exporter,
1081+
library_name="diffusers",
1082+
task="semantic-segmentation",
1083+
model_type="dcae-encoder",
1084+
)
1085+
vae_encoder_export_config = vae_config_constructor(
1086+
vae_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
1087+
)
1088+
vae_encoder_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
1089+
models_for_export["vae_encoder"] = (vae_encoder, vae_encoder_export_config)
1090+
1091+
# VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600
1092+
vae_decoder = copy.deepcopy(pipeline.vae)
1093+
vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample)
1094+
vae_config_constructor = TasksManager.get_exporter_config_constructor(
1095+
model=vae_decoder,
1096+
exporter=exporter,
1097+
library_name="diffusers",
1098+
task="semantic-segmentation",
1099+
model_type="vae-decoder",
1100+
)
1101+
vae_decoder_export_config = vae_config_constructor(
1102+
vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype
1103+
)
1104+
vae_decoder_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
1105+
models_for_export["vae_decoder"] = (vae_decoder, vae_decoder_export_config)
1106+
1107+
return models_for_export
1108+
1109+
10421110
def get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype):
10431111
models_for_export = {}
10441112

optimum/exporters/openvino/model_configs.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
PhiOnnxConfig,
4242
T5OnnxConfig,
4343
UNetOnnxConfig,
44+
VaeEncoderOnnxConfig,
4445
VisionOnnxConfig,
4546
WhisperOnnxConfig,
4647
)
@@ -106,6 +107,7 @@
106107
Qwen2VLVisionEmbMergerPatcher,
107108
QwenModelPatcher,
108109
RotaryEmbPatcher,
110+
SanaTextEncoderModelPatcher,
109111
StatefulSeq2SeqDecoderPatcher,
110112
UpdateCausalMaskModelPatcher,
111113
XverseModelPatcher,
@@ -134,6 +136,8 @@ def init_model_configs():
134136
if is_diffusers_available() and "fill" not in TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS:
135137
TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS["fill"] = "FluxFillPipeline"
136138
TasksManager._DIFFUSERS_TASKS_TO_MODEL_MAPPINGS["fill"] = {"flux": "FluxFillPipeline"}
139+
TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS["text-to-image"] = ("AutoPipelineForText2Image", "SanaPipeline")
140+
TasksManager._DIFFUSERS_TASKS_TO_MODEL_MAPPINGS["text-to-image"]["sana"] = "SanaPipeline"
137141

138142
supported_model_types = [
139143
"_SUPPORTED_MODEL_TYPE",
@@ -1896,6 +1900,83 @@ class T5EncoderOpenVINOConfig(CLIPTextOpenVINOConfig):
18961900
pass
18971901

18981902

1903+
@register_in_tasks_manager("gemma2-text-encoder", *["feature-extraction"], library_name="diffusers")
1904+
class Gemma2TextEncoderOpenVINOConfig(CLIPTextOpenVINOConfig):
1905+
@property
1906+
def inputs(self) -> Dict[str, Dict[int, str]]:
1907+
return {
1908+
"input_ids": {0: "batch_size", 1: "sequence_length"},
1909+
"attention_mask": {0: "batch_size", 1: "sequence_length"},
1910+
}
1911+
1912+
def patch_model_for_export(
1913+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
1914+
) -> ModelPatcher:
1915+
return SanaTextEncoderModelPatcher(self, model, model_kwargs)
1916+
1917+
1918+
class DummySanaSeq2SeqDecoderTextWithEncMaskInputGenerator(DummySeq2SeqDecoderTextInputGenerator):
1919+
SUPPORTED_INPUT_NAMES = (
1920+
"decoder_input_ids",
1921+
"decoder_attention_mask",
1922+
"encoder_outputs",
1923+
"encoder_hidden_states",
1924+
"encoder_attention_mask",
1925+
)
1926+
1927+
1928+
class DummySanaTransformerVisionInputGenerator(DummyUnetVisionInputGenerator):
1929+
def __init__(
1930+
self,
1931+
task: str,
1932+
normalized_config: NormalizedVisionConfig,
1933+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
1934+
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
1935+
width: int = DEFAULT_DUMMY_SHAPES["width"] // 8,
1936+
height: int = DEFAULT_DUMMY_SHAPES["height"] // 8,
1937+
# Reduce img shape by 4 for FLUX to reduce memory usage on conversion
1938+
**kwargs,
1939+
):
1940+
super().__init__(task, normalized_config, batch_size, num_channels, width=width, height=height, **kwargs)
1941+
1942+
1943+
@register_in_tasks_manager("sana-transformer", *["semantic-segmentation"], library_name="diffusers")
1944+
class SanaTransformerOpenVINOConfig(UNetOpenVINOConfig):
1945+
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
1946+
image_size="sample_size",
1947+
num_channels="in_channels",
1948+
hidden_size="caption_channels",
1949+
vocab_size="attention_head_dim",
1950+
allow_new=True,
1951+
)
1952+
DUMMY_INPUT_GENERATOR_CLASSES = (
1953+
DummySanaTransformerVisionInputGenerator,
1954+
DummySanaSeq2SeqDecoderTextWithEncMaskInputGenerator,
1955+
) + UNetOpenVINOConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:-1]
1956+
1957+
@property
1958+
def inputs(self):
1959+
common_inputs = super().inputs
1960+
common_inputs["encoder_attention_mask"] = {0: "batch_size", 1: "sequence_length"}
1961+
return common_inputs
1962+
1963+
def rename_ambiguous_inputs(self, inputs):
1964+
# The input name in the model signature is `x, hence the export input name is updated.
1965+
hidden_states = inputs.pop("sample", None)
1966+
if hidden_states is not None:
1967+
inputs["hidden_states"] = hidden_states
1968+
return inputs
1969+
1970+
1971+
@register_in_tasks_manager("dcae-encoder", *["semantic-segmentation"], library_name="diffusers")
1972+
class DcaeEncoderOpenVINOConfig(VaeEncoderOnnxConfig):
1973+
@property
1974+
def outputs(self) -> Dict[str, Dict[int, str]]:
1975+
return {
1976+
"latent": {0: "batch_size", 2: "height_latent", 3: "width_latent"},
1977+
}
1978+
1979+
18991980
class DummyFluxTransformerInputGenerator(DummyVisionInputGenerator):
19001981
SUPPORTED_INPUT_NAMES = (
19011982
"pixel_values",

optimum/exporters/openvino/model_patcher.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
2626
from transformers.utils import is_tf_available
2727

28+
from optimum.exporters.onnx.base import OnnxConfig
2829
from optimum.exporters.onnx.model_patcher import (
2930
DecoderModelPatcher,
3031
ModelPatcher,
@@ -115,18 +116,20 @@ def patch_model_with_bettertransformer(model):
115116
return model
116117

117118

118-
def patch_update_causal_mask(model, transformers_version, inner_model_name="model", patch_fn=None):
119+
def patch_update_causal_mask(
120+
model, transformers_version, inner_model_name="model", patch_fn=None, patch_extrnal_model=False
121+
):
119122
if is_transformers_version(">=", transformers_version):
120-
inner_model = getattr(model, inner_model_name, None)
123+
inner_model = getattr(model, inner_model_name, None) if not patch_extrnal_model else model
121124
if inner_model is not None:
122125
if hasattr(inner_model, "_update_causal_mask"):
123126
inner_model._orig_update_causal_mask = inner_model._update_causal_mask
124127
patch_fn = patch_fn or _llama_gemma_update_causal_mask
125128
inner_model._update_causal_mask = types.MethodType(patch_fn, inner_model)
126129

127130

128-
def unpatch_update_causal_mask(model, inner_model_name="model"):
129-
inner_model = getattr(model, inner_model_name, None)
131+
def unpatch_update_causal_mask(model, inner_model_name="model", patch_extrnal_model=False):
132+
inner_model = getattr(model, inner_model_name, None) if not patch_extrnal_model else model
130133
if inner_model is not None and hasattr(inner_model, "._orig_update_causal_mask"):
131134
inner_model._update_causal_mask = inner_model._orig_update_causal_mask
132135

@@ -3872,3 +3875,29 @@ def patched_forward(*args, **kwargs):
38723875
model.forward = patched_forward
38733876

38743877
super().__init__(config, model, model_kwargs)
3878+
3879+
3880+
class SanaTextEncoderModelPatcher(ModelPatcher):
3881+
def __enter__(self):
3882+
super().__enter__()
3883+
patch_update_causal_mask(self._model, "4.39.0", None, patch_extrnal_model=True)
3884+
3885+
if self._model.config._attn_implementation != "sdpa":
3886+
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
3887+
self._model.config._attn_implementation = "sdpa"
3888+
if is_transformers_version("<", "4.47.0"):
3889+
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_CLASSES
3890+
3891+
sdpa_attn = GEMMA2_ATTENTION_CLASSES["sdpa"]
3892+
for layer in self._model.layers:
3893+
layer.self_attn._orig_forward = layer.self_attn.forward
3894+
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
3895+
3896+
def __exit__(self, exc_type, exc_value, traceback):
3897+
super().__exit__(exc_type, exc_value, traceback)
3898+
unpatch_update_causal_mask(self._model, None, True)
3899+
if hasattr(self._model.config, "_orig_attn_implementation"):
3900+
self._model.config._attn_implementation = self._model.config._orig_attn_implementation
3901+
for layer in self._model.layers:
3902+
if hasattr(layer.self_attn, "_orig_forward"):
3903+
layer.self_attn.forward = layer.self_attn._orig_forward

optimum/exporters/openvino/utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,15 @@ def deduce_diffusers_dtype(model_name_or_path, **loading_kwargs):
257257
model_part_name = "unet"
258258
if model_part_name:
259259
directory = path / model_part_name
260-
safetensors_files = [
261-
filename for filename in directory.glob("*.safetensors") if len(filename.suffixes) == 1
262-
]
260+
261+
pattern = "*.safetensors"
262+
if "variant" in loading_kwargs:
263+
variant = loading_kwargs["variant"]
264+
pattern = f"*.{variant}.safetensors"
265+
safetensors_files = list(directory.glob(pattern))
266+
else:
267+
# filter out variant files
268+
safetensors_files = [filename for filename in directory.glob(pattern) if len(filename.suffixes) == 1]
263269
safetensors_file = None
264270
if len(safetensors_files) > 0:
265271
safetensors_file = safetensors_files.pop(0)

optimum/intel/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
"OVFluxImg2ImgPipeline",
128128
"OVFluxInpaintPipeline",
129129
"OVFluxFillPipeline",
130+
"OVSanaPipeline",
130131
"OVPipelineForImage2Image",
131132
"OVPipelineForText2Image",
132133
"OVPipelineForInpainting",
@@ -150,6 +151,7 @@
150151
"OVFluxImg2ImgPipeline",
151152
"OVFluxInpaintPipeline",
152153
"OVFluxFillPipeline",
154+
"OVSanaPipeline",
153155
"OVPipelineForImage2Image",
154156
"OVPipelineForText2Image",
155157
"OVPipelineForInpainting",
@@ -303,6 +305,7 @@
303305
OVPipelineForImage2Image,
304306
OVPipelineForInpainting,
305307
OVPipelineForText2Image,
308+
OVSanaPipeline,
306309
OVStableDiffusion3Img2ImgPipeline,
307310
OVStableDiffusion3InpaintPipeline,
308311
OVStableDiffusion3Pipeline,
@@ -321,6 +324,7 @@
321324
OVPipelineForImage2Image,
322325
OVPipelineForInpainting,
323326
OVPipelineForText2Image,
327+
OVSanaPipeline,
324328
OVStableDiffusion3Img2ImgPipeline,
325329
OVStableDiffusion3InpaintPipeline,
326330
OVStableDiffusion3Pipeline,

optimum/intel/openvino/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
OVPipelineForImage2Image,
9292
OVPipelineForInpainting,
9393
OVPipelineForText2Image,
94+
OVSanaPipeline,
9495
OVStableDiffusion3Img2ImgPipeline,
9596
OVStableDiffusion3InpaintPipeline,
9697
OVStableDiffusion3Pipeline,

0 commit comments

Comments
 (0)