diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 0fab6d910a82..753df90e5aa3 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -172,6 +172,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and dark samples instead of limiting it to samples with medium brightness. Loosely related to [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + scale_betas_for_timesteps (`bool`, defaults to `False`): + Whether to scale the `beta_end` parameter based on the number of training timesteps. The original DDPM + paper's parameters are tuned for `num_train_timesteps=1000`, so scaling may be required for other values to + maintain a similar noise schedule. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -195,23 +199,32 @@ def __init__( timestep_spacing: str = "leading", steps_offset: int = 0, rescale_betas_zero_snr: bool = False, + scale_betas_for_timesteps: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - elif beta_schedule == "sigmoid": - # GeoDiff sigmoid schedule - betas = torch.linspace(-6, 6, num_train_timesteps) - self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start else: - raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + if scale_betas_for_timesteps and num_train_timesteps != 1000: + # scale betas for num_train_timesteps + scale_factor = 1000 / num_train_timesteps + beta_end = beta_end * scale_factor + + if beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + elif beta_schedule == "sigmoid": + # GeoDiff sigmoid schedule + betas = torch.linspace(-6, 6, num_train_timesteps) + self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") # Rescale for zero SNR if rescale_betas_zero_snr: diff --git a/tests/schedulers/test_scheduler_ddpm.py b/tests/schedulers/test_scheduler_ddpm.py index 056b5d83350e..12a8bb091d44 100644 --- a/tests/schedulers/test_scheduler_ddpm.py +++ b/tests/schedulers/test_scheduler_ddpm.py @@ -72,6 +72,29 @@ def test_rescale_betas_zero_snr(self): for rescale_betas_zero_snr in [True, False]: self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) + def test_scale_betas_for_timesteps(self): + scheduler_class = self.scheduler_classes[0] + + # 1. Test that betas are scaled when the flag is True and T != 1000. + config = self.get_scheduler_config(num_train_timesteps=2000, scale_betas_for_timesteps=True) + scheduler = scheduler_class(**config) + last_beta = scheduler.betas[-1].item() + # The original beta_end is 0.02. Scaled should be 0.02 * (1000/2000) = 0.01 + self.assertAlmostEqual(last_beta, 0.01) + + # 2. Test that betas are NOT scaled when the flag is False. + config = self.get_scheduler_config(num_train_timesteps=2000, scale_betas_for_timesteps=False) + scheduler = scheduler_class(**config) + last_beta = scheduler.betas[-1].item() + # Should be the original, unscaled value + self.assertAlmostEqual(last_beta, 0.02) + + # 3. Test that betas are NOT scaled when T=1000, even if the flag is True. + config = self.get_scheduler_config(num_train_timesteps=1000, scale_betas_for_timesteps=True) + scheduler = scheduler_class(**config) + last_beta = scheduler.betas[-1].item() + self.assertAlmostEqual(last_beta, 0.02) + def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config()