diff --git a/examples/diffusers/quantization/diffusion_trt.py b/examples/diffusers/quantization/diffusion_trt.py index 4db12e9c2..1a0ec852d 100644 --- a/examples/diffusers/quantization/diffusion_trt.py +++ b/examples/diffusers/quantization/diffusion_trt.py @@ -105,7 +105,11 @@ def main(): image_name = args.save_image_as if args.save_image_as else f"{args.model}.png" - pipe = PipelineManager.create_pipeline_from(MODEL_ID[args.model], dtype_map[args.model_dtype]) + pipe = PipelineManager.create_pipeline_from( + MODEL_ID[args.model], + dtype_map[args.model_dtype], + override_model_path=args.override_model_path, + ) # Save the backbone of the pipeline and move it to the GPU add_embedding = None diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 81c59392d..f94a4a1ad 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -309,7 +309,9 @@ def __init__(self, config: ModelConfig, logger: logging.Logger): @staticmethod def create_pipeline_from( - model_type: ModelType, torch_dtype: torch.dtype = torch.bfloat16 + model_type: ModelType, + torch_dtype: torch.dtype = torch.bfloat16, + override_model_path: str | None = None, ) -> DiffusionPipeline: """ Create and return an appropriate pipeline based on configuration. @@ -321,7 +323,9 @@ def create_pipeline_from( ValueError: If model type is unsupported """ try: - model_id = MODEL_REGISTRY[model_type] + model_id = ( + MODEL_REGISTRY[model_type] if override_model_path is None else override_model_path + ) if model_type == ModelType.SD3_MEDIUM: pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch_dtype) elif model_type in [ModelType.FLUX_DEV, ModelType.FLUX_SCHNELL]: diff --git a/examples/diffusers/quantization/requirements.txt b/examples/diffusers/quantization/requirements.txt index 9c9a60b86..52921fe74 100644 --- a/examples/diffusers/quantization/requirements.txt +++ b/examples/diffusers/quantization/requirements.txt @@ -1,4 +1,5 @@ cuda-python +diffusers<=0.34.0 nvtx onnx_graphsurgeon opencv-python>=4.8.1.78,<4.12.0.88