@@ -504,9 +504,7 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
504504 else :
505505 image_latents = retrieve_latents (self .vae .encode (image ), generator = generator )
506506
507- image_latents = (
508- image_latents - self .vae .config .shift_factor
509- ) * self .vae .config .scaling_factor
507+ image_latents = (image_latents - self .vae .config .shift_factor ) * self .vae .config .scaling_factor
510508
511509 return image_latents
512510
@@ -877,9 +875,7 @@ def __call__(
877875
878876 # 3. Prepare prompt embeddings
879877 lora_scale = (
880- self .joint_attention_kwargs .get ("scale" , None )
881- if self .joint_attention_kwargs is not None
882- else None
878+ self .joint_attention_kwargs .get ("scale" , None ) if self .joint_attention_kwargs is not None else None
883879 )
884880 (
885881 prompt_embeds ,
@@ -897,14 +893,8 @@ def __call__(
897893 )
898894
899895 # 4. Prepare timesteps
900- sigmas = (
901- np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps )
902- if sigmas is None
903- else sigmas
904- )
905- image_seq_len = (int (height ) // self .vae_scale_factor // 2 ) * (
906- int (width ) // self .vae_scale_factor // 2
907- )
896+ sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps ) if sigmas is None else sigmas
897+ image_seq_len = (int (height ) // self .vae_scale_factor // 2 ) * (int (width ) // self .vae_scale_factor // 2 )
908898 mu = calculate_shift (
909899 image_seq_len ,
910900 self .scheduler .config .base_image_seq_len ,
@@ -967,7 +957,6 @@ def __call__(
967957 )
968958 masked_image_latents = torch .cat ((masked_image_latents , mask ), dim = - 1 )
969959
970-
971960 num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
972961 self ._num_timesteps = len (timesteps )
973962
0 commit comments