Skip to content

Commit 4aadc6a

Browse files
committed
support sana text2image
1 parent 190ae87 commit 4aadc6a

File tree

2 files changed

+122
-1
lines changed

2 files changed

+122
-1
lines changed

optimum/exporters/openvino/convert.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1023,17 +1023,89 @@ def get_diffusion_models_for_export_ext(
10231023
is_flux = isinstance(pipeline, tuple(flux_pipes))
10241024
else:
10251025
is_flux = False
1026+
1027+
try:
1028+
from diffusers import SanaPipeline
1029+
is_sana = isinstance(pipeline, SanaPipeline)
1030+
except ImportError:
1031+
is_sana = False
10261032

1027-
if not is_sd3 and not is_flux:
1033+
if not any([is_sana, is_flux, is_sd3]):
10281034
return None, get_diffusion_models_for_export(pipeline, int_dtype, float_dtype, exporter)
10291035
if is_sd3:
10301036
models_for_export = get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype)
1037+
if is_sana:
1038+
models_for_export = get_sana_models_for_export(pipeline, exporter, int_dtype, float_dtype)
10311039
else:
10321040
models_for_export = get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype)
10331041

10341042
return None, models_for_export
10351043

10361044

1045+
def get_sana_models_for_export(pipeline, exporter, int_dtype, float_dtype):
1046+
DEFAULT_DUMMY_SHAPES["heigh"] = DEFAULT_DUMMY_SHAPES["height"] // 4
1047+
DEFAULT_DUMMY_SHAPES["width"] = DEFAULT_DUMMY_SHAPES["width"] // 4
1048+
models_for_export = {}
1049+
text_encoder = pipeline.text_encoder
1050+
text_encoder_config_constructor = TasksManager.get_exporter_config_constructor(
1051+
model=text_encoder,
1052+
exporter=exporter,
1053+
library_name="diffusers",
1054+
task="feature-extraction",
1055+
model_type="gemma2-text-encoder",
1056+
)
1057+
text_encoder_export_config = text_encoder_config_constructor(
1058+
pipeline.text_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
1059+
)
1060+
models_for_export["text_encoder"] = (text_encoder, text_encoder_export_config)
1061+
transformer = pipeline.transformer
1062+
transformer.config.text_encoder_projection_dim = transformer.config.caption_channels
1063+
transformer.config.requires_aesthetics_score = False
1064+
transformer.config.time_cond_proj_dim = None
1065+
export_config_constructor = TasksManager.get_exporter_config_constructor(
1066+
model=transformer,
1067+
exporter=exporter,
1068+
library_name="diffusers",
1069+
task="semantic-segmentation",
1070+
model_type="sana-transformer",
1071+
)
1072+
transformer_export_config = export_config_constructor(
1073+
pipeline.transformer.config, int_dtype=int_dtype, float_dtype=float_dtype
1074+
)
1075+
models_for_export["transformer"] = (transformer, transformer_export_config)
1076+
# VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
1077+
vae_encoder = copy.deepcopy(pipeline.vae)
1078+
vae_encoder.forward = lambda sample: {"latent_parameters": vae_encoder.encode(x=sample)["latent_dist"].parameters}
1079+
vae_config_constructor = TasksManager.get_exporter_config_constructor(
1080+
model=vae_encoder,
1081+
exporter=exporter,
1082+
library_name="diffusers",
1083+
task="semantic-segmentation",
1084+
model_type="vae-encoder",
1085+
)
1086+
vae_encoder_export_config = vae_config_constructor(
1087+
vae_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
1088+
)
1089+
models_for_export["vae_encoder"] = (vae_encoder, vae_encoder_export_config)
1090+
1091+
# VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600
1092+
vae_decoder = copy.deepcopy(pipeline.vae)
1093+
vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample)
1094+
vae_config_constructor = TasksManager.get_exporter_config_constructor(
1095+
model=vae_decoder,
1096+
exporter=exporter,
1097+
library_name="diffusers",
1098+
task="semantic-segmentation",
1099+
model_type="vae-decoder",
1100+
)
1101+
vae_decoder_export_config = vae_config_constructor(
1102+
vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype
1103+
)
1104+
models_for_export["vae_decoder"] = (vae_decoder, vae_decoder_export_config)
1105+
1106+
return models_for_export
1107+
1108+
10371109
def get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype):
10381110
models_for_export = {}
10391111

