|  | 
| 30 | 30 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler | 
| 31 | 31 | from diffusers.utils import ( | 
| 32 | 32 |     USE_PEFT_BACKEND, | 
| 33 |  | -    deprecate, | 
| 34 | 33 |     is_torch_xla_available, | 
| 35 | 34 |     logging, | 
| 36 | 35 |     replace_example_docstring, | 
|  | 
| 92 | 91 | # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift | 
| 93 | 92 | def calculate_shift( | 
| 94 | 93 |     image_seq_len, | 
| 95 |  | -    base_seq_len: Optional[int] = 256, | 
| 96 |  | -    max_seq_len: Optional[int] = 4096, | 
| 97 |  | -    base_shift: Optional[float] = 0.5, | 
| 98 |  | -    max_shift: Optional[float] = 1.16, | 
| 99 |  | -    scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, | 
|  | 94 | +    base_seq_len: int = 256, | 
|  | 95 | +    max_seq_len: int = 4096, | 
|  | 96 | +    base_shift: float = 0.5, | 
|  | 97 | +    max_shift: float = 1.16, | 
| 100 | 98 | ): | 
| 101 |  | -    if base_seq_len or max_seq_len or base_shift or max_shift or scheduler is None: | 
| 102 |  | -        deprecation_message = "Pass `scheduler` to `calculate_shift`." | 
| 103 |  | -        deprecate( | 
| 104 |  | -            "calculate_shift scheduler", | 
| 105 |  | -            "1.0.0", | 
| 106 |  | -            deprecation_message, | 
| 107 |  | -            standard_warn=False, | 
| 108 |  | -        ) | 
| 109 |  | -    base_seq_len = base_seq_len or scheduler.config.get("base_image_seq_len", 256) | 
| 110 |  | -    max_seq_len = max_seq_len or scheduler.config.get("max_image_seq_len", 4096) | 
| 111 |  | -    base_shift = base_shift or scheduler.config.get("base_shift", 0.5) | 
| 112 |  | -    max_shift = max_shift or scheduler.config.get("max_shift", 1.16) | 
| 113 | 99 |     m = (max_shift - base_shift) / (max_seq_len - base_seq_len) | 
| 114 | 100 |     b = base_shift - m * base_seq_len | 
| 115 | 101 |     mu = image_seq_len * m + b | 
| @@ -836,7 +822,10 @@ def __call__( | 
| 836 | 822 |         image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) | 
| 837 | 823 |         mu = calculate_shift( | 
| 838 | 824 |             image_seq_len, | 
| 839 |  | -            scheduler=self.scheduler, | 
|  | 825 | +            self.scheduler.config.get("base_image_seq_len", 256), | 
|  | 826 | +            self.scheduler.config.get("max_image_seq_len", 4096), | 
|  | 827 | +            self.scheduler.config.get("base_shift", 0.5), | 
|  | 828 | +            self.scheduler.config.get("max_shift", 1.16), | 
| 840 | 829 |         ) | 
| 841 | 830 |         timesteps, num_inference_steps = retrieve_timesteps( | 
| 842 | 831 |             self.scheduler, | 
| @@ -1003,7 +992,10 @@ def invert( | 
| 1003 | 992 |         image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) | 
| 1004 | 993 |         mu = calculate_shift( | 
| 1005 | 994 |             image_seq_len, | 
| 1006 |  | -            scheduler=self.scheduler, | 
|  | 995 | +            self.scheduler.config.get("base_image_seq_len", 256), | 
|  | 996 | +            self.scheduler.config.get("max_image_seq_len", 4096), | 
|  | 997 | +            self.scheduler.config.get("base_shift", 0.5), | 
|  | 998 | +            self.scheduler.config.get("max_shift", 1.16), | 
| 1007 | 999 |         ) | 
| 1008 | 1000 |         timesteps, num_inversion_steps = retrieve_timesteps( | 
| 1009 | 1001 |             self.scheduler, | 
|  | 
0 commit comments