Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions examples/diffusers/quantization/onnx_utils/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@
"pooled_projections": {0: "batch_size"},
"sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
},
"sd3.5-medium": {
"hidden_states": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
"timestep": {0: "steps"},
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
"pooled_projections": {0: "batch_size"},
"out_hidden_states": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
},
"flux-dev": {
"hidden_states": {0: "batch_size", 1: "latent_dim"},
"encoder_hidden_states": {0: "batch_size"},
Expand Down Expand Up @@ -290,6 +297,8 @@ def update_dynamic_axes(model_id, dynamic_axes):
dynamic_axes["out.0"] = dynamic_axes.pop("latent")
elif model_id == "sd3-medium":
dynamic_axes["out.0"] = dynamic_axes.pop("sample")
elif model_id == "sd3.5-medium":
dynamic_axes["out.0"] = dynamic_axes.pop("out_hidden_states")


def _create_dynamic_shapes(dynamic_shapes):
Expand All @@ -313,7 +322,7 @@ def generate_dummy_inputs_and_dynamic_axes_and_shapes(model_id, backbone):
dummy_input, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_sdxl(
backbone, min_bs=2, opt_bs=16
)
elif model_id == "sd3-medium":
elif model_id in ["sd3-medium", "sd3.5-medium"]:
dummy_input, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_sd3(
backbone, min_bs=2, opt_bs=16
)
Expand Down Expand Up @@ -343,14 +352,16 @@ def get_io_shapes(model_id, onnx_load_path, dynamic_shapes):
output_name = "latent"
elif model_id in ["sd3-medium"]:
output_name = "sample"
elif model_id in ["sd3.5-medium"]:
output_name = "out_hidden_states"
elif model_id in ["flux-dev", "flux-schnell"]:
output_name = "output"
else:
raise NotImplementedError(f"Unsupported model_id: {model_id}")

if model_id in ["sdxl-1.0", "sdxl-turbo"]:
io_shapes = {output_name: dynamic_shapes["dynamic_shapes"]["minShapes"]["sample"]}
elif model_id in ["sd3-medium"]:
elif model_id in ["sd3-medium", "sd3.5-medium"]:
io_shapes = {output_name: dynamic_shapes["dynamic_shapes"]["minShapes"]["hidden_states"]}
elif model_id in ["flux-dev", "flux-schnell"]:
io_shapes = {}
Expand Down Expand Up @@ -406,6 +417,9 @@ def modelopt_export_sd(backbone, onnx_dir, model_name, precision):
elif model_name == "sd3-medium":
input_names = ["hidden_states", "encoder_hidden_states", "pooled_projections", "timestep"]
output_names = ["sample"]
elif model_name == "sd3.5-medium":
input_names = ["hidden_states", "encoder_hidden_states", "pooled_projections", "timestep"]
output_names = ["out_hidden_states"]
elif model_name in ["flux-dev", "flux-schnell"]:
input_names = [
"hidden_states",
Expand Down
17 changes: 13 additions & 4 deletions examples/diffusers/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import argparse
import logging
import sys
import time as time
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -59,6 +60,7 @@ class ModelType(str, Enum):
SDXL_BASE = "sdxl-1.0"
SDXL_TURBO = "sdxl-turbo"
SD3_MEDIUM = "sd3-medium"
SD35_MEDIUM = "sd3.5-medium"
FLUX_DEV = "flux-dev"
FLUX_SCHNELL = "flux-schnell"
LTX_VIDEO_DEV = "ltx-video-dev"
Expand Down Expand Up @@ -114,6 +116,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.SDXL_BASE: filter_func_default,
ModelType.SDXL_TURBO: filter_func_default,
ModelType.SD3_MEDIUM: filter_func_default,
ModelType.SD35_MEDIUM: filter_func_default,
ModelType.LTX_VIDEO_DEV: filter_func_ltx_video,
}

Expand All @@ -125,6 +128,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.SDXL_BASE: "stabilityai/stable-diffusion-xl-base-1.0",
ModelType.SDXL_TURBO: "stabilityai/sdxl-turbo",
ModelType.SD3_MEDIUM: "stabilityai/stable-diffusion-3-medium-diffusers",
ModelType.SD35_MEDIUM: "stabilityai/stable-diffusion-3.5-medium",
ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev",
ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell",
ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev",
Expand Down Expand Up @@ -230,6 +234,7 @@ def uses_transformer(self) -> bool:
"""Check if model uses transformer backbone (vs UNet)."""
return self.model_type in [
ModelType.SD3_MEDIUM,
ModelType.SD35_MEDIUM,
ModelType.FLUX_DEV,
ModelType.FLUX_SCHNELL,
ModelType.LTX_VIDEO_DEV,
Expand Down Expand Up @@ -326,7 +331,7 @@ def create_pipeline_from(
model_id = (
MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path
)
if model_type == ModelType.SD3_MEDIUM:
if model_type in [ModelType.SD3_MEDIUM, ModelType.SD35_MEDIUM]:
pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
elif model_type in [ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL]:
pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
Expand Down Expand Up @@ -357,7 +362,7 @@ def create_pipeline(self) -> DiffusionPipeline:
self.logger.info(f"Data type: {self.config.model_dtype.value}")

try:
if self.config.model_type == ModelType.SD3_MEDIUM:
if self.config.model_type in [ModelType.SD3_MEDIUM, ModelType.SD35_MEDIUM]:
self.pipe = StableDiffusion3Pipeline.from_pretrained(
self.config.model_path, torch_dtype=self.config.torch_dtype
)
Expand Down Expand Up @@ -864,6 +869,8 @@ def main() -> None:
parser = create_argument_parser()
args = parser.parse_args()

s = time.time()

logger = setup_logging(args.verbose)
logger.info("Starting Enhanced Diffusion Model Quantization")

Expand Down Expand Up @@ -939,9 +946,11 @@ def forward_loop(mod):
backbone,
model_config.model_type,
quant_config.format,
quantize_mha=QuantizationConfig.quantize_mha,
quantize_mha=quant_config.quantize_mha,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching the bug here!

)
logger.info(
f"Quantization process completed successfully! Time taken = {time.time() - s} seconds"
)
logger.info("Quantization process completed successfully!")

except Exception as e:
logger.error(f"Quantization failed: {e}", exc_info=True)
Expand Down