Skip to content

Commit 7160506

Browse files
committed
fix
1 parent b7581a7 commit 7160506

File tree

1 file changed

+7
-25
lines changed

1 file changed

+7
-25
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -757,15 +757,9 @@ def check_inputs(
757757
for images_ in image:
758758
for image_ in images_:
759759
self.check_image(image_, prompt, prompt_embeds)
760-
else:
761-
assert False
762760

763761
# Check `controlnet_conditioning_scale`
764-
# TODO Update for https://github.com/huggingface/diffusers/pull/10723
765-
if isinstance(controlnet, ControlNetUnionModel):
766-
if not isinstance(controlnet_conditioning_scale, float):
767-
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
768-
elif isinstance(controlnet, MultiControlNetUnionModel):
762+
if isinstance(controlnet, MultiControlNetUnionModel):
769763
if isinstance(controlnet_conditioning_scale, list):
770764
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
771765
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
@@ -776,8 +770,6 @@ def check_inputs(
776770
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
777771
" the same length as the number of controlnets"
778772
)
779-
else:
780-
assert False
781773

782774
if len(control_guidance_start) != len(control_guidance_end):
783775
raise ValueError(
@@ -808,8 +800,6 @@ def check_inputs(
808800
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
809801
if max(_control_mode) >= _controlnet.config.num_control_type:
810802
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
811-
else:
812-
assert False
813803

814804
# Equal number of `image` and `control_mode` elements
815805
if isinstance(controlnet, ControlNetUnionModel):
@@ -823,8 +813,6 @@ def check_inputs(
823813

824814
elif sum(len(x) for x in image) != sum(len(x) for x in control_mode):
825815
raise ValueError("Expected len(control_image) == len(control_mode)")
826-
else:
827-
assert False
828816

829817
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
830818
raise ValueError(
@@ -1201,6 +1189,11 @@ def __call__(
12011189

12021190
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
12031191

1192+
if not isinstance(control_image, list):
1193+
control_image = [control_image]
1194+
else:
1195+
control_image = control_image.copy()
1196+
12041197
if not isinstance(control_mode, list):
12051198
control_mode = [control_mode]
12061199

@@ -1216,15 +1209,7 @@ def __call__(
12161209
mult * [control_guidance_end],
12171210
)
12181211

1219-
if not isinstance(control_image, list):
1220-
control_image = [control_image]
1221-
else:
1222-
control_image = control_image.copy()
1223-
1224-
if not isinstance(control_mode, list):
1225-
control_mode = [control_mode]
1226-
1227-
if isinstance(controlnet, MultiControlNetUnionModel) and isinstance(controlnet_conditioning_scale, float):
1212+
if isinstance(controlnet_conditioning_scale, float):
12281213
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
12291214
controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
12301215

@@ -1361,9 +1346,6 @@ def __call__(
13611346
control_image = control_images
13621347
height, width = control_image[0][0].shape[-2:]
13631348

1364-
else:
1365-
assert False
1366-
13671349
# 5. Prepare timesteps
13681350
timesteps, num_inference_steps = retrieve_timesteps(
13691351
self.scheduler, num_inference_steps, device, timesteps, sigmas

0 commit comments

Comments
 (0)