diff --git a/examples/diffusers/quantization/onnx_utils/export.py b/examples/diffusers/quantization/onnx_utils/export.py index f7d325b4f..a0abad37a 100644 --- a/examples/diffusers/quantization/onnx_utils/export.py +++ b/examples/diffusers/quantization/onnx_utils/export.py @@ -73,6 +73,13 @@ "pooled_projections": {0: "batch_size"}, "sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, }, + "sd3.5-medium": { + "hidden_states": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, + "timestep": {0: "steps"}, + "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, + "pooled_projections": {0: "batch_size"}, + "out_hidden_states": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, + }, "flux-dev": { "hidden_states": {0: "batch_size", 1: "latent_dim"}, "encoder_hidden_states": {0: "batch_size"}, @@ -290,6 +297,8 @@ def update_dynamic_axes(model_id, dynamic_axes): dynamic_axes["out.0"] = dynamic_axes.pop("latent") elif model_id == "sd3-medium": dynamic_axes["out.0"] = dynamic_axes.pop("sample") + elif model_id == "sd3.5-medium": + dynamic_axes["out.0"] = dynamic_axes.pop("out_hidden_states") def _create_dynamic_shapes(dynamic_shapes): @@ -313,7 +322,7 @@ def generate_dummy_inputs_and_dynamic_axes_and_shapes(model_id, backbone): dummy_input, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_sdxl( backbone, min_bs=2, opt_bs=16 ) - elif model_id == "sd3-medium": + elif model_id in ["sd3-medium", "sd3.5-medium"]: dummy_input, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_sd3( backbone, min_bs=2, opt_bs=16 ) @@ -343,6 +352,8 @@ def get_io_shapes(model_id, onnx_load_path, dynamic_shapes): output_name = "latent" elif model_id in ["sd3-medium"]: output_name = "sample" + elif model_id in ["sd3.5-medium"]: + output_name = "out_hidden_states" elif model_id in ["flux-dev", "flux-schnell"]: output_name = "output" else: @@ -350,7 +361,7 @@ def get_io_shapes(model_id, onnx_load_path, dynamic_shapes): if model_id in ["sdxl-1.0", "sdxl-turbo"]: io_shapes = {output_name: dynamic_shapes["dynamic_shapes"]["minShapes"]["sample"]} - elif model_id in ["sd3-medium"]: + elif model_id in ["sd3-medium", "sd3.5-medium"]: io_shapes = {output_name: dynamic_shapes["dynamic_shapes"]["minShapes"]["hidden_states"]} elif model_id in ["flux-dev", "flux-schnell"]: io_shapes = {} @@ -406,6 +417,9 @@ def modelopt_export_sd(backbone, onnx_dir, model_name, precision): elif model_name == "sd3-medium": input_names = ["hidden_states", "encoder_hidden_states", "pooled_projections", "timestep"] output_names = ["sample"] + elif model_name == "sd3.5-medium": + input_names = ["hidden_states", "encoder_hidden_states", "pooled_projections", "timestep"] + output_names = ["out_hidden_states"] elif model_name in ["flux-dev", "flux-schnell"]: input_names = [ "hidden_states", diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index f94a4a1ad..fb1fd13a1 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -16,6 +16,7 @@ import argparse import logging import sys +import time as time from collections.abc import Callable from dataclasses import dataclass from enum import Enum @@ -59,6 +60,7 @@ class ModelType(str, Enum): SDXL_BASE = "sdxl-1.0" SDXL_TURBO = "sdxl-turbo" SD3_MEDIUM = "sd3-medium" + SD35_MEDIUM = "sd3.5-medium" FLUX_DEV = "flux-dev" FLUX_SCHNELL = "flux-schnell" LTX_VIDEO_DEV = "ltx-video-dev" @@ -114,6 +116,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.SDXL_BASE: filter_func_default, ModelType.SDXL_TURBO: filter_func_default, ModelType.SD3_MEDIUM: filter_func_default, + ModelType.SD35_MEDIUM: filter_func_default, ModelType.LTX_VIDEO_DEV: filter_func_ltx_video, } @@ -125,6 +128,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.SDXL_BASE: "stabilityai/stable-diffusion-xl-base-1.0", ModelType.SDXL_TURBO: "stabilityai/sdxl-turbo", ModelType.SD3_MEDIUM: "stabilityai/stable-diffusion-3-medium-diffusers", + ModelType.SD35_MEDIUM: "stabilityai/stable-diffusion-3.5-medium", ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev", ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell", ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev", @@ -230,6 +234,7 @@ def uses_transformer(self) -> bool: """Check if model uses transformer backbone (vs UNet).""" return self.model_type in [ ModelType.SD3_MEDIUM, + ModelType.SD35_MEDIUM, ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL, ModelType.LTX_VIDEO_DEV, @@ -326,7 +331,7 @@ def create_pipeline_from( model_id = ( MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path ) - if model_type == ModelType.SD3_MEDIUM: + if model_type in [ModelType.SD3_MEDIUM, ModelType.SD35_MEDIUM]: pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch_dtype) elif model_type in [ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL]: pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch_dtype) @@ -357,7 +362,7 @@ def create_pipeline(self) -> DiffusionPipeline: self.logger.info(f"Data type: {self.config.model_dtype.value}") try: - if self.config.model_type == ModelType.SD3_MEDIUM: + if self.config.model_type in [ModelType.SD3_MEDIUM, ModelType.SD35_MEDIUM]: self.pipe = StableDiffusion3Pipeline.from_pretrained( self.config.model_path, torch_dtype=self.config.torch_dtype ) @@ -864,6 +869,8 @@ def main() -> None: parser = create_argument_parser() args = parser.parse_args() + s = time.time() + logger = setup_logging(args.verbose) logger.info("Starting Enhanced Diffusion Model Quantization") @@ -939,9 +946,11 @@ def forward_loop(mod): backbone, model_config.model_type, quant_config.format, - quantize_mha=QuantizationConfig.quantize_mha, + quantize_mha=quant_config.quantize_mha, + ) + logger.info( + f"Quantization process completed successfully! Time taken = {time.time() - s} seconds" ) - logger.info("Quantization process completed successfully!") except Exception as e: logger.error(f"Quantization failed: {e}", exc_info=True)