@@ -325,6 +325,102 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
325325        return  pipeline , state 
326326
327327
328+ class  StableDiffusionXLVAEEncoderStep (PipelineBlock ):
329+     expected_components  =  ["vae" ]
330+     expected_auxiliaries  =  ["image_processor" ]
331+ 
332+     @property  
333+     def  inputs (self ) ->  List [Tuple [str , Any ]]:
334+         return  [
335+             ("image" , None ),
336+             ("generator" , None ),
337+             ("height" , None ),
338+             ("width" , None ),
339+             ("device" , None ),
340+             ("dtype" , None ),
341+         ]
342+ 
343+     @property  
344+     def  intermediates_inputs (self ) ->  List [str ]:
345+         return  ["batch_size" ]
346+ 
347+     @property  
348+     def  intermediates_outputs (self ) ->  List [str ]:
349+         return  ["image_latents" ]
350+ 
351+     def  __init__ (self , vae = None ):
352+         super ().__init__ (vae = vae )
353+         self .image_processor  =  VaeImageProcessor ()
354+         self .auxiliaries ["image_processor" ] =  self .image_processor 
355+ 
356+     @torch .no_grad () 
357+     def  __call__ (self , pipeline , state : PipelineState ) ->  PipelineState :
358+         image  =  state .get_input ("image" )
359+         generator  =  state .get_input ("generator" )
360+         height  =  state .get_input ("height" )
361+         width  =  state .get_input ("width" )
362+         device  =  state .get_input ("device" )
363+         dtype  =  state .get_input ("dtype" )
364+ 
365+         batch_size  =  state .get_intermediate ("batch_size" )
366+ 
367+         if  device  is  None :
368+             device  =  pipeline ._execution_device 
369+         if  dtype  is  None :
370+             dtype  =  pipeline .vae .dtype 
371+ 
372+         image  =  pipeline .image_processor .preprocess (image , height = height , width = width )
373+         image  =  image .to (device = device , dtype = dtype )
374+ 
375+         latents_mean  =  latents_std  =  None 
376+         if  hasattr (pipeline .vae .config , "latents_mean" ) and  pipeline .vae .config .latents_mean  is  not None :
377+             latents_mean  =  torch .tensor (pipeline .vae .config .latents_mean ).view (1 , 4 , 1 , 1 )
378+         if  hasattr (pipeline .vae .config , "latents_std" ) and  pipeline .vae .config .latents_std  is  not None :
379+             latents_std  =  torch .tensor (pipeline .vae .config .latents_std ).view (1 , 4 , 1 , 1 )
380+ 
381+         # make sure the VAE is in float32 mode, as it overflows in float16 
382+         if  pipeline .vae .config .force_upcast :
383+             image  =  image .float ()
384+             pipeline .vae .to (dtype = torch .float32 )
385+ 
386+         if  isinstance (generator , list ) and  len (generator ) !=  batch_size :
387+             raise  ValueError (
388+                 f"You have passed a list of generators of length { len (generator )}  
389+                 f" size of { batch_size }  
390+             )
391+ 
392+         elif  isinstance (generator , list ):
393+             if  image .shape [0 ] <  batch_size  and  batch_size  %  image .shape [0 ] ==  0 :
394+                 image  =  torch .cat ([image ] *  (batch_size  //  image .shape [0 ]), dim = 0 )
395+             elif  image .shape [0 ] <  batch_size  and  batch_size  %  image .shape [0 ] !=  0 :
396+                 raise  ValueError (
397+                     f"Cannot duplicate `image` of batch size { image .shape [0 ]} { batch_size }  
398+                 )
399+ 
400+             init_latents  =  [
401+                 retrieve_latents (pipeline .vae .encode (image [i  : i  +  1 ]), generator = generator [i ])
402+                 for  i  in  range (batch_size )
403+             ]
404+             init_latents  =  torch .cat (init_latents , dim = 0 )
405+         else :
406+             init_latents  =  retrieve_latents (pipeline .vae .encode (image ), generator = generator )
407+ 
408+         if  pipeline .vae .config .force_upcast :
409+             pipeline .vae .to (dtype )
410+ 
411+         init_latents  =  init_latents .to (dtype )
412+         if  latents_mean  is  not None  and  latents_std  is  not None :
413+             latents_mean  =  latents_mean .to (device = device , dtype = dtype )
414+             latents_std  =  latents_std .to (device = device , dtype = dtype )
415+             init_latents  =  (init_latents  -  latents_mean ) *  pipeline .vae .config .scaling_factor  /  latents_std 
416+         else :
417+             init_latents  =  pipeline .vae .config .scaling_factor  *  init_latents 
418+ 
419+         state .add_intermediate ("image_latents" , init_latents )
420+ 
421+         return  pipeline , state 
422+ 
423+ 
328424class  StableDiffusionXLImg2ImgSetTimestepsStep (PipelineBlock ):
329425    expected_components  =  ["scheduler" ]
330426
@@ -498,9 +594,9 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin
498594        denoising_start  =  state .get_input ("denoising_start" )
499595
500596        batch_size  =  state .get_intermediate ("batch_size" )
501-         prompt_embeds  =  state .get_intermediate ("prompt_embeds" ,  None )
597+         prompt_embeds  =  state .get_intermediate ("prompt_embeds" )
502598        # image to image only 
503-         latent_timestep  =  state .get_intermediate ("latent_timestep" ,  None )
599+         latent_timestep  =  state .get_intermediate ("latent_timestep" )
504600
505601        if  dtype  is  None  and  prompt_embeds  is  not None :
506602            dtype  =  prompt_embeds .dtype 
@@ -1872,12 +1968,6 @@ def prepare_latents_img2img(
18721968                f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is { type (image )}  
18731969            )
18741970
1875-         latents_mean  =  latents_std  =  None 
1876-         if  hasattr (self .vae .config , "latents_mean" ) and  self .vae .config .latents_mean  is  not None :
1877-             latents_mean  =  torch .tensor (self .vae .config .latents_mean ).view (1 , 4 , 1 , 1 )
1878-         if  hasattr (self .vae .config , "latents_std" ) and  self .vae .config .latents_std  is  not None :
1879-             latents_std  =  torch .tensor (self .vae .config .latents_std ).view (1 , 4 , 1 , 1 )
1880- 
18811971        # Offload text encoder if `enable_model_cpu_offload` was enabled 
18821972        if  hasattr (self , "final_offload_hook" ) and  self .final_offload_hook  is  not None :
18831973            self .text_encoder_2 .to ("cpu" )
@@ -1891,6 +1981,11 @@ def prepare_latents_img2img(
18911981            init_latents  =  image 
18921982
18931983        else :
1984+             latents_mean  =  latents_std  =  None 
1985+             if  hasattr (self .vae .config , "latents_mean" ) and  self .vae .config .latents_mean  is  not None :
1986+                 latents_mean  =  torch .tensor (self .vae .config .latents_mean ).view (1 , 4 , 1 , 1 )
1987+             if  hasattr (self .vae .config , "latents_std" ) and  self .vae .config .latents_std  is  not None :
1988+                 latents_std  =  torch .tensor (self .vae .config .latents_std ).view (1 , 4 , 1 , 1 )
18941989            # make sure the VAE is in float32 mode, as it overflows in float16 
18951990            if  self .vae .config .force_upcast :
18961991                image  =  image .float ()
0 commit comments