Skip to content
Merged
4 changes: 3 additions & 1 deletion src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
Expand Down Expand Up @@ -840,6 +840,8 @@ def __call__(

# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
if isinstance(self.scheduler, (UniPCMultistepScheduler, DPMSolverMultistepScheduler)):
sigmas = None
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ def __init__(
timestep_spacing: str = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
Expand Down Expand Up @@ -330,6 +332,7 @@ def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
mu: Optional[float] = None,
timesteps: Optional[List[int]] = None,
):
"""
Expand All @@ -345,6 +348,9 @@ def set_timesteps(
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
must be `None`, and `timestep_spacing` attribute will be ignored.
"""
if mu is not None:
assert self.config.use_dynamic_shifting and self.config.time_shift_type == 'exponential'
self.config.flow_shift = np.exp(mu)
if num_inference_steps is None and timesteps is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
if num_inference_steps is not None and timesteps is not None:
Expand Down
7 changes: 6 additions & 1 deletion src/diffusers/schedulers/scheduling_unipc_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ def __init__(
steps_offset: int = 0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
rescale_betas_zero_snr: bool = False,
use_dynamic_shifting: bool = False,
time_shift_type: str = "exponential",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
Expand Down Expand Up @@ -298,7 +300,7 @@ def set_begin_index(self, begin_index: int = 0):
"""
self._begin_index = begin_index

def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).

Expand All @@ -309,6 +311,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if mu is not None:
assert self.config.use_dynamic_shifting and self.config.time_shift_type == 'exponential'
self.config.flow_shift = np.exp(mu)
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
Expand Down