|
34 | 34 | from ...utils import ( |
35 | 35 | USE_PEFT_BACKEND, |
36 | 36 | is_torch_xla_available, |
| 37 | + deprecate, |
37 | 38 | logging, |
38 | 39 | replace_example_docstring, |
39 | 40 | scale_lora_layers, |
|
73 | 74 |
|
74 | 75 | def calculate_shift( |
75 | 76 | image_seq_len, |
76 | | - scheduler, |
| 77 | + base_seq_len: Optional[int] = 256, |
| 78 | + max_seq_len: Optional[int] = 4096, |
| 79 | + base_shift: Optional[float] = 0.5, |
| 80 | + max_shift: Optional[float] = 1.16, |
| 81 | + scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, |
77 | 82 | ): |
78 | | - base_seq_len = scheduler.config.get("base_image_seq_len", 256) |
79 | | - max_seq_len = scheduler.config.get("max_image_seq_len", 4096) |
80 | | - base_shift = scheduler.config.get("base_shift", 0.5) |
81 | | - max_shift = scheduler.config.get("max_shift", 1.16) |
| 83 | + if base_seq_len or max_seq_len or base_shift or max_shift or scheduler is None: |
| 84 | + deprecation_message = "Pass `scheduler` to `calculate_shift`." |
| 85 | + deprecate( |
| 86 | + "calculate_shift scheduler", |
| 87 | + "1.0.0", |
| 88 | + deprecation_message, |
| 89 | + standard_warn=False, |
| 90 | + ) |
| 91 | + base_seq_len = base_seq_len or scheduler.config.get("base_image_seq_len", 256) |
| 92 | + max_seq_len = max_seq_len or scheduler.config.get("max_image_seq_len", 4096) |
| 93 | + base_shift = base_shift or scheduler.config.get("base_shift", 0.5) |
| 94 | + max_shift = max_shift or scheduler.config.get("max_shift", 1.16) |
82 | 95 | m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
83 | 96 | b = base_shift - m * base_seq_len |
84 | 97 | mu = image_seq_len * m + b |
|
0 commit comments