1616import inspect
1717from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1818
19- import numpy as np
20- import PIL .Image
2119import torch
2220import torch .nn .functional as F
2321from transformers import (
4846from ...schedulers import KarrasDiffusionSchedulers
4947from ...utils import (
5048 USE_PEFT_BACKEND ,
51- deprecate ,
5249 logging ,
5350 replace_example_docstring ,
5451 scale_lora_layers ,
@@ -615,8 +612,7 @@ def check_inputs(
615612 self ,
616613 prompt ,
617614 prompt_2 ,
618- image ,
619- callback_steps ,
615+ image : PipelineImageInput ,
620616 negative_prompt = None ,
621617 negative_prompt_2 = None ,
622618 prompt_embeds = None ,
@@ -630,12 +626,6 @@ def check_inputs(
630626 control_guidance_end = 1.0 ,
631627 callback_on_step_end_tensor_inputs = None ,
632628 ):
633- if callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 ):
634- raise ValueError (
635- f"`callback_steps` has to be a positive integer but is { callback_steps } of type"
636- f" { type (callback_steps )} ."
637- )
638-
639629 if callback_on_step_end_tensor_inputs is not None and not all (
640630 k in self ._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
641631 ):
@@ -767,43 +757,25 @@ def check_inputs(
767757 f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is { ip_adapter_image_embeds [0 ].ndim } D"
768758 )
769759
770- # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
771- def check_image (self , image , prompt , prompt_embeds ):
772- image_is_pil = isinstance (image , PIL .Image .Image )
773- image_is_tensor = isinstance (image , torch .Tensor )
774- image_is_np = isinstance (image , np .ndarray )
775- image_is_pil_list = isinstance (image , list ) and isinstance (image [0 ], PIL .Image .Image )
776- image_is_tensor_list = isinstance (image , list ) and isinstance (image [0 ], torch .Tensor )
777- image_is_np_list = isinstance (image , list ) and isinstance (image [0 ], np .ndarray )
778-
779- if (
780- not image_is_pil
781- and not image_is_tensor
782- and not image_is_np
783- and not image_is_pil_list
784- and not image_is_tensor_list
785- and not image_is_np_list
786- ):
787- raise TypeError (
788- f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is { type (image )} "
789- )
790-
791- if image_is_pil :
792- image_batch_size = 1
793- else :
794- image_batch_size = len (image )
795-
796- if prompt is not None and isinstance (prompt , str ):
797- prompt_batch_size = 1
798- elif prompt is not None and isinstance (prompt , list ):
799- prompt_batch_size = len (prompt )
800- elif prompt_embeds is not None :
801- prompt_batch_size = prompt_embeds .shape [0 ]
760+ def check_input (
761+ self ,
762+ image : Union [ControlNetUnionInput , ControlNetUnionInputProMax ],
763+ ):
764+ controlnet = self .controlnet ._orig_mod if is_compiled_module (self .controlnet ) else self .controlnet
802765
803- if image_batch_size != 1 and image_batch_size != prompt_batch_size :
766+ if not isinstance ( image , ( ControlNetUnionInput , ControlNetUnionInputProMax )) :
804767 raise ValueError (
805- f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: { image_batch_size } , prompt batch size: { prompt_batch_size } "
768+ "Expected type of ` image` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax` "
806769 )
770+ if len (image ) != controlnet .config .num_control_type :
771+ if isinstance (image , ControlNetUnionInput ):
772+ raise ValueError (
773+ f"Expected num_control_type { controlnet .config .num_control_type } , got { len (image )} . Try `ControlNetUnionInputProMax`."
774+ )
775+ elif isinstance (image , ControlNetUnionInputProMax ):
776+ raise ValueError (
777+ f"Expected num_control_type { controlnet .config .num_control_type } , got { len (image )} . Try `ControlNetUnionInput`."
778+ )
807779
808780 # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
809781 def prepare_image (
@@ -823,9 +795,11 @@ def prepare_image(
823795
824796 if image_batch_size == 1 :
825797 repeat_by = batch_size
826- else :
798+ elif image_batch_size == batch_size :
827799 # image batch size is the same as prompt batch size
828800 repeat_by = num_images_per_prompt
801+ else :
802+ raise ValueError (f"Expected image batch size == 1 or `batch_size`, got { image_batch_size } ." )
829803
830804 image = image .repeat_interleave (repeat_by , dim = 0 )
831805
@@ -964,7 +938,7 @@ def __call__(
964938 self ,
965939 prompt : Union [str , List [str ]] = None ,
966940 prompt_2 : Optional [Union [str , List [str ]]] = None ,
967- image_list : Union [ControlNetUnionInput , ControlNetUnionInputProMax ] = None ,
941+ image : Union [ControlNetUnionInput , ControlNetUnionInputProMax ] = None ,
968942 height : Optional [int ] = None ,
969943 width : Optional [int ] = None ,
970944 num_inference_steps : int = 50 ,
@@ -1002,7 +976,6 @@ def __call__(
1002976 Union [Callable [[int , int , Dict ], None ], PipelineCallback , MultiPipelineCallbacks ]
1003977 ] = None ,
1004978 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
1005- ** kwargs ,
1006979 ):
1007980 r"""
1008981 The call function to the pipeline for generation.
@@ -1013,7 +986,7 @@ def __call__(
1013986 prompt_2 (`str` or `List[str]`, *optional*):
1014987 The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1015988 used in both text-encoders.
1016- image_list (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`):
989+ image (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`):
1017990 In turn this supports (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`,
1018991 `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]`
1019992 or `List[List[PIL.Image.Image]]`):
@@ -1158,40 +1131,12 @@ def __call__(
11581131 otherwise a `tuple` is returned containing the output images.
11591132 """
11601133
1161- callback = kwargs .pop ("callback" , None )
1162- callback_steps = kwargs .pop ("callback_steps" , None )
1163-
1164- if callback is not None :
1165- deprecate (
1166- "callback" ,
1167- "1.0.0" ,
1168- "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`" ,
1169- )
1170- if callback_steps is not None :
1171- deprecate (
1172- "callback_steps" ,
1173- "1.0.0" ,
1174- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`" ,
1175- )
1176-
11771134 if isinstance (callback_on_step_end , (PipelineCallback , MultiPipelineCallbacks )):
11781135 callback_on_step_end_tensor_inputs = callback_on_step_end .tensor_inputs
11791136
11801137 controlnet = self .controlnet ._orig_mod if is_compiled_module (self .controlnet ) else self .controlnet
11811138
1182- if not isinstance (image_list , (ControlNetUnionInput , ControlNetUnionInputProMax )):
1183- raise ValueError (
1184- "Expected type of `image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
1185- )
1186- if len (image_list ) != controlnet .config .num_control_type :
1187- if isinstance (image_list , ControlNetUnionInput ):
1188- raise ValueError (
1189- f"Expected num_control_type { controlnet .config .num_control_type } , got { len (image_list )} . Try `ControlNetUnionInputProMax`."
1190- )
1191- elif isinstance (image_list , ControlNetUnionInputProMax ):
1192- raise ValueError (
1193- f"Expected num_control_type { controlnet .config .num_control_type } , got { len (image_list )} . Try `ControlNetUnionInput`."
1194- )
1139+ self .check_input (image )
11951140
11961141 # align format for control guidance
11971142 if not isinstance (control_guidance_start , list ) and isinstance (control_guidance_end , list ):
@@ -1201,13 +1146,12 @@ def __call__(
12011146
12021147 # 1. Check inputs. Raise error if not correct
12031148 control_type = []
1204- for image_type in image_list :
1205- if image_list [image_type ]:
1149+ for image_type in image :
1150+ if image [image_type ]:
12061151 self .check_inputs (
12071152 prompt ,
12081153 prompt_2 ,
1209- image_list [image_type ],
1210- callback_steps ,
1154+ image [image_type ],
12111155 negative_prompt ,
12121156 negative_prompt_2 ,
12131157 prompt_embeds ,
@@ -1282,10 +1226,10 @@ def __call__(
12821226 )
12831227
12841228 # 4. Prepare image
1285- for image_type in image_list :
1286- if image_list [image_type ]:
1229+ for image_type in image :
1230+ if image [image_type ]:
12871231 image = self .prepare_image (
1288- image = image_list [image_type ],
1232+ image = image [image_type ],
12891233 width = width ,
12901234 height = height ,
12911235 batch_size = batch_size * num_images_per_prompt ,
@@ -1296,7 +1240,7 @@ def __call__(
12961240 guess_mode = guess_mode ,
12971241 )
12981242 height , width = image .shape [- 2 :]
1299- image_list [image_type ] = image
1243+ image [image_type ] = image
13001244
13011245 # 5. Prepare timesteps
13021246 timesteps , num_inference_steps = retrieve_timesteps (
@@ -1337,9 +1281,9 @@ def __call__(
13371281 )
13381282
13391283 # 7.2 Prepare added time ids & embeddings
1340- for image_type in image_list :
1341- if isinstance (image_list [image_type ], torch .Tensor ):
1342- original_size = original_size or image_list [image_type ].shape [- 2 :]
1284+ for image_type in image :
1285+ if isinstance (image [image_type ], torch .Tensor ):
1286+ original_size = original_size or image [image_type ].shape [- 2 :]
13431287
13441288 target_size = target_size or (height , width )
13451289 add_text_embeds = pooled_prompt_embeds
@@ -1449,7 +1393,7 @@ def __call__(
14491393 control_model_input ,
14501394 t ,
14511395 encoder_hidden_states = controlnet_prompt_embeds ,
1452- controlnet_cond = image_list ,
1396+ controlnet_cond = image ,
14531397 control_type = control_type ,
14541398 conditioning_scale = cond_scale ,
14551399 guess_mode = guess_mode ,
@@ -1508,9 +1452,6 @@ def __call__(
15081452 # call the callback, if provided
15091453 if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
15101454 progress_bar .update ()
1511- if callback is not None and i % callback_steps == 0 :
1512- step_idx = i // getattr (self .scheduler , "order" , 1 )
1513- callback (step_idx , t , latents )
15141455
15151456 if not output_type == "latent" :
15161457 # make sure the VAE is in float32 mode, as it overflows in float16
0 commit comments