@@ -504,9 +504,7 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
504504        else :
505505            image_latents  =  retrieve_latents (self .vae .encode (image ), generator = generator )
506506
507-         image_latents  =  (
508-             image_latents  -  self .vae .config .shift_factor 
509-         ) *  self .vae .config .scaling_factor 
507+         image_latents  =  (image_latents  -  self .vae .config .shift_factor ) *  self .vae .config .scaling_factor 
510508
511509        return  image_latents 
512510
@@ -877,9 +875,7 @@ def __call__(
877875
878876        # 3. Prepare prompt embeddings 
879877        lora_scale  =  (
880-             self .joint_attention_kwargs .get ("scale" , None )
881-             if  self .joint_attention_kwargs  is  not   None 
882-             else  None 
878+             self .joint_attention_kwargs .get ("scale" , None ) if  self .joint_attention_kwargs  is  not   None  else  None 
883879        )
884880        (
885881            prompt_embeds ,
@@ -897,14 +893,8 @@ def __call__(
897893        )
898894
899895        # 4. Prepare timesteps 
900-         sigmas  =  (
901-             np .linspace (1.0 , 1  /  num_inference_steps , num_inference_steps )
902-             if  sigmas  is  None 
903-             else  sigmas 
904-         )
905-         image_seq_len  =  (int (height ) //  self .vae_scale_factor  //  2 ) *  (
906-             int (width ) //  self .vae_scale_factor  //  2 
907-         )
896+         sigmas  =  np .linspace (1.0 , 1  /  num_inference_steps , num_inference_steps ) if  sigmas  is  None  else  sigmas 
897+         image_seq_len  =  (int (height ) //  self .vae_scale_factor  //  2 ) *  (int (width ) //  self .vae_scale_factor  //  2 )
908898        mu  =  calculate_shift (
909899            image_seq_len ,
910900            self .scheduler .config .base_image_seq_len ,
@@ -967,7 +957,6 @@ def __call__(
967957            )
968958            masked_image_latents  =  torch .cat ((masked_image_latents , mask ), dim = - 1 )
969959
970- 
971960        num_warmup_steps  =  max (len (timesteps ) -  num_inference_steps  *  self .scheduler .order , 0 )
972961        self ._num_timesteps  =  len (timesteps )
973962
0 commit comments