Skip to content

Commit 24ff10e

Browse files
authored
Merge pull request #15 from huggingface/final-update
final updates: schedule, default guidance_scale
2 parents 29b02b8 + fc1bd89 commit 24ff10e

File tree

1 file changed

+17
-20
lines changed

1 file changed

+17
-20
lines changed

src/diffusers/pipelines/flux2/pipeline_flux2.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,18 @@ def format_text_input(prompts: List[str], system_message: str = None):
7979
]
8080

8181

82-
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
83-
def calculate_shift(
84-
image_seq_len,
85-
base_seq_len: int = 256,
86-
max_seq_len: int = 4096,
87-
base_shift: float = 0.5,
88-
max_shift: float = 1.15,
89-
):
90-
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
91-
b = base_shift - m * base_seq_len
92-
mu = image_seq_len * m + b
93-
return mu
82+
83+
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
84+
a1, b1 = 0.00020573, 1.85733333
85+
a2, b2 = 0.00016927, 0.45666666
86+
87+
m_200 = a2 * image_seq_len + b2
88+
m_30 = a1 * image_seq_len + b1
89+
90+
a = (m_200 - m_30) / 170.0
91+
b = m_200 - 200.0 * a
92+
mu = a * num_steps + b
93+
return float(mu)
9494

9595

9696
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
@@ -608,7 +608,7 @@ def __call__(
608608
width: Optional[int] = None,
609609
num_inference_steps: int = 50,
610610
sigmas: Optional[List[float]] = None,
611-
guidance_scale: Optional[float] = 2.5,
611+
guidance_scale: Optional[float] = 4.0,
612612
num_images_per_prompt: int = 1,
613613
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
614614
latents: Optional[torch.Tensor] = None,
@@ -783,13 +783,10 @@ def __call__(
783783
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
784784
sigmas = None
785785
image_seq_len = latents.shape[1]
786-
mu = calculate_shift(
787-
image_seq_len,
788-
self.scheduler.config.get("base_image_seq_len", 256),
789-
self.scheduler.config.get("max_image_seq_len", 4096),
790-
self.scheduler.config.get("base_shift", 0.5),
791-
self.scheduler.config.get("max_shift", 1.15),
792-
)
786+
mu = compute_empirical_mu(
787+
image_seq_len=image_seq_len,
788+
num_steps= num_inference_steps,
789+
)
793790
timesteps, num_inference_steps = retrieve_timesteps(
794791
self.scheduler,
795792
num_inference_steps,

0 commit comments

Comments
 (0)