|
74 | 74 | """ |
75 | 75 |
|
76 | 76 |
|
| 77 | +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift |
| 78 | +def calculate_shift( |
| 79 | + image_seq_len, |
| 80 | + base_seq_len: int = 256, |
| 81 | + max_seq_len: int = 4096, |
| 82 | + base_shift: float = 0.5, |
| 83 | + max_shift: float = 1.16, |
| 84 | +): |
| 85 | + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
| 86 | + b = base_shift - m * base_seq_len |
| 87 | + mu = image_seq_len * m + b |
| 88 | + return mu |
| 89 | + |
| 90 | + |
77 | 91 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents |
78 | 92 | def retrieve_latents( |
79 | 93 | encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" |
@@ -838,6 +852,7 @@ def __call__( |
838 | 852 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
839 | 853 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
840 | 854 | max_sequence_length: int = 256, |
| 855 | + mu: Optional[float] = None, |
841 | 856 | ): |
842 | 857 | r""" |
843 | 858 | Function invoked when calling the pipeline for generation. |
@@ -947,6 +962,7 @@ def __call__( |
947 | 962 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the |
948 | 963 | `._callback_tensor_inputs` attribute of your pipeline class. |
949 | 964 | max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. |
| 965 | + mu (`float`, *optional*): `mu` value used for `dynamic_shifting`. |
950 | 966 |
|
951 | 967 | Examples: |
952 | 968 |
|
@@ -1023,7 +1039,24 @@ def __call__( |
1023 | 1039 | pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) |
1024 | 1040 |
|
1025 | 1041 | # 3. Prepare timesteps |
1026 | | - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) |
| 1042 | + scheduler_kwargs = {} |
| 1043 | + if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: |
| 1044 | + image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * ( |
| 1045 | + int(width) // self.vae_scale_factor // self.transformer.config.patch_size |
| 1046 | + ) |
| 1047 | + mu = calculate_shift( |
| 1048 | + image_seq_len, |
| 1049 | + self.scheduler.config.base_image_seq_len, |
| 1050 | + self.scheduler.config.max_image_seq_len, |
| 1051 | + self.scheduler.config.base_shift, |
| 1052 | + self.scheduler.config.max_shift, |
| 1053 | + ) |
| 1054 | + scheduler_kwargs["mu"] = mu |
| 1055 | + elif mu is not None: |
| 1056 | + scheduler_kwargs["mu"] = mu |
| 1057 | + timesteps, num_inference_steps = retrieve_timesteps( |
| 1058 | + self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs |
| 1059 | + ) |
1027 | 1060 | timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) |
1028 | 1061 | # check that number of inference steps is not < 1 - as this doesn't make sense |
1029 | 1062 | if num_inference_steps < 1: |
|
0 commit comments