Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
scale_betas_for_timesteps (`bool`, defaults to `False`):
Whether to scale the `beta_end` parameter based on the number of training timesteps. The original DDPM
paper's parameters are tuned for `num_train_timesteps=1000`, so scaling may be required for other values to
maintain a similar noise schedule.
"""

_compatibles = [e.name for e in KarrasDiffusionSchedulers]
Expand All @@ -195,23 +199,32 @@ def __init__(
timestep_spacing: str = "leading",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
scale_betas_for_timesteps: bool = False,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
elif beta_schedule == "sigmoid":
# GeoDiff sigmoid schedule
betas = torch.linspace(-6, 6, num_train_timesteps)
self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
else:
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
if scale_betas_for_timesteps and num_train_timesteps != 1000:
# scale betas for num_train_timesteps
scale_factor = 1000 / num_train_timesteps
beta_end = beta_end * scale_factor

if beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
elif beta_schedule == "sigmoid":
# GeoDiff sigmoid schedule
betas = torch.linspace(-6, 6, num_train_timesteps)
self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
else:
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")

# Rescale for zero SNR
if rescale_betas_zero_snr:
Expand Down
23 changes: 23 additions & 0 deletions tests/schedulers/test_scheduler_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,29 @@ def test_rescale_betas_zero_snr(self):
for rescale_betas_zero_snr in [True, False]:
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)

def test_scale_betas_for_timesteps(self):
scheduler_class = self.scheduler_classes[0]

# 1. Test that betas are scaled when the flag is True and T != 1000.
config = self.get_scheduler_config(num_train_timesteps=2000, scale_betas_for_timesteps=True)
scheduler = scheduler_class(**config)
last_beta = scheduler.betas[-1].item()
# The original beta_end is 0.02. Scaled should be 0.02 * (1000/2000) = 0.01
self.assertAlmostEqual(last_beta, 0.01)

# 2. Test that betas are NOT scaled when the flag is False.
config = self.get_scheduler_config(num_train_timesteps=2000, scale_betas_for_timesteps=False)
scheduler = scheduler_class(**config)
last_beta = scheduler.betas[-1].item()
# Should be the original, unscaled value
self.assertAlmostEqual(last_beta, 0.02)

# 3. Test that betas are NOT scaled when T=1000, even if the flag is True.
config = self.get_scheduler_config(num_train_timesteps=1000, scale_betas_for_timesteps=True)
scheduler = scheduler_class(**config)
last_beta = scheduler.betas[-1].item()
self.assertAlmostEqual(last_beta, 0.02)

def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
Expand Down