@@ -79,18 +79,18 @@ def format_text_input(prompts: List[str], system_message: str = None):
7979 ]
8080
8181
82- # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
83- def calculate_shift (
84- image_seq_len ,
85- base_seq_len : int = 256 ,
86- max_seq_len : int = 4096 ,
87- base_shift : float = 0.5 ,
88- max_shift : float = 1.15 ,
89- ):
90- m = (max_shift - base_shift ) / ( max_seq_len - base_seq_len )
91- b = base_shift - m * base_seq_len
92- mu = image_seq_len * m + b
93- return mu
82+
83+ def compute_empirical_mu ( image_seq_len : int , num_steps : int ) -> float :
84+ a1 , b1 = 0.00020573 , 1.85733333
85+ a2 , b2 = 0.00016927 , 0.45666666
86+
87+ m_200 = a2 * image_seq_len + b2
88+ m_30 = a1 * image_seq_len + b1
89+
90+ a = (m_200 - m_30 ) / 170.0
91+ b = m_200 - 200.0 * a
92+ mu = a * num_steps + b
93+ return float ( mu )
9494
9595
9696# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
@@ -608,7 +608,7 @@ def __call__(
608608 width : Optional [int ] = None ,
609609 num_inference_steps : int = 50 ,
610610 sigmas : Optional [List [float ]] = None ,
611- guidance_scale : Optional [float ] = 2.5 ,
611+ guidance_scale : Optional [float ] = 4.0 ,
612612 num_images_per_prompt : int = 1 ,
613613 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
614614 latents : Optional [torch .Tensor ] = None ,
@@ -783,13 +783,10 @@ def __call__(
783783 if hasattr (self .scheduler .config , "use_flow_sigmas" ) and self .scheduler .config .use_flow_sigmas :
784784 sigmas = None
785785 image_seq_len = latents .shape [1 ]
786- mu = calculate_shift (
787- image_seq_len ,
788- self .scheduler .config .get ("base_image_seq_len" , 256 ),
789- self .scheduler .config .get ("max_image_seq_len" , 4096 ),
790- self .scheduler .config .get ("base_shift" , 0.5 ),
791- self .scheduler .config .get ("max_shift" , 1.15 ),
792- )
786+ mu = compute_empirical_mu (
787+ image_seq_len = image_seq_len ,
788+ num_steps = num_inference_steps ,
789+ )
793790 timesteps , num_inference_steps = retrieve_timesteps (
794791 self .scheduler ,
795792 num_inference_steps ,
0 commit comments