@@ -757,15 +757,9 @@ def check_inputs(
757757 for images_ in image :
758758 for image_ in images_ :
759759 self .check_image (image_ , prompt , prompt_embeds )
760- else :
761- assert False
762760
763761 # Check `controlnet_conditioning_scale`
764- # TODO Update for https://github.com/huggingface/diffusers/pull/10723
765- if isinstance (controlnet , ControlNetUnionModel ):
766- if not isinstance (controlnet_conditioning_scale , float ):
767- raise TypeError ("For single controlnet: `controlnet_conditioning_scale` must be type `float`." )
768- elif isinstance (controlnet , MultiControlNetUnionModel ):
762+ if isinstance (controlnet , MultiControlNetUnionModel ):
769763 if isinstance (controlnet_conditioning_scale , list ):
770764 if any (isinstance (i , list ) for i in controlnet_conditioning_scale ):
771765 raise ValueError ("A single batch of multiple conditionings is not supported at the moment." )
@@ -776,8 +770,6 @@ def check_inputs(
776770 "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
777771 " the same length as the number of controlnets"
778772 )
779- else :
780- assert False
781773
782774 if len (control_guidance_start ) != len (control_guidance_end ):
783775 raise ValueError (
@@ -808,8 +800,6 @@ def check_inputs(
808800 for _control_mode , _controlnet in zip (control_mode , self .controlnet .nets ):
809801 if max (_control_mode ) >= _controlnet .config .num_control_type :
810802 raise ValueError (f"control_mode: must be lower than { _controlnet .config .num_control_type } ." )
811- else :
812- assert False
813803
814804 # Equal number of `image` and `control_mode` elements
815805 if isinstance (controlnet , ControlNetUnionModel ):
@@ -823,8 +813,6 @@ def check_inputs(
823813
824814 elif sum (len (x ) for x in image ) != sum (len (x ) for x in control_mode ):
825815 raise ValueError ("Expected len(control_image) == len(control_mode)" )
826- else :
827- assert False
828816
829817 if ip_adapter_image is not None and ip_adapter_image_embeds is not None :
830818 raise ValueError (
@@ -1201,28 +1189,33 @@ def __call__(
12011189
12021190 controlnet = self .controlnet ._orig_mod if is_compiled_module (self .controlnet ) else self .controlnet
12031191
1192+ if not isinstance (control_image , list ):
1193+ control_image = [control_image ]
1194+ else :
1195+ control_image = control_image .copy ()
1196+
1197+ if not isinstance (control_mode , list ):
1198+ control_mode = [control_mode ]
1199+
1200+ if isinstance (controlnet , MultiControlNetUnionModel ):
1201+ control_image = [[item ] for item in control_image ]
1202+ control_mode = [[item ] for item in control_mode ]
1203+
12041204 # align format for control guidance
12051205 if not isinstance (control_guidance_start , list ) and isinstance (control_guidance_end , list ):
12061206 control_guidance_start = len (control_guidance_end ) * [control_guidance_start ]
12071207 elif not isinstance (control_guidance_end , list ) and isinstance (control_guidance_start , list ):
12081208 control_guidance_end = len (control_guidance_start ) * [control_guidance_end ]
12091209 elif not isinstance (control_guidance_start , list ) and not isinstance (control_guidance_end , list ):
1210- mult = len (controlnet .nets ) if isinstance (controlnet , MultiControlNetUnionModel ) else 1
1210+ mult = len (controlnet .nets ) if isinstance (controlnet , MultiControlNetUnionModel ) else len ( control_mode )
12111211 control_guidance_start , control_guidance_end = (
12121212 mult * [control_guidance_start ],
12131213 mult * [control_guidance_end ],
12141214 )
12151215
1216- if not isinstance (control_image , list ):
1217- control_image = [control_image ]
1218- else :
1219- control_image = control_image .copy ()
1220-
1221- if not isinstance (control_mode , list ):
1222- control_mode = [control_mode ]
1223-
1224- if isinstance (controlnet , MultiControlNetUnionModel ) and isinstance (controlnet_conditioning_scale , float ):
1225- controlnet_conditioning_scale = [controlnet_conditioning_scale ] * len (controlnet .nets )
1216+ if isinstance (controlnet_conditioning_scale , float ):
1217+ mult = len (controlnet .nets ) if isinstance (controlnet , MultiControlNetUnionModel ) else len (control_mode )
1218+ controlnet_conditioning_scale = [controlnet_conditioning_scale ] * mult
12261219
12271220 # 1. Check inputs
12281221 self .check_inputs (
@@ -1357,9 +1350,6 @@ def __call__(
13571350 control_image = control_images
13581351 height , width = control_image [0 ][0 ].shape [- 2 :]
13591352
1360- else :
1361- assert False
1362-
13631353 # 5. Prepare timesteps
13641354 timesteps , num_inference_steps = retrieve_timesteps (
13651355 self .scheduler , num_inference_steps , device , timesteps , sigmas
@@ -1397,7 +1387,7 @@ def __call__(
13971387 1.0 - float (i / len (timesteps ) < s or (i + 1 ) / len (timesteps ) > e )
13981388 for s , e in zip (control_guidance_start , control_guidance_end )
13991389 ]
1400- controlnet_keep .append (keeps [ 0 ] if isinstance ( controlnet , ControlNetUnionModel ) else keeps )
1390+ controlnet_keep .append (keeps )
14011391
14021392 # 7.2 Prepare added time ids & embeddings
14031393 original_size = original_size or (height , width )
0 commit comments