1313# limitations under the License.
1414import inspect
1515from typing import Any , Callable , Dict , List , Optional , Union
16+
1617import PIL .Image
1718import torch
1819from packaging import version
5758 ```
5859"""
5960
61+
6062# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
6163def retrieve_latents (
6264 encoder_output : torch .Tensor , generator : Optional [torch .Generator ] = None , sample_mode : str = "sample"
@@ -70,6 +72,7 @@ def retrieve_latents(
7072 else :
7173 raise AttributeError ("Could not access latents of provided encoder_output" )
7274
75+
7376# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
7477def rescale_noise_cfg (noise_cfg , noise_pred_text , guidance_rescale = 0.0 ):
7578 """
@@ -579,94 +582,95 @@ def prepare_extra_step_kwargs(self, generator, eta):
579582
580583 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
581584 def check_inputs (
582- self ,
583- prompt ,
584- image ,
585- mask_image ,
586- height ,
587- width ,
588- strength ,
589- callback_steps ,
590- output_type ,
591- negative_prompt = None ,
592- prompt_embeds = None ,
593- negative_prompt_embeds = None ,
594- ip_adapter_image = None ,
595- ip_adapter_image_embeds = None ,
596- callback_on_step_end_tensor_inputs = None ,
597- padding_mask_crop = None ,
585+ self ,
586+ prompt ,
587+ image ,
588+ mask_image ,
589+ height ,
590+ width ,
591+ strength ,
592+ callback_steps ,
593+ output_type ,
594+ negative_prompt = None ,
595+ prompt_embeds = None ,
596+ negative_prompt_embeds = None ,
597+ ip_adapter_image = None ,
598+ ip_adapter_image_embeds = None ,
599+ callback_on_step_end_tensor_inputs = None ,
600+ padding_mask_crop = None ,
601+ ):
602+ if strength < 0 or strength > 1 :
603+ raise ValueError (f"The value of strength should in [0.0, 1.0] but is { strength } " )
604+
605+ if height % self .vae_scale_factor != 0 or width % self .vae_scale_factor != 0 :
606+ raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
607+
608+ if callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 ):
609+ raise ValueError (
610+ f"`callback_steps` has to be a positive integer but is { callback_steps } of type"
611+ f" { type (callback_steps )} ."
612+ )
613+
614+ if callback_on_step_end_tensor_inputs is not None and not all (
615+ k in self ._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
598616 ):
599- if strength < 0 or strength > 1 :
600- raise ValueError (f"The value of strength should in [0.0, 1.0] but is { strength } " )
617+ raise ValueError (
618+ f"`callback_on_step_end_tensor_inputs` has to be in { self ._callback_tensor_inputs } , but found { [k for k in callback_on_step_end_tensor_inputs if k not in self ._callback_tensor_inputs ]} "
619+ )
601620
602- if height % self .vae_scale_factor != 0 or width % self .vae_scale_factor != 0 :
603- raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
621+ if prompt is not None and prompt_embeds is not None :
622+ raise ValueError (
623+ f"Cannot forward both `prompt`: { prompt } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
624+ " only forward one of the two."
625+ )
626+ elif prompt is None and prompt_embeds is None :
627+ raise ValueError (
628+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
629+ )
630+ elif prompt is not None and (not isinstance (prompt , str ) and not isinstance (prompt , list )):
631+ raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
604632
605- if callback_steps is not None and ( not isinstance ( callback_steps , int ) or callback_steps <= 0 ) :
606- raise ValueError (
607- f"`callback_steps` has to be a positive integer but is { callback_steps } of type "
608- f" { type ( callback_steps ) } ."
609- )
633+ if negative_prompt is not None and negative_prompt_embeds is not None :
634+ raise ValueError (
635+ f"Cannot forward both `negative_prompt`: { negative_prompt } and `negative_prompt_embeds`: "
636+ f" { negative_prompt_embeds } . Please make sure to only forward one of the two ."
637+ )
610638
611- if callback_on_step_end_tensor_inputs is not None and not all (
612- k in self ._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
613- ):
639+ if prompt_embeds is not None and negative_prompt_embeds is not None :
640+ if prompt_embeds .shape != negative_prompt_embeds .shape :
614641 raise ValueError (
615- f"`callback_on_step_end_tensor_inputs` has to be in { self ._callback_tensor_inputs } , but found { [k for k in callback_on_step_end_tensor_inputs if k not in self ._callback_tensor_inputs ]} "
642+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
643+ f" got: `prompt_embeds` { prompt_embeds .shape } != `negative_prompt_embeds`"
644+ f" { negative_prompt_embeds .shape } ."
616645 )
617-
618- if prompt is not None and prompt_embeds is not None :
646+ if padding_mask_crop is not None :
647+ if not isinstance ( image , PIL . Image . Image ) :
619648 raise ValueError (
620- f"Cannot forward both `prompt`: { prompt } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
621- " only forward one of the two."
649+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" { type (image )} ."
622650 )
623- elif prompt is None and prompt_embeds is None :
651+ if not isinstance ( mask_image , PIL . Image . Image ) :
624652 raise ValueError (
625- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
653+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
654+ f" { type (mask_image )} ."
626655 )
627- elif prompt is not None and ( not isinstance ( prompt , str ) and not isinstance ( prompt , list )) :
628- raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type ( prompt ) } " )
656+ if output_type != "pil" :
657+ raise ValueError (f"The output type should be PIL when inpainting mask crop, but is" f" { output_type } . " )
629658
630- if negative_prompt is not None and negative_prompt_embeds is not None :
659+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None :
660+ raise ValueError (
661+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
662+ )
663+
664+ if ip_adapter_image_embeds is not None :
665+ if not isinstance (ip_adapter_image_embeds , list ):
631666 raise ValueError (
632- f"Cannot forward both `negative_prompt`: { negative_prompt } and `negative_prompt_embeds`:"
633- f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
667+ f"`ip_adapter_image_embeds` has to be of type `list` but is { type (ip_adapter_image_embeds )} "
634668 )
635-
636- if prompt_embeds is not None and negative_prompt_embeds is not None :
637- if prompt_embeds .shape != negative_prompt_embeds .shape :
638- raise ValueError (
639- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
640- f" got: `prompt_embeds` { prompt_embeds .shape } != `negative_prompt_embeds`"
641- f" { negative_prompt_embeds .shape } ."
642- )
643- if padding_mask_crop is not None :
644- if not isinstance (image , PIL .Image .Image ):
645- raise ValueError (
646- f"The image should be a PIL image when inpainting mask crop, but is of type" f" { type (image )} ."
647- )
648- if not isinstance (mask_image , PIL .Image .Image ):
649- raise ValueError (
650- f"The mask image should be a PIL image when inpainting mask crop, but is of type"
651- f" { type (mask_image )} ."
652- )
653- if output_type != "pil" :
654- raise ValueError (f"The output type should be PIL when inpainting mask crop, but is" f" { output_type } ." )
655-
656- if ip_adapter_image is not None and ip_adapter_image_embeds is not None :
669+ elif ip_adapter_image_embeds [0 ].ndim not in [3 , 4 ]:
657670 raise ValueError (
658- "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined. "
671+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is { ip_adapter_image_embeds [ 0 ]. ndim } D "
659672 )
660673
661- if ip_adapter_image_embeds is not None :
662- if not isinstance (ip_adapter_image_embeds , list ):
663- raise ValueError (
664- f"`ip_adapter_image_embeds` has to be of type `list` but is { type (ip_adapter_image_embeds )} "
665- )
666- elif ip_adapter_image_embeds [0 ].ndim not in [3 , 4 ]:
667- raise ValueError (
668- f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is { ip_adapter_image_embeds [0 ].ndim } D"
669- )
670674 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents
671675 def prepare_latents (
672676 self ,
@@ -730,7 +734,7 @@ def prepare_latents(
730734 outputs += (image_latents ,)
731735
732736 return outputs
733-
737+
734738 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image
735739 def _encode_vae_image (self , image : torch .Tensor , generator : torch .Generator ):
736740 if isinstance (generator , list ):
@@ -746,7 +750,7 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
746750
747751 return image_latents
748752
749- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents
753+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents
750754 def prepare_mask_latents (
751755 self , mask , masked_image , batch_size , height , width , dtype , device , generator , do_classifier_free_guidance
752756 ):
@@ -788,8 +792,7 @@ def prepare_mask_latents(
788792 torch .cat ([masked_image_latents ] * 2 ) if do_classifier_free_guidance else masked_image_latents
789793 )
790794
791- # star
792-
795+ # star
793796
794797 # aligning device to prevent device errors when concating it with the latent model input
795798 masked_image_latents = masked_image_latents .to (device = device , dtype = dtype )
@@ -807,7 +810,6 @@ def get_timesteps(self, num_inference_steps, strength, device):
807810
808811 return timesteps , num_inference_steps - t_start
809812
810-
811813 # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
812814 def get_guidance_scale_embedding (
813815 self , w : torch .Tensor , embedding_dim : int = 512 , dtype : torch .dtype = torch .float32
@@ -1001,7 +1003,7 @@ def __call__(
10011003 height = height or self .unet .config .sample_size * self .vae_scale_factor
10021004 width = width or self .unet .config .sample_size * self .vae_scale_factor
10031005 # to deal with lora scaling and other possible forward hooks
1004-
1006+
10051007 # 1. Check inputs. Raise error if not correct
10061008 self .check_inputs (
10071009 prompt ,
@@ -1069,7 +1071,7 @@ def __call__(
10691071 f"After adjusting the num_inference_steps by strength parameter: { strength } , the number of pipeline"
10701072 f"steps is { num_inference_steps } which is < 1 and not appropriate for this pipeline."
10711073 )
1072-
1074+
10731075 latent_timestep = timesteps [:1 ].repeat (batch_size * num_images_per_prompt )
10741076 # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
10751077 is_strength_max = strength == 1.0
@@ -1198,8 +1200,6 @@ def __call__(
11981200 image_embeds = image_embeds .to (device )
11991201 ip_adapter_image_embeds [i ] = image_embeds
12001202
1201-
1202-
12031203 # 9.1 Add image embeds for IP-Adapter
12041204 added_cond_kwargs = (
12051205 {"image_embeds" : ip_adapter_image_embeds }
@@ -1217,8 +1217,7 @@ def __call__(
12171217
12181218 # 10. Denoising loop
12191219 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
1220-
1221-
1220+
12221221 if self .do_perturbed_attention_guidance :
12231222 original_attn_proc = self .unet .attn_processors
12241223 self ._set_pag_attn_processor (
@@ -1281,7 +1280,6 @@ def __call__(
12811280
12821281 latents = (1 - init_mask ) * init_latents_proper + init_mask * latents
12831282
1284-
12851283 if callback_on_step_end is not None :
12861284 callback_kwargs = {}
12871285 for k in callback_on_step_end_tensor_inputs :
@@ -1297,7 +1295,7 @@ def __call__(
12971295 # call the callback, if provided
12981296 if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
12991297 progress_bar .update ()
1300-
1298+
13011299 if not output_type == "latent" :
13021300 condition_kwargs = {}
13031301 if isinstance (self .vae , AsymmetricAutoencoderKL ):
@@ -1330,7 +1328,6 @@ def __call__(
13301328 if self .do_perturbed_attention_guidance :
13311329 self .unet .set_attn_processor (original_attn_proc )
13321330
1333-
13341331 if not return_dict :
13351332 return (image , has_nsfw_concept )
13361333
0 commit comments