@@ -544,6 +544,82 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
544544 return pipeline , state
545545
546546
547+ class StableDiffusionXLInpaintPrepareLatentsStep (PipelineBlock ):
548+ expected_components = ["vae" , "scheduler" ]
549+ model_name = "stable-diffusion-xl"
550+
551+ @property
552+ def inputs (self ) -> List [Tuple [str , Any ]]:
553+ return [
554+ ("height" , None ),
555+ ("width" , None ),
556+ ("generator" , None ),
557+ ("latents" , None ),
558+ ("num_images_per_prompt" , 1 ),
559+ ("device" , None ),
560+ ("dtype" , None ),
561+ ("image" , None ),
562+ ("denoising_start" , None ),
563+ ]
564+
565+ @property
566+ def intermediates_inputs (self ) -> List [str ]:
567+ return ["batch_size" , "latent_timestep" , "prompt_embeds" ]
568+
569+ @property
570+ def intermediates_outputs (self ) -> List [str ]:
571+ return ["latents" ]
572+
573+ def __init__ (self ):
574+ super ().__init__ ()
575+ self .auxiliaries ["image_processor" ] = VaeImageProcessor ()
576+ self .components ["vae" ] = None
577+ self .components ["scheduler" ] = None
578+
579+ @torch .no_grad ()
580+ def __call__ (self , pipeline : DiffusionPipeline , state : PipelineState ) -> PipelineState :
581+ latents = state .get_input ("latents" )
582+ num_images_per_prompt = state .get_input ("num_images_per_prompt" )
583+ generator = state .get_input ("generator" )
584+ device = state .get_input ("device" )
585+ dtype = state .get_input ("dtype" )
586+
587+ # image to image only
588+ image = state .get_input ("image" )
589+ denoising_start = state .get_input ("denoising_start" )
590+
591+ batch_size = state .get_intermediate ("batch_size" )
592+ prompt_embeds = state .get_intermediate ("prompt_embeds" )
593+ # image to image only
594+ latent_timestep = state .get_intermediate ("latent_timestep" )
595+
596+ if dtype is None and prompt_embeds is not None :
597+ dtype = prompt_embeds .dtype
598+ elif dtype is None :
599+ dtype = pipeline .vae .dtype
600+
601+ if device is None :
602+ device = pipeline ._execution_device
603+
604+ image = pipeline .image_processor .preprocess (image )
605+ add_noise = True if denoising_start is None else False
606+ if latents is None :
607+ latents = pipeline .prepare_latents_img2img (
608+ image ,
609+ latent_timestep ,
610+ batch_size ,
611+ num_images_per_prompt ,
612+ dtype ,
613+ device ,
614+ generator ,
615+ add_noise ,
616+ )
617+
618+ state .add_intermediate ("latents" , latents )
619+
620+ return pipeline , state
621+
622+
547623class StableDiffusionXLImg2ImgPrepareLatentsStep (PipelineBlock ):
548624 expected_components = ["vae" , "scheduler" ]
549625 model_name = "stable-diffusion-xl"
@@ -2026,6 +2102,100 @@ def prepare_latents_img2img(
20262102
20272103 return latents
20282104
2105+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents
2106+ def prepare_latents_inpaint (
2107+ self ,
2108+ batch_size ,
2109+ num_channels_latents ,
2110+ height ,
2111+ width ,
2112+ dtype ,
2113+ device ,
2114+ generator ,
2115+ latents = None ,
2116+ image = None ,
2117+ timestep = None ,
2118+ is_strength_max = True ,
2119+ add_noise = True ,
2120+ return_noise = False ,
2121+ return_image_latents = False ,
2122+ ):
2123+ shape = (
2124+ batch_size ,
2125+ num_channels_latents ,
2126+ int (height ) // self .vae_scale_factor ,
2127+ int (width ) // self .vae_scale_factor ,
2128+ )
2129+ if isinstance (generator , list ) and len (generator ) != batch_size :
2130+ raise ValueError (
2131+ f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
2132+ f" size of { batch_size } . Make sure the batch size matches the length of the generators."
2133+ )
2134+
2135+ if (image is None or timestep is None ) and not is_strength_max :
2136+ raise ValueError (
2137+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
2138+ "However, either the image or the noise timestep has not been provided."
2139+ )
2140+
2141+ if image .shape [1 ] == 4 :
2142+ image_latents = image .to (device = device , dtype = dtype )
2143+ image_latents = image_latents .repeat (batch_size // image_latents .shape [0 ], 1 , 1 , 1 )
2144+ elif return_image_latents or (latents is None and not is_strength_max ):
2145+ image = image .to (device = device , dtype = dtype )
2146+ image_latents = self ._encode_vae_image (image = image , generator = generator )
2147+ image_latents = image_latents .repeat (batch_size // image_latents .shape [0 ], 1 , 1 , 1 )
2148+
2149+ if latents is None and add_noise :
2150+ noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
2151+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
2152+ latents = noise if is_strength_max else self .scheduler .add_noise (image_latents , noise , timestep )
2153+ # if pure noise then scale the initial latents by the Scheduler's init sigma
2154+ latents = latents * self .scheduler .init_noise_sigma if is_strength_max else latents
2155+ elif add_noise :
2156+ noise = latents .to (device )
2157+ latents = noise * self .scheduler .init_noise_sigma
2158+ else :
2159+ noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
2160+ latents = image_latents .to (device )
2161+
2162+ outputs = (latents ,)
2163+
2164+ if return_noise :
2165+ outputs += (noise ,)
2166+
2167+ if return_image_latents :
2168+ outputs += (image_latents ,)
2169+
2170+ return outputs
2171+
2172+
2173+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image
2174+ def _encode_vae_image (self , image : torch .Tensor , generator : torch .Generator ):
2175+ dtype = image .dtype
2176+ if self .vae .config .force_upcast :
2177+ image = image .float ()
2178+ self .vae .to (dtype = torch .float32 )
2179+
2180+ if isinstance (generator , list ):
2181+ image_latents = [
2182+ retrieve_latents (self .vae .encode (image [i : i + 1 ]), generator = generator [i ])
2183+ for i in range (image .shape [0 ])
2184+ ]
2185+ image_latents = torch .cat (image_latents , dim = 0 )
2186+ else :
2187+ image_latents = retrieve_latents (self .vae .encode (image ), generator = generator )
2188+
2189+ if self .vae .config .force_upcast :
2190+ self .vae .to (dtype )
2191+
2192+ image_latents = image_latents .to (dtype )
2193+ image_latents = self .vae .config .scaling_factor * image_latents
2194+
2195+ return image_latents
2196+
2197+
2198+
20292199 # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
20302200 def prepare_extra_step_kwargs (self , generator , eta ):
20312201 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
0 commit comments