4040    AttnProcessor2_0 ,
4141    XFormersAttnProcessor ,
4242)
43- from  ...models .controlnets  import  ControlNetUnionInput , ControlNetUnionInputProMax 
4443from  ...models .lora  import  adjust_lora_scale_text_encoder 
4544from  ...schedulers  import  KarrasDiffusionSchedulers 
4645from  ...utils  import  (
@@ -82,7 +81,6 @@ def retrieve_latents(
8281    Examples: 
8382        ```py 
8483        from diffusers import StableDiffusionXLControlNetUnionInpaintPipeline, ControlNetUnionModel, AutoencoderKL 
85-         from diffusers.models.controlnets import ControlNetUnionInputProMax 
8684        from diffusers.utils import load_image 
8785        import torch 
8886        import numpy as np 
@@ -114,11 +112,8 @@ def retrieve_latents(
114112        mask_np = np.array(mask) 
115113        controlnet_img_np[mask_np > 0] = 0 
116114        controlnet_img = Image.fromarray(controlnet_img_np) 
117-         union_input = ControlNetUnionInputProMax( 
118-             repaint=controlnet_img, 
119-         ) 
120115        # generate image 
121-         image = pipe(prompt, image=image, mask_image=mask, control_image_list=union_input ).images[0] 
116+         image = pipe(prompt, image=image, mask_image=mask, control_image=[controlnet_img], control_mode=[7] ).images[0] 
122117        image.save("inpaint.png") 
123118        ``` 
124119""" 
@@ -1130,7 +1125,7 @@ def __call__(
11301125        prompt_2 : Optional [Union [str , List [str ]]] =  None ,
11311126        image : PipelineImageInput  =  None ,
11321127        mask_image : PipelineImageInput  =  None ,
1133-         control_image_list :  Union [ ControlNetUnionInput ,  ControlNetUnionInputProMax ]  =  None ,
1128+         control_image :  PipelineImageInput  =  None ,
11341129        height : Optional [int ] =  None ,
11351130        width : Optional [int ] =  None ,
11361131        padding_mask_crop : Optional [int ] =  None ,
@@ -1158,6 +1153,7 @@ def __call__(
11581153        guess_mode : bool  =  False ,
11591154        control_guidance_start : Union [float , List [float ]] =  0.0 ,
11601155        control_guidance_end : Union [float , List [float ]] =  1.0 ,
1156+         control_mode : Optional [Union [int , List [int ]]] =  None ,
11611157        guidance_rescale : float  =  0.0 ,
11621158        original_size : Tuple [int , int ] =  None ,
11631159        crops_coords_top_left : Tuple [int , int ] =  (0 , 0 ),
@@ -1345,20 +1341,6 @@ def __call__(
13451341
13461342        controlnet  =  self .controlnet ._orig_mod  if  is_compiled_module (self .controlnet ) else  self .controlnet 
13471343
1348-         if  not  isinstance (control_image_list , (ControlNetUnionInput , ControlNetUnionInputProMax )):
1349-             raise  ValueError (
1350-                 "Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`" 
1351-             )
1352-         if  len (control_image_list ) !=  controlnet .config .num_control_type :
1353-             if  isinstance (control_image_list , ControlNetUnionInput ):
1354-                 raise  ValueError (
1355-                     f"Expected num_control_type { controlnet .config .num_control_type } { len (control_image_list )}  
1356-                 )
1357-             elif  isinstance (control_image_list , ControlNetUnionInputProMax ):
1358-                 raise  ValueError (
1359-                     f"Expected num_control_type { controlnet .config .num_control_type } { len (control_image_list )}  
1360-                 )
1361- 
13621344        # align format for control guidance 
13631345        if  not  isinstance (control_guidance_start , list ) and  isinstance (control_guidance_end , list ):
13641346            control_guidance_start  =  len (control_guidance_end ) *  [control_guidance_start ]
@@ -1375,36 +1357,44 @@ def __call__(
13751357        elif  not  isinstance (control_guidance_end , list ) and  isinstance (control_guidance_start , list ):
13761358            control_guidance_end  =  len (control_guidance_start ) *  [control_guidance_end ]
13771359
1360+         if  not  isinstance (control_image , list ):
1361+             control_image  =  [control_image ]
1362+ 
1363+         if  not  isinstance (control_mode , list ):
1364+             control_mode  =  [control_mode ]
1365+ 
1366+         if  len (control_image ) !=  len (control_mode ):
1367+             raise  ValueError ("Expected len(control_image) == len(control_type)" )
1368+ 
1369+         num_control_type  =  controlnet .config .num_control_type 
1370+ 
13781371        # 1. Check inputs 
1379-         control_type  =  []
1380-         for  image_type  in  control_image_list :
1381-             if  control_image_list [image_type ]:
1382-                 self .check_inputs (
1383-                     prompt ,
1384-                     prompt_2 ,
1385-                     control_image_list [image_type ],
1386-                     mask_image ,
1387-                     strength ,
1388-                     num_inference_steps ,
1389-                     callback_steps ,
1390-                     output_type ,
1391-                     negative_prompt ,
1392-                     negative_prompt_2 ,
1393-                     prompt_embeds ,
1394-                     negative_prompt_embeds ,
1395-                     ip_adapter_image ,
1396-                     ip_adapter_image_embeds ,
1397-                     pooled_prompt_embeds ,
1398-                     negative_pooled_prompt_embeds ,
1399-                     controlnet_conditioning_scale ,
1400-                     control_guidance_start ,
1401-                     control_guidance_end ,
1402-                     callback_on_step_end_tensor_inputs ,
1403-                     padding_mask_crop ,
1404-                 )
1405-                 control_type .append (1 )
1406-             else :
1407-                 control_type .append (0 )
1372+         control_type  =  [0  for  _  in  range (num_control_type )]
1373+         for  _image , control_idx  in  zip (control_image , control_mode ):
1374+             control_type [control_idx ] =  1 
1375+             self .check_inputs (
1376+                 prompt ,
1377+                 prompt_2 ,
1378+                 _image ,
1379+                 mask_image ,
1380+                 strength ,
1381+                 num_inference_steps ,
1382+                 callback_steps ,
1383+                 output_type ,
1384+                 negative_prompt ,
1385+                 negative_prompt_2 ,
1386+                 prompt_embeds ,
1387+                 negative_prompt_embeds ,
1388+                 ip_adapter_image ,
1389+                 ip_adapter_image_embeds ,
1390+                 pooled_prompt_embeds ,
1391+                 negative_pooled_prompt_embeds ,
1392+                 controlnet_conditioning_scale ,
1393+                 control_guidance_start ,
1394+                 control_guidance_end ,
1395+                 callback_on_step_end_tensor_inputs ,
1396+                 padding_mask_crop ,
1397+             )
14081398
14091399        control_type  =  torch .Tensor (control_type )
14101400
@@ -1499,23 +1489,21 @@ def denoising_value_valid(dnv):
14991489        init_image  =  init_image .to (dtype = torch .float32 )
15001490
15011491        # 5.2 Prepare control images 
1502-         for  image_type  in  control_image_list :
1503-             if  control_image_list [image_type ]:
1504-                 control_image  =  self .prepare_control_image (
1505-                     image = control_image_list [image_type ],
1506-                     width = width ,
1507-                     height = height ,
1508-                     batch_size = batch_size  *  num_images_per_prompt ,
1509-                     num_images_per_prompt = num_images_per_prompt ,
1510-                     device = device ,
1511-                     dtype = controlnet .dtype ,
1512-                     crops_coords = crops_coords ,
1513-                     resize_mode = resize_mode ,
1514-                     do_classifier_free_guidance = self .do_classifier_free_guidance ,
1515-                     guess_mode = guess_mode ,
1516-                 )
1517-                 height , width  =  control_image .shape [- 2 :]
1518-                 control_image_list [image_type ] =  control_image 
1492+         for  idx , _  in  enumerate (control_image ):
1493+             control_image [idx ] =  self .prepare_control_image (
1494+                 image = control_image [idx ],
1495+                 width = width ,
1496+                 height = height ,
1497+                 batch_size = batch_size  *  num_images_per_prompt ,
1498+                 num_images_per_prompt = num_images_per_prompt ,
1499+                 device = device ,
1500+                 dtype = controlnet .dtype ,
1501+                 crops_coords = crops_coords ,
1502+                 resize_mode = resize_mode ,
1503+                 do_classifier_free_guidance = self .do_classifier_free_guidance ,
1504+                 guess_mode = guess_mode ,
1505+             )
1506+             height , width  =  control_image [idx ].shape [- 2 :]
15191507
15201508        # 5.3 Prepare mask 
15211509        mask  =  self .mask_processor .preprocess (
@@ -1589,6 +1577,9 @@ def denoising_value_valid(dnv):
15891577
15901578        original_size  =  original_size  or  (height , width )
15911579        target_size  =  target_size  or  (height , width )
1580+         for  _image  in  control_image :
1581+             if  isinstance (_image , torch .Tensor ):
1582+                 original_size  =  original_size  or  _image .shape [- 2 :]
15921583
15931584        # 10. Prepare added time ids & embeddings 
15941585        add_text_embeds  =  pooled_prompt_embeds 
@@ -1693,8 +1684,9 @@ def denoising_value_valid(dnv):
16931684                    control_model_input ,
16941685                    t ,
16951686                    encoder_hidden_states = controlnet_prompt_embeds ,
1696-                     controlnet_cond = control_image_list ,
1687+                     controlnet_cond = control_image ,
16971688                    control_type = control_type ,
1689+                     control_type_idx = control_mode ,
16981690                    conditioning_scale = cond_scale ,
16991691                    guess_mode = guess_mode ,
17001692                    added_cond_kwargs = controlnet_added_cond_kwargs ,
0 commit comments