Skip to content

Commit 2185553

Browse files
authored
Merge branch 'main' into guidance-scale-docs
2 parents d146f12 + 5c52097 commit 2185553

11 files changed

+238
-99
lines changed

src/diffusers/configuration_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -763,4 +763,7 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un
763763
# resolve remapping
764764
remapped_class = _fetch_remapped_cls_from_config(config, cls)
765765

766-
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
766+
if remapped_class is cls:
767+
return super(LegacyConfigMixin, remapped_class).from_config(config, return_unused_kwargs, **kwargs)
768+
else:
769+
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)

src/diffusers/models/modeling_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1877,4 +1877,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
18771877
# resolve remapping
18781878
remapped_class = _fetch_remapped_cls_from_config(config, cls)
18791879

1880-
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
1880+
if remapped_class is cls:
1881+
return super(LegacyModelMixin, remapped_class).from_pretrained(
1882+
pretrained_model_name_or_path, **kwargs_copy
1883+
)
1884+
else:
1885+
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py

Lines changed: 194 additions & 92 deletions
Large diffs are not rendered by default.

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,8 @@ def __call__(
844844

845845
# 5. Prepare timesteps
846846
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
847+
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
848+
sigmas = None
847849
image_seq_len = latents.shape[1]
848850
mu = calculate_shift(
849851
image_seq_len,

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,8 @@ def __call__(
383383
# set timesteps
384384
self.scheduler.set_timesteps(num_inference_steps)
385385

386-
latents = latents * np.float64(self.scheduler.init_noise_sigma)
386+
# scale the initial noise by the standard deviation required by the scheduler
387+
latents = latents * self.scheduler.init_noise_sigma
387388

388389
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
389390
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def __call__(
483483
self.scheduler.set_timesteps(num_inference_steps)
484484

485485
# scale the initial noise by the standard deviation required by the scheduler
486-
latents = latents * np.float64(self.scheduler.init_noise_sigma)
486+
latents = latents * self.scheduler.init_noise_sigma
487487

488488
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
489489
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def __call__(
481481
timesteps = self.scheduler.timesteps
482482

483483
# Scale the initial noise by the standard deviation required by the scheduler
484-
latents = latents * np.float64(self.scheduler.init_noise_sigma)
484+
latents = latents * self.scheduler.init_noise_sigma
485485

486486
# 5. Add noise to image
487487
noise_level = np.array([noise_level]).astype(np.int64)

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ def __init__(
153153
flow_shift: Optional[float] = 1.0,
154154
timestep_spacing: str = "linspace",
155155
steps_offset: int = 0,
156+
use_dynamic_shifting: bool = False,
157+
time_shift_type: str = "exponential",
156158
):
157159
if self.config.use_beta_sigmas and not is_scipy_available():
158160
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -232,7 +234,9 @@ def set_begin_index(self, begin_index: int = 0):
232234
"""
233235
self._begin_index = begin_index
234236

235-
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
237+
def set_timesteps(
238+
self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
239+
):
236240
"""
237241
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
238242
@@ -242,6 +246,9 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
242246
device (`str` or `torch.device`, *optional*):
243247
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
244248
"""
249+
if mu is not None:
250+
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
251+
self.config.flow_shift = np.exp(mu)
245252
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
246253
if self.config.timestep_spacing == "linspace":
247254
timesteps = (

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ def __init__(
230230
timestep_spacing: str = "linspace",
231231
steps_offset: int = 0,
232232
rescale_betas_zero_snr: bool = False,
233+
use_dynamic_shifting: bool = False,
234+
time_shift_type: str = "exponential",
233235
):
234236
if self.config.use_beta_sigmas and not is_scipy_available():
235237
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -330,6 +332,7 @@ def set_timesteps(
330332
self,
331333
num_inference_steps: int = None,
332334
device: Union[str, torch.device] = None,
335+
mu: Optional[float] = None,
333336
timesteps: Optional[List[int]] = None,
334337
):
335338
"""
@@ -345,6 +348,9 @@ def set_timesteps(
345348
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
346349
must be `None`, and `timestep_spacing` attribute will be ignored.
347350
"""
351+
if mu is not None:
352+
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
353+
self.config.flow_shift = np.exp(mu)
348354
if num_inference_steps is None and timesteps is None:
349355
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
350356
if num_inference_steps is not None and timesteps is not None:

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def __init__(
169169
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
170170
lambda_min_clipped: float = -float("inf"),
171171
variance_type: Optional[str] = None,
172+
use_dynamic_shifting: bool = False,
173+
time_shift_type: str = "exponential",
172174
):
173175
if self.config.use_beta_sigmas and not is_scipy_available():
174176
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -301,6 +303,7 @@ def set_timesteps(
301303
self,
302304
num_inference_steps: int = None,
303305
device: Union[str, torch.device] = None,
306+
mu: Optional[float] = None,
304307
timesteps: Optional[List[int]] = None,
305308
):
306309
"""
@@ -316,6 +319,9 @@ def set_timesteps(
316319
timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is
317320
passed, `num_inference_steps` must be `None`.
318321
"""
322+
if mu is not None:
323+
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
324+
self.config.flow_shift = np.exp(mu)
319325
if num_inference_steps is None and timesteps is None:
320326
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
321327
if num_inference_steps is not None and timesteps is not None:

0 commit comments

Comments
 (0)