@@ -550,8 +550,7 @@ def prepare_latents(
550550
551551 noise = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
552552 import numpy as np
553- sigma = timestep [0 ]
554- latents = sigma * noise + (1.0 - sigma ) * image_latents
553+ latents = self .scheduler .scale_noise (image_latents , timestep , noise )
555554 latents = self ._pack_latents (latents , batch_size , num_channels_latents , height , width )
556555 np .save ("reference_image_latent.npy" , latents .detach ().cpu ().float ().numpy ())
557556 return latents , latent_image_ids
@@ -761,23 +760,26 @@ def __call__(
761760 max_sequence_length = max_sequence_length ,
762761 lora_scale = lora_scale ,
763762 )
764- import math
765- def time_shift (mu : float , sigma : float , t : torch .Tensor ):
766- return math .exp (mu ) / (math .exp (mu ) + (1 / t - 1 ) ** sigma )
767763
768-
769- image_seq_len = (int (height ) // self .vae_scale_factor ) * (int (width ) // self .vae_scale_factor )
770- def get_lin_function (
771- x1 : float = 256 , y1 : float = 0.5 , x2 : float = 4096 , y2 : float = 1.15
772- ) -> Callable [[float ], float ]:
773- m = (y2 - y1 ) / (x2 - x1 )
774- b = y1 - m * x1
775- return lambda x : m * x + b
776-
777- mu = get_lin_function ()(image_seq_len )
778- timesteps = torch .linspace (0 , 1 , num_inference_steps + 1 )
779- timesteps = time_shift (mu , 1.0 , timesteps ).to ("cuda" , torch .bfloat16 )
780764 # 4.Prepare timesteps
765+ sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps )
766+ image_seq_len = (int (height ) // self .vae_scale_factor ) * (int (width ) // self .vae_scale_factor )
767+ mu = calculate_shift (
768+ image_seq_len ,
769+ self .scheduler .config .base_image_seq_len ,
770+ self .scheduler .config .max_image_seq_len ,
771+ self .scheduler .config .base_shift ,
772+ self .scheduler .config .max_shift ,
773+ )
774+ timesteps , num_inference_steps = retrieve_timesteps (
775+ self .scheduler ,
776+ num_inference_steps ,
777+ device ,
778+ timesteps ,
779+ sigmas ,
780+ mu = mu ,
781+ )
782+ timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , strength , device )
781783
782784 if num_inference_steps < 1 :
783785 raise ValueError (
@@ -788,9 +790,10 @@ def get_lin_function(
788790
789791 # 5. Prepare latent variables
790792 num_channels_latents = self .transformer .config .in_channels // 4
793+
791794 latents , latent_image_ids = self .prepare_latents (
792795 init_image ,
793- timesteps ,
796+ latent_timestep ,
794797 batch_size * num_images_per_prompt ,
795798 num_channels_latents ,
796799 height ,
@@ -815,15 +818,13 @@ def get_lin_function(
815818 y1 = randn_tensor (latents .shape , generator = generator , device = device , dtype = latents .dtype )
816819 # 6. Denoising loop
817820 with self .progress_bar (total = num_inference_steps ) as progress_bar :
818- for i in range ( num_inference_steps - stop_step ):
821+ for i , t in enumerate ( timesteps ):
819822 if self .interrupt :
820823 continue
821- t = torch .tensor ([timesteps [i ]], device = latents .device , dtype = latents .dtype )
822- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
823824 timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
824825 noise_pred = self .transformer (
825826 hidden_states = latents ,
826- timestep = timestep ,
827+ timestep = timestep / 1000 ,
827828 guidance = guidance ,
828829 pooled_projections = pooled_prompt_embeds ,
829830 encoder_hidden_states = prompt_embeds ,
@@ -837,10 +838,7 @@ def get_lin_function(
837838 conditional_vector_field = (y1 - latents )/ (1 - timestep )
838839 controlled_vector_field = unconditional_vector_field + controller_guidance * (conditional_vector_field - unconditional_vector_field )
839840
840- # Get the corresponding sigma values
841- sigma = timesteps [i ]
842- sigma_next = timesteps [i + 1 ]
843- latents = latents + (sigma_next - sigma ) * controlled_vector_field
841+ latents = self .scheduler .step (controlled_vector_field , t , latents , return_dict = False )[0 ]
844842
845843 if XLA_AVAILABLE :
846844 xm .mark_step ()
0 commit comments