Skip to content

Commit 713bf6b

Browse files
committed
Check controlnet is a ControlNetUnionModel
1 parent c63899e commit 713bf6b

File tree

3 files changed

+16
-11
lines changed

3 files changed

+16
-11
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def __init__(
239239
tokenizer: CLIPTokenizer,
240240
tokenizer_2: CLIPTokenizer,
241241
unet: UNet2DConditionModel,
242-
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
242+
controlnet: ControlNetUnionModel,
243243
scheduler: KarrasDiffusionSchedulers,
244244
requires_aesthetics_score: bool = False,
245245
force_zeros_for_empty_prompt: bool = True,
@@ -249,6 +249,9 @@ def __init__(
249249
):
250250
super().__init__()
251251

252+
if not isinstance(controlnet, ControlNetUnionModel):
253+
raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
254+
252255
if isinstance(controlnet, (list, tuple)):
253256
controlnet = MultiControlNetModel(controlnet)
254257

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,8 @@ class StableDiffusionXLControlNetUnionPipeline(
209209
A `CLIPTokenizer` to tokenize text.
210210
unet ([`UNet2DConditionModel`]):
211211
A `UNet2DConditionModel` to denoise the encoded image latents.
212-
controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
213-
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
214-
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
215-
additional conditioning.
212+
controlnet ([`ControlNetUnionModel`]`):
213+
Provides additional conditioning to the `unet` during the denoising process.
216214
scheduler ([`SchedulerMixin`]):
217215
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
218216
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
@@ -254,7 +252,7 @@ def __init__(
254252
tokenizer: CLIPTokenizer,
255253
tokenizer_2: CLIPTokenizer,
256254
unet: UNet2DConditionModel,
257-
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
255+
controlnet: ControlNetUnionModel,
258256
scheduler: KarrasDiffusionSchedulers,
259257
force_zeros_for_empty_prompt: bool = True,
260258
add_watermarker: Optional[bool] = None,
@@ -263,6 +261,9 @@ def __init__(
263261
):
264262
super().__init__()
265263

264+
if not isinstance(controlnet, ControlNetUnionModel):
265+
raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
266+
266267
if isinstance(controlnet, (list, tuple)):
267268
controlnet = MultiControlNetModel(controlnet)
268269

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,8 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
197197
Second Tokenizer of class
198198
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
199199
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
200-
controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
201-
Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
202-
as a list, the outputs from each ControlNet are added together to create one combined additional
203-
conditioning.
200+
controlnet ([`ControlNetUnionModel`]):
201+
Provides additional conditioning to the unet during the denoising process.
204202
scheduler ([`SchedulerMixin`]):
205203
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
206204
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
@@ -245,7 +243,7 @@ def __init__(
245243
tokenizer: CLIPTokenizer,
246244
tokenizer_2: CLIPTokenizer,
247245
unet: UNet2DConditionModel,
248-
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
246+
controlnet: ControlNetUnionModel,
249247
scheduler: KarrasDiffusionSchedulers,
250248
requires_aesthetics_score: bool = False,
251249
force_zeros_for_empty_prompt: bool = True,
@@ -255,6 +253,9 @@ def __init__(
255253
):
256254
super().__init__()
257255

256+
if not isinstance(controlnet, ControlNetUnionModel):
257+
raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
258+
258259
if isinstance(controlnet, (list, tuple)):
259260
controlnet = MultiControlNetModel(controlnet)
260261

0 commit comments

Comments
 (0)