|
30 | 30 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler |
31 | 31 | from diffusers.utils import ( |
32 | 32 | USE_PEFT_BACKEND, |
| 33 | + deprecate, |
33 | 34 | is_torch_xla_available, |
34 | 35 | logging, |
35 | 36 | replace_example_docstring, |
|
91 | 92 | # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift |
92 | 93 | def calculate_shift( |
93 | 94 | image_seq_len, |
94 | | - base_seq_len: int = 256, |
95 | | - max_seq_len: int = 4096, |
96 | | - base_shift: float = 0.5, |
97 | | - max_shift: float = 1.16, |
| 95 | + base_seq_len: Optional[int] = 256, |
| 96 | + max_seq_len: Optional[int] = 4096, |
| 97 | + base_shift: Optional[float] = 0.5, |
| 98 | + max_shift: Optional[float] = 1.16, |
| 99 | + scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, |
98 | 100 | ): |
| 101 | + if base_seq_len or max_seq_len or base_shift or max_shift or scheduler is None: |
| 102 | + deprecation_message = "Pass `scheduler` to `calculate_shift`." |
| 103 | + deprecate( |
| 104 | + "calculate_shift scheduler", |
| 105 | + "1.0.0", |
| 106 | + deprecation_message, |
| 107 | + standard_warn=False, |
| 108 | + ) |
| 109 | + base_seq_len = base_seq_len or scheduler.config.get("base_image_seq_len", 256) |
| 110 | + max_seq_len = max_seq_len or scheduler.config.get("max_image_seq_len", 4096) |
| 111 | + base_shift = base_shift or scheduler.config.get("base_shift", 0.5) |
| 112 | + max_shift = max_shift or scheduler.config.get("max_shift", 1.16) |
99 | 113 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
100 | 114 | b = base_shift - m * base_seq_len |
101 | 115 | mu = image_seq_len * m + b |
@@ -822,10 +836,7 @@ def __call__( |
822 | 836 | image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) |
823 | 837 | mu = calculate_shift( |
824 | 838 | image_seq_len, |
825 | | - self.scheduler.config.base_image_seq_len, |
826 | | - self.scheduler.config.max_image_seq_len, |
827 | | - self.scheduler.config.base_shift, |
828 | | - self.scheduler.config.max_shift, |
| 839 | + scheduler=self.scheduler, |
829 | 840 | ) |
830 | 841 | timesteps, num_inference_steps = retrieve_timesteps( |
831 | 842 | self.scheduler, |
@@ -992,10 +1003,7 @@ def invert( |
992 | 1003 | image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) |
993 | 1004 | mu = calculate_shift( |
994 | 1005 | image_seq_len, |
995 | | - self.scheduler.config.base_image_seq_len, |
996 | | - self.scheduler.config.max_image_seq_len, |
997 | | - self.scheduler.config.base_shift, |
998 | | - self.scheduler.config.max_shift, |
| 1006 | + scheduler=self.scheduler, |
999 | 1007 | ) |
1000 | 1008 | timesteps, num_inversion_steps = retrieve_timesteps( |
1001 | 1009 | self.scheduler, |
|
0 commit comments