6868"""
6969
7070
71+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
72+ def calculate_shift (
73+ image_seq_len ,
74+ base_seq_len : int = 256 ,
75+ max_seq_len : int = 4096 ,
76+ base_shift : float = 0.5 ,
77+ max_shift : float = 1.16 ,
78+ ):
79+ m = (max_shift - base_shift ) / (max_seq_len - base_seq_len )
80+ b = base_shift - m * base_seq_len
81+ mu = image_seq_len * m + b
82+ return mu
83+
84+
7185# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
7286def retrieve_timesteps (
7387 scheduler ,
@@ -882,12 +896,7 @@ def __call__(
882896 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
883897 pooled_prompt_embeds = torch .cat ([negative_pooled_prompt_embeds , pooled_prompt_embeds ], dim = 0 )
884898
885- # 4. Prepare timesteps
886- timesteps , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , sigmas = sigmas )
887- num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
888- self ._num_timesteps = len (timesteps )
889-
890- # 5. Prepare latent variables
899+ # 4. Prepare latent variables
891900 num_channels_latents = self .transformer .config .in_channels
892901 latents = self .prepare_latents (
893902 batch_size * num_images_per_prompt ,
@@ -900,6 +909,27 @@ def __call__(
900909 latents ,
901910 )
902911
912+ # 5. Prepare timesteps
913+ mu = None
914+ if self .scheduler .config .use_dynamic_shifting :
915+ image_seq_len = latents .shape [1 ]
916+ mu = calculate_shift (
917+ image_seq_len ,
918+ self .scheduler .config .base_image_seq_len ,
919+ self .scheduler .config .max_image_seq_len ,
920+ self .scheduler .config .base_shift ,
921+ self .scheduler .config .max_shift ,
922+ )
923+ timesteps , num_inference_steps = retrieve_timesteps (
924+ self .scheduler ,
925+ num_inference_steps ,
926+ device ,
927+ sigmas = sigmas ,
928+ mu = mu ,
929+ )
930+ num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
931+ self ._num_timesteps = len (timesteps )
932+
903933 # 6. Denoising loop
904934 with self .progress_bar (total = num_inference_steps ) as progress_bar :
905935 for i , t in enumerate (timesteps ):
0 commit comments