Skip to content

Commit f0c6d97

Browse files
vladmandichlky
andauthored
flux: make scheduler config params optional (#10384)
* dont assume scheduler has optional config params * make style, make fix-copies * calculate_shift * fix-copies, usage in pipelines --------- Co-authored-by: hlky <[email protected]>
1 parent d006f07 commit f0c6d97

19 files changed

+78
-89
lines changed

examples/community/pipeline_flux_differential_img2img.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -875,10 +875,10 @@ def __call__(
875875
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
876876
mu = calculate_shift(
877877
image_seq_len,
878-
self.scheduler.config.base_image_seq_len,
879-
self.scheduler.config.max_image_seq_len,
880-
self.scheduler.config.base_shift,
881-
self.scheduler.config.max_shift,
878+
self.scheduler.config.get("base_image_seq_len", 256),
879+
self.scheduler.config.get("max_image_seq_len", 4096),
880+
self.scheduler.config.get("base_shift", 0.5),
881+
self.scheduler.config.get("max_shift", 1.16),
882882
)
883883
timesteps, num_inference_steps = retrieve_timesteps(
884884
self.scheduler,

examples/community/pipeline_flux_rf_inversion.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -820,10 +820,10 @@ def __call__(
820820
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
821821
mu = calculate_shift(
822822
image_seq_len,
823-
self.scheduler.config.base_image_seq_len,
824-
self.scheduler.config.max_image_seq_len,
825-
self.scheduler.config.base_shift,
826-
self.scheduler.config.max_shift,
823+
self.scheduler.config.get("base_image_seq_len", 256),
824+
self.scheduler.config.get("max_image_seq_len", 4096),
825+
self.scheduler.config.get("base_shift", 0.5),
826+
self.scheduler.config.get("max_shift", 1.16),
827827
)
828828
timesteps, num_inference_steps = retrieve_timesteps(
829829
self.scheduler,
@@ -990,10 +990,10 @@ def invert(
990990
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
991991
mu = calculate_shift(
992992
image_seq_len,
993-
self.scheduler.config.base_image_seq_len,
994-
self.scheduler.config.max_image_seq_len,
995-
self.scheduler.config.base_shift,
996-
self.scheduler.config.max_shift,
993+
self.scheduler.config.get("base_image_seq_len", 256),
994+
self.scheduler.config.get("max_image_seq_len", 4096),
995+
self.scheduler.config.get("base_shift", 0.5),
996+
self.scheduler.config.get("max_shift", 1.16),
997997
)
998998
timesteps, num_inversion_steps = retrieve_timesteps(
999999
self.scheduler,

examples/community/pipeline_flux_with_cfg.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
"""
6565

6666

67+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
6768
def calculate_shift(
6869
image_seq_len,
6970
base_seq_len: int = 256,
@@ -755,10 +756,10 @@ def __call__(
755756
image_seq_len = latents.shape[1]
756757
mu = calculate_shift(
757758
image_seq_len,
758-
self.scheduler.config.base_image_seq_len,
759-
self.scheduler.config.max_image_seq_len,
760-
self.scheduler.config.base_shift,
761-
self.scheduler.config.max_shift,
759+
self.scheduler.config.get("base_image_seq_len", 256),
760+
self.scheduler.config.get("max_image_seq_len", 4096),
761+
self.scheduler.config.get("base_shift", 0.5),
762+
self.scheduler.config.get("max_shift", 1.16),
762763
)
763764
timesteps, num_inference_steps = retrieve_timesteps(
764765
self.scheduler,

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -822,10 +822,10 @@ def __call__(
822822
image_seq_len = latents.shape[1]
823823
mu = calculate_shift(
824824
image_seq_len,
825-
self.scheduler.config.base_image_seq_len,
826-
self.scheduler.config.max_image_seq_len,
827-
self.scheduler.config.base_shift,
828-
self.scheduler.config.max_shift,
825+
self.scheduler.config.get("base_image_seq_len", 256),
826+
self.scheduler.config.get("max_image_seq_len", 4096),
827+
self.scheduler.config.get("base_shift", 0.5),
828+
self.scheduler.config.get("max_shift", 1.16),
829829
)
830830
timesteps, num_inference_steps = retrieve_timesteps(
831831
self.scheduler,

src/diffusers/pipelines/flux/pipeline_flux_control.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
"""
8383

8484

85+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
8586
def calculate_shift(
8687
image_seq_len,
8788
base_seq_len: int = 256,
@@ -798,10 +799,10 @@ def __call__(
798799
image_seq_len = latents.shape[1]
799800
mu = calculate_shift(
800801
image_seq_len,
801-
self.scheduler.config.base_image_seq_len,
802-
self.scheduler.config.max_image_seq_len,
803-
self.scheduler.config.base_shift,
804-
self.scheduler.config.max_shift,
802+
self.scheduler.config.get("base_image_seq_len", 256),
803+
self.scheduler.config.get("max_image_seq_len", 4096),
804+
self.scheduler.config.get("base_shift", 0.5),
805+
self.scheduler.config.get("max_shift", 1.16),
805806
)
806807
timesteps, num_inference_steps = retrieve_timesteps(
807808
self.scheduler,

src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -807,10 +807,10 @@ def __call__(
807807
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
808808
mu = calculate_shift(
809809
image_seq_len,
810-
self.scheduler.config.base_image_seq_len,
811-
self.scheduler.config.max_image_seq_len,
812-
self.scheduler.config.base_shift,
813-
self.scheduler.config.max_shift,
810+
self.scheduler.config.get("base_image_seq_len", 256),
811+
self.scheduler.config.get("max_image_seq_len", 4096),
812+
self.scheduler.config.get("base_shift", 0.5),
813+
self.scheduler.config.get("max_shift", 1.16),
814814
)
815815
timesteps, num_inference_steps = retrieve_timesteps(
816816
self.scheduler,

src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -984,10 +984,10 @@ def __call__(
984984
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
985985
mu = calculate_shift(
986986
image_seq_len,
987-
self.scheduler.config.base_image_seq_len,
988-
self.scheduler.config.max_image_seq_len,
989-
self.scheduler.config.base_shift,
990-
self.scheduler.config.max_shift,
987+
self.scheduler.config.get("base_image_seq_len", 256),
988+
self.scheduler.config.get("max_image_seq_len", 4096),
989+
self.scheduler.config.get("base_shift", 0.5),
990+
self.scheduler.config.get("max_shift", 1.16),
991991
)
992992
timesteps, num_inference_steps = retrieve_timesteps(
993993
self.scheduler,

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -874,10 +874,10 @@ def __call__(
874874
image_seq_len = latents.shape[1]
875875
mu = calculate_shift(
876876
image_seq_len,
877-
self.scheduler.config.base_image_seq_len,
878-
self.scheduler.config.max_image_seq_len,
879-
self.scheduler.config.base_shift,
880-
self.scheduler.config.max_shift,
877+
self.scheduler.config.get("base_image_seq_len", 256),
878+
self.scheduler.config.get("max_image_seq_len", 4096),
879+
self.scheduler.config.get("base_shift", 0.5),
880+
self.scheduler.config.get("max_shift", 1.16),
881881
)
882882
timesteps, num_inference_steps = retrieve_timesteps(
883883
self.scheduler,

src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -862,10 +862,10 @@ def __call__(
862862
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
863863
mu = calculate_shift(
864864
image_seq_len,
865-
self.scheduler.config.base_image_seq_len,
866-
self.scheduler.config.max_image_seq_len,
867-
self.scheduler.config.base_shift,
868-
self.scheduler.config.max_shift,
865+
self.scheduler.config.get("base_image_seq_len", 256),
866+
self.scheduler.config.get("max_image_seq_len", 4096),
867+
self.scheduler.config.get("base_shift", 0.5),
868+
self.scheduler.config.get("max_shift", 1.16),
869869
)
870870
timesteps, num_inference_steps = retrieve_timesteps(
871871
self.scheduler,

src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,10 +1016,10 @@ def __call__(
10161016
)
10171017
mu = calculate_shift(
10181018
image_seq_len,
1019-
self.scheduler.config.base_image_seq_len,
1020-
self.scheduler.config.max_image_seq_len,
1021-
self.scheduler.config.base_shift,
1022-
self.scheduler.config.max_shift,
1019+
self.scheduler.config.get("base_image_seq_len", 256),
1020+
self.scheduler.config.get("max_image_seq_len", 4096),
1021+
self.scheduler.config.get("base_shift", 0.5),
1022+
self.scheduler.config.get("max_shift", 1.16),
10231023
)
10241024
timesteps, num_inference_steps = retrieve_timesteps(
10251025
self.scheduler,

0 commit comments

Comments
 (0)