@@ -103,6 +103,31 @@ def __init__(
103103        self .vae_scale_factor  =  2  **  (len (self .vae .config .block_out_channels ) -  1 ) if  getattr (self , "vae" , None ) else  8 
104104        self .image_processor  =  VaeImageProcessor (vae_scale_factor = self .vae_scale_factor )
105105
106+     @staticmethod  
107+     def  calculate_shift (
108+         image_seq_len ,
109+         base_seq_len : int  =  256 ,
110+         max_seq_len : int  =  4096 ,
111+         base_shift : float  =  0.5 ,
112+         max_shift : float  =  1.15 ,
113+     ):
114+         """Calculate shift parameter based on image dimensions. 
115+          
116+         Args: 
117+             image_seq_len: Length of the image sequence (height/vae_factor/2 * width/vae_factor/2) 
118+             base_seq_len: Base sequence length for interpolation 
119+             max_seq_len: Maximum sequence length for interpolation 
120+             base_shift: Base shift value 
121+             max_shift: Maximum shift value 
122+              
123+         Returns: 
124+             Calculated shift parameter (mu) 
125+         """ 
126+         m  =  (max_shift  -  base_shift ) /  (max_seq_len  -  base_seq_len )
127+         b  =  base_shift  -  m  *  base_seq_len 
128+         mu  =  image_seq_len  *  m  +  b 
129+         return  mu 
130+ 
106131    def  check_inputs (
107132        self ,
108133        prompt ,
@@ -305,41 +330,8 @@ def encode_prompt(
305330
306331        return  prompt_embeds , prompt_attention_mask , negative_prompt_embeds , negative_prompt_attention_mask 
307332
308-     # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents 
309-     def  prepare_latents (
310-         self ,
311-         batch_size ,
312-         num_channels_latents ,
313-         height ,
314-         width ,
315-         dtype ,
316-         device ,
317-         generator ,
318-         latents = None ,
319-     ):
320-         if  latents  is  not None :
321-             return  latents .to (device = device , dtype = dtype )
322- 
323-         shape  =  (
324-             batch_size ,
325-             num_channels_latents ,
326-             int (height ) //  self .vae_scale_factor ,
327-             int (width ) //  self .vae_scale_factor ,
328-         )
329- 
330-         if  isinstance (generator , list ) and  len (generator ) !=  batch_size :
331-             raise  ValueError (
332-                 f"You have passed a list of generators of length { len (generator )}  
333-                 f" size of { batch_size }  
334-             )
335- 
336-         latents  =  randn_tensor (shape , generator = generator , device = device , dtype = dtype )
337- 
338-         return  latents 
339- 
340333    def  get_timesteps (self , num_inference_steps , strength , device ):
341334        # Set timesteps using the full range initially 
342-         self .scheduler .set_timesteps (num_inference_steps , device = device )
343335        timesteps  =  self .scheduler .timesteps .to (device = device )
344336
345337        if  len (timesteps ) !=  num_inference_steps :
@@ -349,18 +341,29 @@ def get_timesteps(self, num_inference_steps, strength, device):
349341        init_timestep  =  min (num_inference_steps  *  strength , num_inference_steps )
350342
351343        t_start  =  int (max (num_inference_steps  -  init_timestep , 0 ))
352-         timesteps  =  self .scheduler .timesteps [t_start :]
344+         timesteps  =  self .scheduler .timesteps [t_start  *  self .scheduler .order :]
345+         
346+         # Set begin index if scheduler supports it 
347+         if  hasattr (self .scheduler , "set_begin_index" ):
348+             self .scheduler .set_begin_index (t_start  *  self .scheduler .order )
353349
354350        return  timesteps , num_inference_steps  -  t_start 
355351
356-     def  prepare_img2img_latents (
352+     def  prepare_latents (
357353        self , image , timestep , batch_size , num_images_per_prompt , dtype , device , generator = None 
358354    ):
359355        if  not  isinstance (image , (torch .Tensor , PIL .Image .Image , list )):
360356            raise  ValueError (
361357                f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is { type (image )}  
362358            )
363359
360+         # Check for latents_mean and latents_std in the VAE config 
361+         latents_mean  =  latents_std  =  None 
362+         if  hasattr (self .vae .config , "latents_mean" ) and  self .vae .config .latents_mean  is  not None :
363+             latents_mean  =  torch .tensor (self .vae .config .latents_mean ).view (1 , 4 , 1 , 1 )
364+         if  hasattr (self .vae .config , "latents_std" ) and  self .vae .config .latents_std  is  not None :
365+             latents_std  =  torch .tensor (self .vae .config .latents_std ).view (1 , 4 , 1 , 1 )
366+ 
364367        image  =  image .to (device = device , dtype = dtype )
365368
366369        batch_size  =  batch_size  *  num_images_per_prompt 
@@ -404,26 +407,30 @@ def prepare_img2img_latents(
404407            if  isinstance (generator , list ):
405408                sample  =  torch .cat (
406409                    [
407-                         torch . randn (
410+                         randn_tensor (
408411                            (1 , * mean .shape [1 :]),
409412                            generator = generator [i ],
410-                             device = generator [ i ] .device   if   hasattr ( generator [ i ],  "device" )  else   "cpu" ,
413+                             device = mean .device ,
411414                            dtype = mean .dtype ,
412-                         ). to ( mean . device ) 
415+                         )
413416                        for  i  in  range (batch_size )
414417                    ]
415418                )
416419            else :
417420                # Single generator - use its device if it has one 
418-                 generator_device  =  getattr (generator , "device" , "cpu" ) if  generator  is  not None  else  "cpu" 
419-                 noise  =  torch .randn (mean .shape , generator = generator , device = generator_device , dtype = mean .dtype )
420-                 sample  =  noise .to (mean .device )
421+                 sample  =  randn_tensor (mean .shape , generator = generator , device = mean .device , dtype = mean .dtype )
421422
422423            # Compute latents 
423424            latents  =  mean  +  std  *  sample 
424425
425-             # Scale latents 
426-             latents  =  latents  *  self .vae .config .scaling_factor 
426+             # Apply standardization if VAE has mean and std defined in config 
427+             if  latents_mean  is  not None  and  latents_std  is  not None :
428+                 latents_mean  =  latents_mean .to (device = device , dtype = dtype )
429+                 latents_std  =  latents_std .to (device = device , dtype = dtype )
430+                 latents  =  (latents  -  latents_mean ) *  self .vae .config .scaling_factor  /  latents_std 
431+             else :
432+                 # Scale latents 
433+                 latents  =  latents  *  self .vae .config .scaling_factor 
427434
428435            # get the original timestep using init_timestep 
429436            init_timestep  =  timestep  # Use the passed timestep directly 
@@ -433,21 +440,20 @@ def prepare_img2img_latents(
433440            if  isinstance (generator , list ):
434441                noise  =  torch .cat (
435442                    [
436-                         torch . randn (
443+                         randn_tensor (
437444                            (1 , * latents .shape [1 :]),
438445                            generator = generator [i ],
439-                             device = generator [ i ] .device   if   hasattr ( generator [ i ],  "device" )  else   "cpu" ,
446+                             device = latents .device ,
440447                            dtype = latents .dtype ,
441-                         ). to ( latents . device ) 
448+                         )
442449                        for  i  in  range (batch_size )
443450                    ]
444451                )
445452            else :
446453                # Single generator - use its device if it has one 
447-                 generator_device  =  getattr (generator , "device" , "cpu" ) if  generator  is  not None  else  "cpu" 
448-                 noise  =  torch .randn (
449-                     latents .shape , generator = generator , device = generator_device , dtype = latents .dtype 
450-                 ).to (latents .device )
454+                 noise  =  randn_tensor (
455+                     latents .shape , generator = generator , device = latents .device , dtype = latents .dtype 
456+                 )
451457
452458            latents  =  self .scheduler .scale_noise (latents , init_timestep , noise )
453459
@@ -654,13 +660,29 @@ def __call__(
654660            prompt_embeds  =  torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
655661
656662        # 5. Prepare timesteps 
663+         # Calculate shift parameter based on image dimensions 
664+         image_seq_len  =  (int (height ) //  self .vae_scale_factor  //  2 ) *  (int (width ) //  self .vae_scale_factor  //  2 )
665+         
666+         # Calculate mu (shift parameter) based on image dimensions 
667+         mu  =  self .calculate_shift (
668+             image_seq_len ,
669+             self .scheduler .config .get ("base_image_seq_len" , 256 ),
670+             self .scheduler .config .get ("max_image_seq_len" , 4096 ),
671+             self .scheduler .config .get ("base_shift" , 0.5 ),
672+             self .scheduler .config .get ("max_shift" , 1.15 ),
673+         )
674+         
675+         # Set timesteps with shift parameter 
676+         self .scheduler .set_timesteps (num_inference_steps , device = device , mu = mu )
677+         
678+         # Now adjust for strength 
657679        timesteps , num_inference_steps  =  self .get_timesteps (
658680            num_inference_steps , strength , device 
659681        )
660682        latent_timestep  =  timesteps [:1 ].repeat (batch_size  *  num_images_per_prompt ) # Get the first timestep(s) for initial noise 
661683
662684        # 6. Prepare latent variables 
663-         latents  =  self .prepare_img2img_latents (
685+         latents  =  self .prepare_latents (
664686            image ,
665687            latent_timestep ,
666688            batch_size ,
0 commit comments