|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import inspect |
15 | 16 | import math |
16 | 17 | from typing import List, Optional, Union |
17 | 18 |
|
|
20 | 21 | from ...utils import deprecate |
21 | 22 |
|
22 | 23 |
|
| 24 | +# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift |
| 25 | +def calculate_shift( |
| 26 | + image_seq_len, |
| 27 | + base_seq_len: int = 256, |
| 28 | + max_seq_len: int = 4096, |
| 29 | + base_shift: float = 0.5, |
| 30 | + max_shift: float = 1.15, |
| 31 | +): |
| 32 | + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
| 33 | + b = base_shift - m * base_seq_len |
| 34 | + mu = image_seq_len * m + b |
| 35 | + return mu |
| 36 | + |
| 37 | + |
| 38 | +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps |
| 39 | +def retrieve_timesteps( |
| 40 | + scheduler, |
| 41 | + num_inference_steps: Optional[int] = None, |
| 42 | + device: Optional[Union[str, torch.device]] = None, |
| 43 | + timesteps: Optional[List[int]] = None, |
| 44 | + sigmas: Optional[List[float]] = None, |
| 45 | + **kwargs, |
| 46 | +): |
| 47 | + r""" |
| 48 | + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
| 49 | + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
| 50 | +
|
| 51 | + Args: |
| 52 | + scheduler (`SchedulerMixin`): |
| 53 | + The scheduler to get timesteps from. |
| 54 | + num_inference_steps (`int`): |
| 55 | + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
| 56 | + must be `None`. |
| 57 | + device (`str` or `torch.device`, *optional*): |
| 58 | + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
| 59 | + timesteps (`List[int]`, *optional*): |
| 60 | + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
| 61 | + `num_inference_steps` and `sigmas` must be `None`. |
| 62 | + sigmas (`List[float]`, *optional*): |
| 63 | + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
| 64 | + `num_inference_steps` and `timesteps` must be `None`. |
| 65 | +
|
| 66 | + Returns: |
| 67 | + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
| 68 | + second element is the number of inference steps. |
| 69 | + """ |
| 70 | + if timesteps is not None and sigmas is not None: |
| 71 | + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") |
| 72 | + if timesteps is not None: |
| 73 | + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| 74 | + if not accepts_timesteps: |
| 75 | + raise ValueError( |
| 76 | + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| 77 | + f" timestep schedules. Please check whether you are using the correct scheduler." |
| 78 | + ) |
| 79 | + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
| 80 | + timesteps = scheduler.timesteps |
| 81 | + num_inference_steps = len(timesteps) |
| 82 | + elif sigmas is not None: |
| 83 | + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| 84 | + if not accept_sigmas: |
| 85 | + raise ValueError( |
| 86 | + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| 87 | + f" sigmas schedules. Please check whether you are using the correct scheduler." |
| 88 | + ) |
| 89 | + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
| 90 | + timesteps = scheduler.timesteps |
| 91 | + num_inference_steps = len(timesteps) |
| 92 | + else: |
| 93 | + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
| 94 | + timesteps = scheduler.timesteps |
| 95 | + return timesteps, num_inference_steps |
| 96 | + |
| 97 | + |
| 98 | +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents |
| 99 | +def retrieve_latents( |
| 100 | + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" |
| 101 | +): |
| 102 | + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": |
| 103 | + return encoder_output.latent_dist.sample(generator) |
| 104 | + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": |
| 105 | + return encoder_output.latent_dist.mode() |
| 106 | + elif hasattr(encoder_output, "latents"): |
| 107 | + return encoder_output.latents |
| 108 | + else: |
| 109 | + raise AttributeError("Could not access latents of provided encoder_output") |
| 110 | + |
| 111 | + |
| 112 | +def calculate_dimensions(target_area, ratio): |
| 113 | + width = math.sqrt(target_area * ratio) |
| 114 | + height = width / ratio |
| 115 | + |
| 116 | + width = round(width / 32) * 32 |
| 117 | + height = round(height / 32) * 32 |
| 118 | + |
| 119 | + return width, height, None |
| 120 | + |
| 121 | + |
23 | 122 | class QwenImageMixin: |
24 | 123 | @property |
25 | 124 | def guidance_scale(self): |
@@ -340,13 +439,3 @@ def _get_qwen_prompt_embeds( |
340 | 439 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) |
341 | 440 |
|
342 | 441 | return prompt_embeds, encoder_attention_mask |
343 | | - |
344 | | - |
345 | | -def calculate_dimensions(target_area, ratio): |
346 | | - width = math.sqrt(target_area * ratio) |
347 | | - height = width / ratio |
348 | | - |
349 | | - width = round(width / 32) * 32 |
350 | | - height = round(height / 32) * 32 |
351 | | - |
352 | | - return width, height, None |
|
0 commit comments