@@ -761,23 +761,26 @@ def __call__(
761761 max_sequence_length = max_sequence_length ,
762762 lora_scale = lora_scale ,
763763 )
764- import math
765- def time_shift (mu : float , sigma : float , t : torch .Tensor ):
766- return math .exp (mu ) / (math .exp (mu ) + (1 / t - 1 ) ** sigma )
767764
768-
769- image_seq_len = (int (height ) // self .vae_scale_factor ) * (int (width ) // self .vae_scale_factor )
770- def get_lin_function (
771- x1 : float = 256 , y1 : float = 0.5 , x2 : float = 4096 , y2 : float = 1.15
772- ) -> Callable [[float ], float ]:
773- m = (y2 - y1 ) / (x2 - x1 )
774- b = y1 - m * x1
775- return lambda x : m * x + b
776-
777- mu = get_lin_function ()(image_seq_len )
778- timesteps = torch .linspace (0 , 1 , num_inference_steps + 1 )
779- timesteps = time_shift (mu , 1.0 , timesteps ).to ("cuda" , torch .bfloat16 )
780765 # 4.Prepare timesteps
766+ sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps )
767+ image_seq_len = (int (height ) // self .vae_scale_factor ) * (int (width ) // self .vae_scale_factor )
768+ mu = calculate_shift (
769+ image_seq_len ,
770+ self .scheduler .config .base_image_seq_len ,
771+ self .scheduler .config .max_image_seq_len ,
772+ self .scheduler .config .base_shift ,
773+ self .scheduler .config .max_shift ,
774+ )
775+ timesteps , num_inference_steps = retrieve_timesteps (
776+ self .scheduler ,
777+ num_inference_steps ,
778+ device ,
779+ timesteps ,
780+ sigmas ,
781+ mu = mu ,
782+ )
783+ timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , strength , device )
781784
782785 if num_inference_steps < 1 :
783786 raise ValueError (
@@ -788,9 +791,10 @@ def get_lin_function(
788791
789792 # 5. Prepare latent variables
790793 num_channels_latents = self .transformer .config .in_channels // 4
794+
791795 latents , latent_image_ids = self .prepare_latents (
792796 init_image ,
793- timesteps ,
797+ latent_timestep ,
794798 batch_size * num_images_per_prompt ,
795799 num_channels_latents ,
796800 height ,
@@ -815,15 +819,13 @@ def get_lin_function(
815819 y1 = randn_tensor (latents .shape , generator = generator , device = device , dtype = latents .dtype )
816820 # 6. Denoising loop
817821 with self .progress_bar (total = num_inference_steps ) as progress_bar :
818- for i in range ( num_inference_steps - stop_step ):
822+ for i , t in enumerate ( timesteps ):
819823 if self .interrupt :
820824 continue
821- t = torch .tensor ([timesteps [i ]], device = latents .device , dtype = latents .dtype )
822- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
823825 timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
824826 noise_pred = self .transformer (
825827 hidden_states = latents ,
826- timestep = timestep ,
828+ timestep = timestep / 1000 ,
827829 guidance = guidance ,
828830 pooled_projections = pooled_prompt_embeds ,
829831 encoder_hidden_states = prompt_embeds ,
@@ -837,10 +839,7 @@ def get_lin_function(
837839 conditional_vector_field = (y1 - latents )/ (1 - timestep )
838840 controlled_vector_field = unconditional_vector_field + controller_guidance * (conditional_vector_field - unconditional_vector_field )
839841
840- # Get the corresponding sigma values
841- sigma = timesteps [i ]
842- sigma_next = timesteps [i + 1 ]
843- latents = latents + (sigma_next - sigma ) * controlled_vector_field
842+ latents = self .scheduler .step (controlled_vector_field , t , latents , return_dict = False )[0 ]
844843
845844 if XLA_AVAILABLE :
846845 xm .mark_step ()
0 commit comments