@@ -757,15 +757,9 @@ def check_inputs(
757757            for  images_  in  image :
758758                for  image_  in  images_ :
759759                    self .check_image (image_ , prompt , prompt_embeds )
760-         else :
761-             assert  False 
762760
763761        # Check `controlnet_conditioning_scale` 
764-         # TODO Update for https://github.com/huggingface/diffusers/pull/10723 
765-         if  isinstance (controlnet , ControlNetUnionModel ):
766-             if  not  isinstance (controlnet_conditioning_scale , float ):
767-                 raise  TypeError ("For single controlnet: `controlnet_conditioning_scale` must be type `float`." )
768-         elif  isinstance (controlnet , MultiControlNetUnionModel ):
762+         if  isinstance (controlnet , MultiControlNetUnionModel ):
769763            if  isinstance (controlnet_conditioning_scale , list ):
770764                if  any (isinstance (i , list ) for  i  in  controlnet_conditioning_scale ):
771765                    raise  ValueError ("A single batch of multiple conditionings is not supported at the moment." )
@@ -776,8 +770,6 @@ def check_inputs(
776770                    "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" 
777771                    " the same length as the number of controlnets" 
778772                )
779-         else :
780-             assert  False 
781773
782774        if  len (control_guidance_start ) !=  len (control_guidance_end ):
783775            raise  ValueError (
@@ -808,8 +800,6 @@ def check_inputs(
808800            for  _control_mode , _controlnet  in  zip (control_mode , self .controlnet .nets ):
809801                if  max (_control_mode ) >=  _controlnet .config .num_control_type :
810802                    raise  ValueError (f"control_mode: must be lower than { _controlnet .config .num_control_type }  )
811-         else :
812-             assert  False 
813803
814804        # Equal number of `image` and `control_mode` elements 
815805        if  isinstance (controlnet , ControlNetUnionModel ):
@@ -823,8 +813,6 @@ def check_inputs(
823813
824814            elif  sum (len (x ) for  x  in  image ) !=  sum (len (x ) for  x  in  control_mode ):
825815                raise  ValueError ("Expected len(control_image) == len(control_mode)" )
826-         else :
827-             assert  False 
828816
829817        if  ip_adapter_image  is  not None  and  ip_adapter_image_embeds  is  not None :
830818            raise  ValueError (
@@ -1201,28 +1189,33 @@ def __call__(
12011189
12021190        controlnet  =  self .controlnet ._orig_mod  if  is_compiled_module (self .controlnet ) else  self .controlnet 
12031191
1192+         if  not  isinstance (control_image , list ):
1193+             control_image  =  [control_image ]
1194+         else :
1195+             control_image  =  control_image .copy ()
1196+ 
1197+         if  not  isinstance (control_mode , list ):
1198+             control_mode  =  [control_mode ]
1199+ 
1200+         if  isinstance (controlnet , MultiControlNetUnionModel ):
1201+             control_image  =  [[item ] for  item  in  control_image ]
1202+             control_mode  =  [[item ] for  item  in  control_mode ]
1203+ 
12041204        # align format for control guidance 
12051205        if  not  isinstance (control_guidance_start , list ) and  isinstance (control_guidance_end , list ):
12061206            control_guidance_start  =  len (control_guidance_end ) *  [control_guidance_start ]
12071207        elif  not  isinstance (control_guidance_end , list ) and  isinstance (control_guidance_start , list ):
12081208            control_guidance_end  =  len (control_guidance_start ) *  [control_guidance_end ]
12091209        elif  not  isinstance (control_guidance_start , list ) and  not  isinstance (control_guidance_end , list ):
1210-             mult  =  len (controlnet .nets ) if  isinstance (controlnet , MultiControlNetUnionModel ) else  1 
1210+             mult  =  len (controlnet .nets ) if  isinstance (controlnet , MultiControlNetUnionModel ) else  len ( control_mode ) 
12111211            control_guidance_start , control_guidance_end  =  (
12121212                mult  *  [control_guidance_start ],
12131213                mult  *  [control_guidance_end ],
12141214            )
12151215
1216-         if  not  isinstance (control_image , list ):
1217-             control_image  =  [control_image ]
1218-         else :
1219-             control_image  =  control_image .copy ()
1220- 
1221-         if  not  isinstance (control_mode , list ):
1222-             control_mode  =  [control_mode ]
1223- 
1224-         if  isinstance (controlnet , MultiControlNetUnionModel ) and  isinstance (controlnet_conditioning_scale , float ):
1225-             controlnet_conditioning_scale  =  [controlnet_conditioning_scale ] *  len (controlnet .nets )
1216+         if  isinstance (controlnet_conditioning_scale , float ):
1217+             mult  =  len (controlnet .nets ) if  isinstance (controlnet , MultiControlNetUnionModel ) else  len (control_mode )
1218+             controlnet_conditioning_scale  =  [controlnet_conditioning_scale ] *  mult 
12261219
12271220        # 1. Check inputs 
12281221        self .check_inputs (
@@ -1357,9 +1350,6 @@ def __call__(
13571350            control_image  =  control_images 
13581351            height , width  =  control_image [0 ][0 ].shape [- 2 :]
13591352
1360-         else :
1361-             assert  False 
1362- 
13631353        # 5. Prepare timesteps 
13641354        timesteps , num_inference_steps  =  retrieve_timesteps (
13651355            self .scheduler , num_inference_steps , device , timesteps , sigmas 
@@ -1397,7 +1387,7 @@ def __call__(
13971387                1.0  -  float (i  /  len (timesteps ) <  s  or  (i  +  1 ) /  len (timesteps ) >  e )
13981388                for  s , e  in  zip (control_guidance_start , control_guidance_end )
13991389            ]
1400-             controlnet_keep .append (keeps [ 0 ]  if   isinstance ( controlnet ,  ControlNetUnionModel )  else   keeps )
1390+             controlnet_keep .append (keeps )
14011391
14021392        # 7.2 Prepare added time ids & embeddings 
14031393        original_size  =  original_size  or  (height , width )
0 commit comments