4040 AttnProcessor2_0 ,
4141 XFormersAttnProcessor ,
4242)
43- from ...models .controlnets import ControlNetUnionInput , ControlNetUnionInputProMax
4443from ...models .lora import adjust_lora_scale_text_encoder
4544from ...schedulers import KarrasDiffusionSchedulers
4645from ...utils import (
@@ -82,7 +81,6 @@ def retrieve_latents(
8281 Examples:
8382 ```py
8483 from diffusers import StableDiffusionXLControlNetUnionInpaintPipeline, ControlNetUnionModel, AutoencoderKL
85- from diffusers.models.controlnets import ControlNetUnionInputProMax
8684 from diffusers.utils import load_image
8785 import torch
8886 import numpy as np
@@ -114,11 +112,8 @@ def retrieve_latents(
114112 mask_np = np.array(mask)
115113 controlnet_img_np[mask_np > 0] = 0
116114 controlnet_img = Image.fromarray(controlnet_img_np)
117- union_input = ControlNetUnionInputProMax(
118- repaint=controlnet_img,
119- )
120115 # generate image
121- image = pipe(prompt, image=image, mask_image=mask, control_image_list=union_input ).images[0]
116+ image = pipe(prompt, image=image, mask_image=mask, control_image=[controlnet_img], control_mode=[7] ).images[0]
122117 image.save("inpaint.png")
123118 ```
124119"""
@@ -1130,7 +1125,7 @@ def __call__(
11301125 prompt_2 : Optional [Union [str , List [str ]]] = None ,
11311126 image : PipelineImageInput = None ,
11321127 mask_image : PipelineImageInput = None ,
1133- control_image_list : Union [ ControlNetUnionInput , ControlNetUnionInputProMax ] = None ,
1128+ control_image : PipelineImageInput = None ,
11341129 height : Optional [int ] = None ,
11351130 width : Optional [int ] = None ,
11361131 padding_mask_crop : Optional [int ] = None ,
@@ -1158,6 +1153,7 @@ def __call__(
11581153 guess_mode : bool = False ,
11591154 control_guidance_start : Union [float , List [float ]] = 0.0 ,
11601155 control_guidance_end : Union [float , List [float ]] = 1.0 ,
1156+ control_mode : Optional [Union [int , List [int ]]] = None ,
11611157 guidance_rescale : float = 0.0 ,
11621158 original_size : Tuple [int , int ] = None ,
11631159 crops_coords_top_left : Tuple [int , int ] = (0 , 0 ),
@@ -1345,20 +1341,6 @@ def __call__(
13451341
13461342 controlnet = self .controlnet ._orig_mod if is_compiled_module (self .controlnet ) else self .controlnet
13471343
1348- if not isinstance (control_image_list , (ControlNetUnionInput , ControlNetUnionInputProMax )):
1349- raise ValueError (
1350- "Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
1351- )
1352- if len (control_image_list ) != controlnet .config .num_control_type :
1353- if isinstance (control_image_list , ControlNetUnionInput ):
1354- raise ValueError (
1355- f"Expected num_control_type { controlnet .config .num_control_type } , got { len (control_image_list )} . Try `ControlNetUnionInputProMax`."
1356- )
1357- elif isinstance (control_image_list , ControlNetUnionInputProMax ):
1358- raise ValueError (
1359- f"Expected num_control_type { controlnet .config .num_control_type } , got { len (control_image_list )} . Try `ControlNetUnionInput`."
1360- )
1361-
13621344 # align format for control guidance
13631345 if not isinstance (control_guidance_start , list ) and isinstance (control_guidance_end , list ):
13641346 control_guidance_start = len (control_guidance_end ) * [control_guidance_start ]
@@ -1375,36 +1357,44 @@ def __call__(
13751357 elif not isinstance (control_guidance_end , list ) and isinstance (control_guidance_start , list ):
13761358 control_guidance_end = len (control_guidance_start ) * [control_guidance_end ]
13771359
1360+ if not isinstance (control_image , list ):
1361+ control_image = [control_image ]
1362+
1363+ if not isinstance (control_mode , list ):
1364+ control_mode = [control_mode ]
1365+
1366+ if len (control_image ) != len (control_mode ):
1367+ raise ValueError ("Expected len(control_image) == len(control_type)" )
1368+
1369+ num_control_type = controlnet .config .num_control_type
1370+
13781371 # 1. Check inputs
1379- control_type = []
1380- for image_type in control_image_list :
1381- if control_image_list [image_type ]:
1382- self .check_inputs (
1383- prompt ,
1384- prompt_2 ,
1385- control_image_list [image_type ],
1386- mask_image ,
1387- strength ,
1388- num_inference_steps ,
1389- callback_steps ,
1390- output_type ,
1391- negative_prompt ,
1392- negative_prompt_2 ,
1393- prompt_embeds ,
1394- negative_prompt_embeds ,
1395- ip_adapter_image ,
1396- ip_adapter_image_embeds ,
1397- pooled_prompt_embeds ,
1398- negative_pooled_prompt_embeds ,
1399- controlnet_conditioning_scale ,
1400- control_guidance_start ,
1401- control_guidance_end ,
1402- callback_on_step_end_tensor_inputs ,
1403- padding_mask_crop ,
1404- )
1405- control_type .append (1 )
1406- else :
1407- control_type .append (0 )
1372+ control_type = [0 for _ in range (num_control_type )]
1373+ for _image , control_idx in zip (control_image , control_mode ):
1374+ control_type [control_idx ] = 1
1375+ self .check_inputs (
1376+ prompt ,
1377+ prompt_2 ,
1378+ _image ,
1379+ mask_image ,
1380+ strength ,
1381+ num_inference_steps ,
1382+ callback_steps ,
1383+ output_type ,
1384+ negative_prompt ,
1385+ negative_prompt_2 ,
1386+ prompt_embeds ,
1387+ negative_prompt_embeds ,
1388+ ip_adapter_image ,
1389+ ip_adapter_image_embeds ,
1390+ pooled_prompt_embeds ,
1391+ negative_pooled_prompt_embeds ,
1392+ controlnet_conditioning_scale ,
1393+ control_guidance_start ,
1394+ control_guidance_end ,
1395+ callback_on_step_end_tensor_inputs ,
1396+ padding_mask_crop ,
1397+ )
14081398
14091399 control_type = torch .Tensor (control_type )
14101400
@@ -1499,23 +1489,21 @@ def denoising_value_valid(dnv):
14991489 init_image = init_image .to (dtype = torch .float32 )
15001490
15011491 # 5.2 Prepare control images
1502- for image_type in control_image_list :
1503- if control_image_list [image_type ]:
1504- control_image = self .prepare_control_image (
1505- image = control_image_list [image_type ],
1506- width = width ,
1507- height = height ,
1508- batch_size = batch_size * num_images_per_prompt ,
1509- num_images_per_prompt = num_images_per_prompt ,
1510- device = device ,
1511- dtype = controlnet .dtype ,
1512- crops_coords = crops_coords ,
1513- resize_mode = resize_mode ,
1514- do_classifier_free_guidance = self .do_classifier_free_guidance ,
1515- guess_mode = guess_mode ,
1516- )
1517- height , width = control_image .shape [- 2 :]
1518- control_image_list [image_type ] = control_image
1492+ for idx , _ in enumerate (control_image ):
1493+ control_image [idx ] = self .prepare_control_image (
1494+ image = control_image [idx ],
1495+ width = width ,
1496+ height = height ,
1497+ batch_size = batch_size * num_images_per_prompt ,
1498+ num_images_per_prompt = num_images_per_prompt ,
1499+ device = device ,
1500+ dtype = controlnet .dtype ,
1501+ crops_coords = crops_coords ,
1502+ resize_mode = resize_mode ,
1503+ do_classifier_free_guidance = self .do_classifier_free_guidance ,
1504+ guess_mode = guess_mode ,
1505+ )
1506+ height , width = control_image [idx ].shape [- 2 :]
15191507
15201508 # 5.3 Prepare mask
15211509 mask = self .mask_processor .preprocess (
@@ -1589,6 +1577,9 @@ def denoising_value_valid(dnv):
15891577
15901578 original_size = original_size or (height , width )
15911579 target_size = target_size or (height , width )
1580+ for _image in control_image :
1581+ if isinstance (_image , torch .Tensor ):
1582+ original_size = original_size or _image .shape [- 2 :]
15921583
15931584 # 10. Prepare added time ids & embeddings
15941585 add_text_embeds = pooled_prompt_embeds
@@ -1693,8 +1684,9 @@ def denoising_value_valid(dnv):
16931684 control_model_input ,
16941685 t ,
16951686 encoder_hidden_states = controlnet_prompt_embeds ,
1696- controlnet_cond = control_image_list ,
1687+ controlnet_cond = control_image ,
16971688 control_type = control_type ,
1689+ control_type_idx = control_mode ,
16981690 conditioning_scale = cond_scale ,
16991691 guess_mode = guess_mode ,
17001692 added_cond_kwargs = controlnet_added_cond_kwargs ,
0 commit comments