6868"""
6969
7070
71+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
72+ def calculate_shift (
73+ image_seq_len ,
74+ base_seq_len : int = 256 ,
75+ max_seq_len : int = 4096 ,
76+ base_shift : float = 0.5 ,
77+ max_shift : float = 1.16 ,
78+ ):
79+ m = (max_shift - base_shift ) / (max_seq_len - base_seq_len )
80+ b = base_shift - m * base_seq_len
81+ mu = image_seq_len * m + b
82+ return mu
83+
84+
7185# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
7286def retrieve_timesteps (
7387 scheduler ,
@@ -884,18 +898,7 @@ def __call__(
884898 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
885899 pooled_prompt_embeds = torch .cat ([negative_pooled_prompt_embeds , pooled_prompt_embeds ], dim = 0 )
886900
887- # 4. Prepare timesteps
888- timesteps , num_inference_steps = retrieve_timesteps (
889- self .scheduler ,
890- num_inference_steps ,
891- device ,
892- sigmas = sigmas ,
893- mu = mu ,
894- )
895- num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
896- self ._num_timesteps = len (timesteps )
897-
898- # 5. Prepare latent variables
901+ # 4. Prepare latent variables
899902 num_channels_latents = self .transformer .config .in_channels
900903 latents = self .prepare_latents (
901904 batch_size * num_images_per_prompt ,
@@ -908,6 +911,29 @@ def __call__(
908911 latents ,
909912 )
910913
914+ # 5. Prepare timesteps
915+ if self .scheduler .config .use_dynamic_shifting and mu is None :
916+ _ , _ , height , width = latents .shape
917+ image_seq_len = (height // self .transformer .config .patch_size ) * (
918+ width // self .transformer .config .patch_size
919+ )
920+ mu = calculate_shift (
921+ image_seq_len ,
922+ self .scheduler .config .base_image_seq_len ,
923+ self .scheduler .config .max_image_seq_len ,
924+ self .scheduler .config .base_shift ,
925+ self .scheduler .config .max_shift ,
926+ )
927+ timesteps , num_inference_steps = retrieve_timesteps (
928+ self .scheduler ,
929+ num_inference_steps ,
930+ device ,
931+ sigmas = sigmas ,
932+ mu = mu ,
933+ )
934+ num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
935+ self ._num_timesteps = len (timesteps )
936+
911937 # 6. Denoising loop
912938 with self .progress_bar (total = num_inference_steps ) as progress_bar :
913939 for i , t in enumerate (timesteps ):
0 commit comments