optimum/exporters/openvino/model_configs.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
DummyVisionInputGenerator,
5555
FalconDummyPastKeyValuesGenerator,
5656
MistralDummyPastKeyValuesGenerator,
57+
DummySeq2SeqDecoderTextInputGenerator
5758
)
5859
from optimum.utils.normalized_config import NormalizedConfig, NormalizedTextConfig, NormalizedVisionConfig
5960

@@ -129,6 +130,8 @@ def init_model_configs():
129130
if is_diffusers_available() and "fill" not in TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS:
130131
TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS["fill"] = "FluxFillPipeline"
131132
TasksManager._DIFFUSERS_TASKS_TO_MODEL_MAPPINGS["fill"] = {"flux": "FluxFillPipeline"}
133+
TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS["text-to-image"] = ("AutoPipelineForText2Image", "SanaPipeline")
134+
TasksManager._DIFFUSERS_TASKS_TO_MODEL_MAPPINGS["text-to-image"]["sana"] = "SanaPipeline"
132135

133136
supported_model_types = [
134137
"_SUPPORTED_MODEL_TYPE",
@@ -1886,6 +1889,52 @@ def rename_ambiguous_inputs(self, inputs):
18861889
class T5EncoderOpenVINOConfig(CLIPTextOpenVINOConfig):
18871890
pass
18881891

1892+
@register_in_tasks_manager("gemma2-text-encoder", *["feature-extraction"], library_name="diffusers")
1893+
class Gemma2TextEncoderOpenVINOConfig(CLIPTextOpenVINOConfig):
1894+
@property
1895+
def inputs(self) -> Dict[str, Dict[int, str]]:
1896+
return {
1897+
"input_ids": {0: "batch_size", 1: "sequence_length"},
1898+
"attention_mask": {0: "batch_size", 1: "sequence_length"}
1899+
}
1900+
1901+
1902+
class DummySeq2SeqDecoderTextWithEncMaskInputGenerator(DummySeq2SeqDecoderTextInputGenerator):
1903+
SUPPORTED_INPUT_NAMES = (
1904+
"decoder_input_ids",
1905+
"decoder_attention_mask",
1906+
"encoder_outputs",
1907+
"encoder_hidden_states",
1908+
"encoder_attention_mask"
1909+
)
1910+
1911+
1912+
class DummySanaTransformerVisionInputGenerator(DummyVisionInputGenerator):
1913+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
1914+
if input_name not in ["sample", "latent_sample"]:
1915+
return super().generate(input_name, framework, int_dtype, float_dtype)
1916+
return self.random_float_tensor(
1917+
shape=[self.batch_size, self.num_channels, self.height, self.width],
1918+
framework=framework,
1919+
dtype=float_dtype,
1920+
)
1921+
1922+
@register_in_tasks_manager("sana-transformer", *["semantic-segmentation"], library_name="diffusers")
1923+
class SanaTransformerOpenVINOConfig(UNetOpenVINOConfig):
1924+
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
1925+
image_size="sample_size",
1926+
num_channels="in_channels",
1927+
hidden_size="cross_attention_dim",
1928+
vocab_size="attention_head_dim",
1929+
allow_new=True,
1930+
)
1931+
DUMMY_INPUT_GENERATOR_CLASSES = (DummySanaTransformerVisionInputGenerator, DummySeq2SeqDecoderTextWithEncMaskInputGenerator) + UNetOpenVINOConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:-1]
1932+
@property
1933+
def inputs(self):
1934+
common_inputs = super().inputs
1935+
common_inputs["encoder_attention_mask"] = {0: "batch_size", 1: "sequence_length"}
1936+
return common_inputs
1937+
18891938

18901939
class DummyFluxTransformerInputGenerator(DummyVisionInputGenerator):
18911940
SUPPORTED_INPUT_NAMES = (

0 commit comments

Comments
 (0)