Skip to content

Commit c63899e

Browse files
committed
Check num_control_type against ControlNetUnionInput, ControlNetUnionInputProMax
1 parent c788d8e commit c63899e

File tree

4 files changed

+56
-0
lines changed

4 files changed

+56
-0
lines changed

src/diffusers/models/controlnet_union.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,20 @@ def forward(
725725
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
726726
returned where the first element is the sample tensor.
727727
"""
728+
if not isinstance(controlnet_cond, (ControlNetUnionInput, ControlNetUnionInputProMax)):
729+
raise ValueError(
730+
"Expected type of `controlnet_cond` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
731+
)
732+
if len(controlnet_cond) != self.config.num_control_type:
733+
if isinstance(controlnet_cond, ControlNetUnionInput):
734+
raise ValueError(
735+
f"Expected num_control_type {self.config.num_control_type}, got {len(controlnet_cond)}. Try `ControlNetUnionInputProMax`."
736+
)
737+
elif isinstance(controlnet_cond, ControlNetUnionInputProMax):
738+
raise ValueError(
739+
f"Expected num_control_type {self.config.num_control_type}, got {len(controlnet_cond)}. Try `ControlNetUnionInput`."
740+
)
741+
728742
# check channel order
729743
channel_order = self.config.controlnet_conditioning_channel_order
730744

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,6 +1400,20 @@ def __call__(
14001400

14011401
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
14021402

1403+
if not isinstance(control_image_list, (ControlNetUnionInput, ControlNetUnionInputProMax)):
1404+
raise ValueError(
1405+
"Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
1406+
)
1407+
if len(control_image_list) != controlnet.config.num_control_type:
1408+
if isinstance(control_image_list, ControlNetUnionInput):
1409+
raise ValueError(
1410+
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInputProMax`."
1411+
)
1412+
elif isinstance(control_image_list, ControlNetUnionInputProMax):
1413+
raise ValueError(
1414+
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInput`."
1415+
)
1416+
14031417
# align format for control guidance
14041418
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
14051419
control_guidance_start = len(control_guidance_end) * [control_guidance_start]

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,20 @@ def __call__(
12321232

12331233
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
12341234

1235+
if not isinstance(image_list, (ControlNetUnionInput, ControlNetUnionInputProMax)):
1236+
raise ValueError(
1237+
"Expected type of `image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
1238+
)
1239+
if len(image_list) != controlnet.config.num_control_type:
1240+
if isinstance(image_list, ControlNetUnionInput):
1241+
raise ValueError(
1242+
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(image_list)}. Try `ControlNetUnionInputProMax`."
1243+
)
1244+
elif isinstance(image_list, ControlNetUnionInputProMax):
1245+
raise ValueError(
1246+
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(image_list)}. Try `ControlNetUnionInput`."
1247+
)
1248+
12351249
# align format for control guidance
12361250
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
12371251
control_guidance_start = len(control_guidance_end) * [control_guidance_start]

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,6 +1323,20 @@ def __call__(
13231323

13241324
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
13251325

1326+
if not isinstance(control_image_list, (ControlNetUnionInput, ControlNetUnionInputProMax)):
1327+
raise ValueError(
1328+
"Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
1329+
)
1330+
if len(control_image_list) != controlnet.config.num_control_type:
1331+
if isinstance(control_image_list, ControlNetUnionInput):
1332+
raise ValueError(
1333+
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInputProMax`."
1334+
)
1335+
elif isinstance(control_image_list, ControlNetUnionInputProMax):
1336+
raise ValueError(
1337+
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInput`."
1338+
)
1339+
13261340
# align format for control guidance
13271341
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
13281342
control_guidance_start = len(control_guidance_end) * [control_guidance_start]

0 commit comments

Comments
 (0)