@@ -667,7 +667,7 @@ def check_inputs(
667667 raise ValueError (
668668 f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is { ip_adapter_image_embeds [0 ].ndim } D"
669669 )
670-
670+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents
671671 def prepare_latents (
672672 self ,
673673 batch_size ,
@@ -731,6 +731,7 @@ def prepare_latents(
731731
732732 return outputs
733733
734+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image
734735 def _encode_vae_image (self , image : torch .Tensor , generator : torch .Generator ):
735736 if isinstance (generator , list ):
736737 image_latents = [
@@ -745,6 +746,7 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
745746
746747 return image_latents
747748
749+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents
748750 def prepare_mask_latents (
749751 self , mask , masked_image , batch_size , height , width , dtype , device , generator , do_classifier_free_guidance
750752 ):
@@ -786,6 +788,9 @@ def prepare_mask_latents(
786788 torch .cat ([masked_image_latents ] * 2 ) if do_classifier_free_guidance else masked_image_latents
787789 )
788790
791+ # star
792+
793+
789794 # aligning device to prevent device errors when concating it with the latent model input
790795 masked_image_latents = masked_image_latents .to (device = device , dtype = dtype )
791796 return mask , masked_image_latents
@@ -996,23 +1001,8 @@ def __call__(
9961001 height = height or self .unet .config .sample_size * self .vae_scale_factor
9971002 width = width or self .unet .config .sample_size * self .vae_scale_factor
9981003 # to deal with lora scaling and other possible forward hooks
999-
1004+
10001005 # 1. Check inputs. Raise error if not correct
1001- # prompt,
1002- # image,
1003- # mask_image,
1004- # height,
1005- # width,
1006- # strength,
1007- # callback_steps,
1008- # output_type,
1009- # negative_prompt=None,
1010- # prompt_embeds=None,
1011- # negative_prompt_embeds=None,
1012- # ip_adapter_image=None,
1013- # ip_adapter_image_embeds=None,
1014- # callback_on_step_end_tensor_inputs=None,
1015- # padding_mask_crop=None,
10161006 self .check_inputs (
10171007 prompt ,
10181008 image ,
@@ -1066,7 +1056,7 @@ def __call__(
10661056 clip_skip = self .clip_skip ,
10671057 )
10681058
1069- # 4. set timesteps
1059+ # 4. set timesteps
10701060 timesteps , num_inference_steps = retrieve_timesteps (
10711061 self .scheduler , num_inference_steps , device , timesteps , sigmas
10721062 )
@@ -1098,7 +1088,7 @@ def __call__(
10981088 )
10991089 init_image = init_image .to (dtype = torch .float32 )
11001090
1101- # 6. Prepare latent variables
1091+ # 6. Prepare latent variables
11021092 num_channels_latents = self .vae .config .latent_channels
11031093 num_channels_unet = self .unet .config .in_channels
11041094 return_image_latents = num_channels_unet == 4
@@ -1171,7 +1161,7 @@ def __call__(
11711161 raise ValueError (
11721162 f"The unet { self .unet .__class__ } should have either 4 or 9 input channels, not { self .unet .config .in_channels } ."
11731163 )
1174- # 8.1 Prepare extra step kwargs.
1164+ # 9 Prepare extra step kwargs.
11751165 extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
11761166
11771167 # For classifier free guidance, we need to do two forward passes.
@@ -1210,22 +1200,22 @@ def __call__(
12101200
12111201
12121202
1213- # 6 .1 Add image embeds for IP-Adapter
1203+ # 9 .1 Add image embeds for IP-Adapter
12141204 added_cond_kwargs = (
12151205 {"image_embeds" : ip_adapter_image_embeds }
12161206 if (ip_adapter_image is not None or ip_adapter_image_embeds is not None )
12171207 else None
12181208 )
12191209
1220- # 6 .2 Optionally get Guidance Scale Embedding
1210+ # 9 .2 Optionally get Guidance Scale Embedding
12211211 timestep_cond = None
12221212 if self .unet .config .time_cond_proj_dim is not None :
12231213 guidance_scale_tensor = torch .tensor (self .guidance_scale - 1 ).repeat (batch_size * num_images_per_prompt )
12241214 timestep_cond = self .get_guidance_scale_embedding (
12251215 guidance_scale_tensor , embedding_dim = self .unet .config .time_cond_proj_dim
12261216 ).to (device = device , dtype = latents .dtype )
12271217
1228- # 7 . Denoising loop
1218+ # 10 . Denoising loop
12291219 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
12301220
12311221
0 commit comments