From 1fa92ca26f453e0047233786ae2a066b8b5d3b92 Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 14 Dec 2024 00:30:24 +0000 Subject: [PATCH] Add ControlNetUnion to AutoPipeline from_pretrained --- src/diffusers/pipelines/auto_pipeline.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 1d6686e64271..a0f95fe6cdc1 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -18,6 +18,7 @@ from huggingface_hub.utils import validate_hf_hub_args from ..configuration_utils import ConfigMixin +from ..models.controlnets import ControlNetUnionModel from ..utils import is_sentencepiece_available from .aura_flow import AuraFlowPipeline from .cogview3 import CogView3PlusPipeline @@ -28,6 +29,9 @@ StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline, StableDiffusionXLControlNetPipeline, + StableDiffusionXLControlNetUnionImg2ImgPipeline, + StableDiffusionXLControlNetUnionInpaintPipeline, + StableDiffusionXLControlNetUnionPipeline, ) from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline from .flux import ( @@ -108,6 +112,7 @@ ("kandinsky3", Kandinsky3Pipeline), ("stable-diffusion-controlnet", StableDiffusionControlNetPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline), + ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionPipeline), ("wuerstchen", WuerstchenCombinedPipeline), ("cascade", StableCascadeCombinedPipeline), ("lcm", LatentConsistencyModelPipeline), @@ -139,6 +144,7 @@ ("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline), ("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline), + ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline), ("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline), ("lcm", LatentConsistencyModelImg2ImgPipeline), @@ -158,6 +164,7 @@ ("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline), ("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline), + ("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionInpaintPipeline), ("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline), ("flux", FluxInpaintPipeline), ("flux-controlnet", FluxControlNetInpaintPipeline), @@ -396,7 +403,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): orig_class_name = config["_class_name"] if "controlnet" in kwargs: - orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") + if isinstance(kwargs["controlnet"], ControlNetUnionModel): + orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetUnionPipeline") + else: + orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: @@ -688,7 +698,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline" if "controlnet" in kwargs: - orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) + if isinstance(kwargs["controlnet"], ControlNetUnionModel): + orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace) + else: + orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: @@ -985,7 +998,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline" if "controlnet" in kwargs: - orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) + if isinstance(kwargs["controlnet"], ControlNetUnionModel): + orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace) + else: + orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) if "enable_pag" in kwargs: enable_pag = kwargs.pop("enable_pag") if enable_pag: