Skip to content

Commit 66806e3

Browse files
authored
dont assume scheduler has optional config params
1 parent 1b202c5 commit 66806e3

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,12 @@
7373

7474
def calculate_shift(
7575
image_seq_len,
76-
base_seq_len: int = 256,
77-
max_seq_len: int = 4096,
78-
base_shift: float = 0.5,
79-
max_shift: float = 1.16,
76+
scheduler,
8077
):
78+
base_seq_len = scheduler.config.get('base_image_seq_len', 256)
79+
max_seq_len = scheduler.config.get('max_image_seq_len', 4096)
80+
base_shift = scheduler.config.get('base_shift', 0.5)
81+
max_shift = scheduler.config.get('max_shift', 1.16)
8182
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
8283
b = base_shift - m * base_seq_len
8384
mu = image_seq_len * m + b
@@ -824,10 +825,7 @@ def __call__(
824825
image_seq_len = latents.shape[1]
825826
mu = calculate_shift(
826827
image_seq_len,
827-
self.scheduler.config.base_image_seq_len,
828-
self.scheduler.config.max_image_seq_len,
829-
self.scheduler.config.base_shift,
830-
self.scheduler.config.max_shift,
828+
self.scheduler,
831829
)
832830
timesteps, num_inference_steps = retrieve_timesteps(
833831
self.scheduler,

0 commit comments

Comments
 (0)