Skip to content

Commit 64aee6e

Browse files
committed
calculate_shift
1 parent b3c6787 commit 64aee6e

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from ...utils import (
3535
USE_PEFT_BACKEND,
3636
is_torch_xla_available,
37+
deprecate,
3738
logging,
3839
replace_example_docstring,
3940
scale_lora_layers,
@@ -73,12 +74,24 @@
7374

7475
def calculate_shift(
7576
image_seq_len,
76-
scheduler,
77+
base_seq_len: Optional[int] = 256,
78+
max_seq_len: Optional[int] = 4096,
79+
base_shift: Optional[float] = 0.5,
80+
max_shift: Optional[float] = 1.16,
81+
scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
7782
):
78-
base_seq_len = scheduler.config.get("base_image_seq_len", 256)
79-
max_seq_len = scheduler.config.get("max_image_seq_len", 4096)
80-
base_shift = scheduler.config.get("base_shift", 0.5)
81-
max_shift = scheduler.config.get("max_shift", 1.16)
83+
if base_seq_len or max_seq_len or base_shift or max_shift or scheduler is None:
84+
deprecation_message = "Pass `scheduler` to `calculate_shift`."
85+
deprecate(
86+
"calculate_shift scheduler",
87+
"1.0.0",
88+
deprecation_message,
89+
standard_warn=False,
90+
)
91+
base_seq_len = base_seq_len or scheduler.config.get("base_image_seq_len", 256)
92+
max_seq_len = max_seq_len or scheduler.config.get("max_image_seq_len", 4096)
93+
base_shift = base_shift or scheduler.config.get("base_shift", 0.5)
94+
max_shift = max_shift or scheduler.config.get("max_shift", 1.16)
8295
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
8396
b = base_shift - m * base_seq_len
8497
mu = image_seq_len * m + b

0 commit comments

Comments
 (0)