Skip to content

Commit 99c76ff

Browse files
Add SD3.5-medium quantization support in ModelOpt Diffusers example (#444)
Signed-off-by: vipandya <[email protected]>
1 parent 2fd67cc commit 99c76ff

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

examples/diffusers/quantization/onnx_utils/export.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@
7373
"pooled_projections": {0: "batch_size"},
7474
"sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
7575
},
76+
"sd3.5-medium": {
77+
"hidden_states": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
78+
"timestep": {0: "steps"},
79+
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
80+
"pooled_projections": {0: "batch_size"},
81+
"out_hidden_states": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
82+
},
7683
"flux-dev": {
7784
"hidden_states": {0: "batch_size", 1: "latent_dim"},
7885
"encoder_hidden_states": {0: "batch_size"},
@@ -290,6 +297,8 @@ def update_dynamic_axes(model_id, dynamic_axes):
290297
dynamic_axes["out.0"] = dynamic_axes.pop("latent")
291298
elif model_id == "sd3-medium":
292299
dynamic_axes["out.0"] = dynamic_axes.pop("sample")
300+
elif model_id == "sd3.5-medium":
301+
dynamic_axes["out.0"] = dynamic_axes.pop("out_hidden_states")
293302

294303

295304
def _create_dynamic_shapes(dynamic_shapes):
@@ -313,7 +322,7 @@ def generate_dummy_inputs_and_dynamic_axes_and_shapes(model_id, backbone):
313322
dummy_input, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_sdxl(
314323
backbone, min_bs=2, opt_bs=16
315324
)
316-
elif model_id == "sd3-medium":
325+
elif model_id in ["sd3-medium", "sd3.5-medium"]:
317326
dummy_input, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_sd3(
318327
backbone, min_bs=2, opt_bs=16
319328
)
@@ -343,14 +352,16 @@ def get_io_shapes(model_id, onnx_load_path, dynamic_shapes):
343352
output_name = "latent"
344353
elif model_id in ["sd3-medium"]:
345354
output_name = "sample"
355+
elif model_id in ["sd3.5-medium"]:
356+
output_name = "out_hidden_states"
346357
elif model_id in ["flux-dev", "flux-schnell"]:
347358
output_name = "output"
348359
else:
349360
raise NotImplementedError(f"Unsupported model_id: {model_id}")
350361

351362
if model_id in ["sdxl-1.0", "sdxl-turbo"]:
352363
io_shapes = {output_name: dynamic_shapes["dynamic_shapes"]["minShapes"]["sample"]}
353-
elif model_id in ["sd3-medium"]:
364+
elif model_id in ["sd3-medium", "sd3.5-medium"]:
354365
io_shapes = {output_name: dynamic_shapes["dynamic_shapes"]["minShapes"]["hidden_states"]}
355366
elif model_id in ["flux-dev", "flux-schnell"]:
356367
io_shapes = {}
@@ -406,6 +417,9 @@ def modelopt_export_sd(backbone, onnx_dir, model_name, precision):
406417
elif model_name == "sd3-medium":
407418
input_names = ["hidden_states", "encoder_hidden_states", "pooled_projections", "timestep"]
408419
output_names = ["sample"]
420+
elif model_name == "sd3.5-medium":
421+
input_names = ["hidden_states", "encoder_hidden_states", "pooled_projections", "timestep"]
422+
output_names = ["out_hidden_states"]
409423
elif model_name in ["flux-dev", "flux-schnell"]:
410424
input_names = [
411425
"hidden_states",

examples/diffusers/quantization/quantize.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import argparse
1717
import logging
1818
import sys
19+
import time as time
1920
from collections.abc import Callable
2021
from dataclasses import dataclass
2122
from enum import Enum
@@ -59,6 +60,7 @@ class ModelType(str, Enum):
5960
SDXL_BASE = "sdxl-1.0"
6061
SDXL_TURBO = "sdxl-turbo"
6162
SD3_MEDIUM = "sd3-medium"
63+
SD35_MEDIUM = "sd3.5-medium"
6264
FLUX_DEV = "flux-dev"
6365
FLUX_SCHNELL = "flux-schnell"
6466
LTX_VIDEO_DEV = "ltx-video-dev"
@@ -114,6 +116,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
114116
ModelType.SDXL_BASE: filter_func_default,
115117
ModelType.SDXL_TURBO: filter_func_default,
116118
ModelType.SD3_MEDIUM: filter_func_default,
119+
ModelType.SD35_MEDIUM: filter_func_default,
117120
ModelType.LTX_VIDEO_DEV: filter_func_ltx_video,
118121
}
119122

@@ -125,6 +128,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
125128
ModelType.SDXL_BASE: "stabilityai/stable-diffusion-xl-base-1.0",
126129
ModelType.SDXL_TURBO: "stabilityai/sdxl-turbo",
127130
ModelType.SD3_MEDIUM: "stabilityai/stable-diffusion-3-medium-diffusers",
131+
ModelType.SD35_MEDIUM: "stabilityai/stable-diffusion-3.5-medium",
128132
ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev",
129133
ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell",
130134
ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev",
@@ -230,6 +234,7 @@ def uses_transformer(self) -> bool:
230234
"""Check if model uses transformer backbone (vs UNet)."""
231235
return self.model_type in [
232236
ModelType.SD3_MEDIUM,
237+
ModelType.SD35_MEDIUM,
233238
ModelType.FLUX_DEV,
234239
ModelType.FLUX_SCHNELL,
235240
ModelType.LTX_VIDEO_DEV,
@@ -326,7 +331,7 @@ def create_pipeline_from(
326331
model_id = (
327332
MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path
328333
)
329-
if model_type == ModelType.SD3_MEDIUM:
334+
if model_type in [ModelType.SD3_MEDIUM, ModelType.SD35_MEDIUM]:
330335
pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
331336
elif model_type in [ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL]:
332337
pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
@@ -357,7 +362,7 @@ def create_pipeline(self) -> DiffusionPipeline:
357362
self.logger.info(f"Data type: {self.config.model_dtype.value}")
358363

359364
try:
360-
if self.config.model_type == ModelType.SD3_MEDIUM:
365+
if self.config.model_type in [ModelType.SD3_MEDIUM, ModelType.SD35_MEDIUM]:
361366
self.pipe = StableDiffusion3Pipeline.from_pretrained(
362367
self.config.model_path, torch_dtype=self.config.torch_dtype
363368
)
@@ -864,6 +869,8 @@ def main() -> None:
864869
parser = create_argument_parser()
865870
args = parser.parse_args()
866871

872+
s = time.time()
873+
867874
logger = setup_logging(args.verbose)
868875
logger.info("Starting Enhanced Diffusion Model Quantization")
869876

@@ -939,9 +946,11 @@ def forward_loop(mod):
939946
backbone,
940947
model_config.model_type,
941948
quant_config.format,
942-
quantize_mha=QuantizationConfig.quantize_mha,
949+
quantize_mha=quant_config.quantize_mha,
950+
)
951+
logger.info(
952+
f"Quantization process completed successfully! Time taken = {time.time() - s} seconds"
943953
)
944-
logger.info("Quantization process completed successfully!")
945954

946955
except Exception as e:
947956
logger.error(f"Quantization failed: {e}", exc_info=True)

0 commit comments

Comments
 (0)