@@ -350,115 +350,66 @@ def get_timesteps(self, num_inference_steps, strength, device):
350350        return  timesteps , num_inference_steps  -  t_start 
351351
352352    def  prepare_latents (
353-         self , image , timestep , batch_size , num_images_per_prompt , dtype , device , generator = None 
353+         self ,
354+         image ,
355+         timestep ,
356+         batch_size ,
357+         num_images_per_prompt ,
358+         dtype ,
359+         device ,
360+         generator = None ,
354361    ):
355362        if  not  isinstance (image , (torch .Tensor , PIL .Image .Image , list )):
356363            raise  ValueError (
357-                 f"`image` has to  be of type  `torch.Tensor`, `PIL.Image.Image` or list but is  { type (image )}  
364+                 f"`image` must  be `torch.Tensor`, `PIL.Image.Image` or list, got  { type (image )}  
358365            )
359366
360-         # Check for latents_mean and latents_std in the VAE config 
361-         latents_mean  =  latents_std  =  None 
362-         if  hasattr (self .vae .config , "latents_mean" ) and  self .vae .config .latents_mean  is  not None :
363-             latents_mean  =  torch .tensor (self .vae .config .latents_mean ).view (1 , 4 , 1 , 1 )
364-         if  hasattr (self .vae .config , "latents_std" ) and  self .vae .config .latents_std  is  not None :
365-             latents_std  =  torch .tensor (self .vae .config .latents_std ).view (1 , 4 , 1 , 1 )
366- 
367367        image  =  image .to (device = device , dtype = dtype )
368- 
369368        batch_size  =  batch_size  *  num_images_per_prompt 
370369
371370        if  image .shape [1 ] ==  4 :
372-             latents  =  image 
371+             latents_0  =  image 
373372        else :
374-             if  isinstance (generator , list ) and  len (generator ) !=  batch_size :
375-                 raise  ValueError (
376-                     f"You have passed a list of generators of length { len (generator )}  
377-                     f" size of { batch_size }  
378-                 )
379- 
380-             # Handle different batch size scenarios 
381-             if  image .shape [0 ] <  batch_size :
382-                 if  batch_size  %  image .shape [0 ] ==  0 :
383-                     # Duplicate the image to match the batch size 
384-                     additional_image_per_prompt  =  batch_size  //  image .shape [0 ]
385-                     image  =  torch .cat ([image ] *  additional_image_per_prompt , dim = 0 )
386-                 else :
387-                     raise  ValueError (
388-                         f"Cannot duplicate `image` of batch size { image .shape [0 ]} { batch_size }  
389-                         f" Batch size must be divisible by the image batch size." 
390-                     )
391- 
392-             # Temporarily move VAE to float32 for encoding 
393-             vae_dtype  =  self .vae .dtype 
394-             if  vae_dtype  !=  torch .float32 :
373+             # VAE ⇢ latents  (ALWAYS on fp32 for numerical stability) 
374+             orig_dtype  =  self .vae .dtype 
375+             if  orig_dtype  !=  torch .float32 :
395376                self .vae .to (dtype = torch .float32 )
396377
397-             # encode the init image into latents and scale the latents 
398-             # 1. Get VAE distribution parameters (on device) 
399378            latent_dist  =  self .vae .encode (image .to (dtype = torch .float32 )).latent_dist 
400-             mean ,  std  =  latent_dist .mean ,  latent_dist . std    # Already on device 
379+             latents_0    =  latent_dist .mean                        # ❶ deterministic! 
401380
402-             # Restore VAE dtype 
403-             if  vae_dtype  !=  torch .float32 :
404-                 self .vae .to (dtype = vae_dtype )
381+             if  orig_dtype  !=  torch .float32 :
382+                 self .vae .to (dtype = orig_dtype )
405383
406-             # 2. Sample noise for each batch element individually if using multiple generators 
407-             if  isinstance (generator , list ):
408-                 sample  =  torch .cat (
409-                     [
410-                         randn_tensor (
411-                             (1 , * mean .shape [1 :]),
412-                             generator = generator [i ],
413-                             device = mean .device ,
414-                             dtype = mean .dtype ,
415-                         )
416-                         for  i  in  range (batch_size )
417-                     ]
418-                 )
419-             else :
420-                 # Single generator - use its device if it has one 
421-                 sample  =  randn_tensor (mean .shape , generator = generator , device = mean .device , dtype = mean .dtype )
384+             # scale 
385+             latents_0  =  latents_0  *  self .vae .config .scaling_factor 
422386
423-             # Compute latents 
424-             latents  =  mean  +  std  *  sample 
425- 
426-             # Apply standardization if VAE has mean and std defined in config 
427-             if  latents_mean  is  not None  and  latents_std  is  not None :
428-                 latents_mean  =  latents_mean .to (device = device , dtype = dtype )
429-                 latents_std  =  latents_std .to (device = device , dtype = dtype )
430-                 latents  =  (latents  -  latents_mean ) *  self .vae .config .scaling_factor  /  latents_std 
431-             else :
432-                 # Scale latents 
433-                 latents  =  latents  *  self .vae .config .scaling_factor 
434- 
435-             # get the original timestep using init_timestep 
436-             init_timestep  =  timestep  # Use the passed timestep directly 
437- 
438-             # add noise to latents using the timesteps 
439-             # Handle noise generation with multiple generators if provided 
440-             if  isinstance (generator , list ):
441-                 noise  =  torch .cat (
442-                     [
443-                         randn_tensor (
444-                             (1 , * latents .shape [1 :]),
445-                             generator = generator [i ],
446-                             device = latents .device ,
447-                             dtype = latents .dtype ,
448-                         )
449-                         for  i  in  range (batch_size )
450-                     ]
451-                 )
452-             else :
453-                 # Single generator - use its device if it has one 
454-                 noise  =  randn_tensor (
455-                     latents .shape , generator = generator , device = latents .device , dtype = latents .dtype 
387+         # replicate to match `batch_size` 
388+         if  latents_0 .shape [0 ] !=  batch_size :
389+             if  batch_size  %  latents_0 .shape [0 ] !=  0 :
390+                 raise  ValueError (
391+                     f"Cannot duplicate image batch of size { latents_0 .shape [0 ]}  
392+                     f"to effective batch size { batch_size }  
456393                )
394+             repeats    =  batch_size  //  latents_0 .shape [0 ]
395+             latents_0  =  latents_0 .repeat (repeats , 1 , 1 , 1 )
396+ 
397+         noise  =  randn_tensor (
398+             latents_0 .shape ,
399+             generator = generator ,
400+             device = latents_0 .device ,
401+             dtype = latents_0 .dtype ,
402+         )
457403
458-             latents  =  self .scheduler .scale_noise (latents , init_timestep , noise )
404+         # make sure `timestep` is 1-D and matches batch 
405+         if  isinstance (timestep , (int , float )):
406+             timestep  =  torch .tensor ([timestep ], device = latents_0 .device , dtype = latents_0 .dtype )
407+         timestep  =  timestep .expand (latents_0 .shape [0 ])
459408
409+         latents  =  self .scheduler .scale_noise (latents_0 , timestep , noise )
460410        return  latents 
461411
412+ 
462413    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae 
463414    def  upcast_vae (self ):
464415        dtype  =  self .vae .dtype 
0 commit comments