Skip to content

Commit 823f4c3

Browse files
authored
up (#17)
1 parent 4ef0285 commit 823f4c3

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

src/diffusers/pipelines/flux2/pipeline_flux2.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,7 @@
2323
from ...loaders import Flux2LoraLoaderMixin
2424
from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel
2525
from ...schedulers import FlowMatchEulerDiscreteScheduler
26-
from ...utils import (
27-
is_torch_xla_available,
28-
logging,
29-
replace_example_docstring,
30-
)
26+
from ...utils import is_torch_xla_available, logging, replace_example_docstring
3127
from ...utils.torch_utils import randn_tensor
3228
from ..pipeline_utils import DiffusionPipeline
3329
from .image_processor import Flux2ImageProcessor
@@ -79,17 +75,21 @@ def format_text_input(prompts: List[str], system_message: str = None):
7975
]
8076

8177

82-
8378
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
84-
a1, b1 = 0.00020573, 1.85733333
79+
a1, b1 = 8.73809524e-05, 1.89833333
8580
a2, b2 = 0.00016927, 0.45666666
8681

82+
if image_seq_len > 4300:
83+
mu = a2 * image_seq_len + b2
84+
return float(mu)
85+
8786
m_200 = a2 * image_seq_len + b2
88-
m_30 = a1 * image_seq_len + b1
87+
m_10 = a1 * image_seq_len + b1
8988

90-
a = (m_200 - m_30) / 170.0
89+
a = (m_200 - m_10) / 190.0
9190
b = m_200 - 200.0 * a
9291
mu = a * num_steps + b
92+
9393
return float(mu)
9494

9595

@@ -171,7 +171,7 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
171171
r"""
172172
The Flux2 pipeline for text-to-image generation.
173173
174-
Reference: TODO
174+
Reference: [https://bfl.ai/blog/flux-2](https://bfl.ai/blog/flux-2)
175175
176176
Args:
177177
transformer ([`Flux2Transformer2DModel`]):
@@ -783,10 +783,7 @@ def __call__(
783783
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
784784
sigmas = None
785785
image_seq_len = latents.shape[1]
786-
mu = compute_empirical_mu(
787-
image_seq_len=image_seq_len,
788-
num_steps= num_inference_steps,
789-
)
786+
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
790787
timesteps, num_inference_steps = retrieve_timesteps(
791788
self.scheduler,
792789
num_inference_steps,

0 commit comments

Comments
 (0)