Skip to content

Commit 3706aa3

Browse files
dg845sayakpaul
andauthored
Add rescale_betas_zero_snr Argument to DDPMScheduler (#6305)
* Add rescale_betas_zero_snr argument to DDPMScheduler. * Propagate rescale_betas_zero_snr changes to DDPMParallelScheduler. --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent d4f10ea commit 3706aa3

File tree

4 files changed

+100
-0
lines changed

4 files changed

+100
-0
lines changed

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,43 @@ def alpha_bar_fn(t):
8989
return torch.tensor(betas, dtype=torch.float32)
9090

9191

92+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
93+
def rescale_zero_terminal_snr(betas):
94+
"""
95+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
96+
97+
98+
Args:
99+
betas (`torch.FloatTensor`):
100+
the betas that the scheduler is being initialized with.
101+
102+
Returns:
103+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
104+
"""
105+
# Convert betas to alphas_bar_sqrt
106+
alphas = 1.0 - betas
107+
alphas_cumprod = torch.cumprod(alphas, dim=0)
108+
alphas_bar_sqrt = alphas_cumprod.sqrt()
109+
110+
# Store old values.
111+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
112+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
113+
114+
# Shift so the last timestep is zero.
115+
alphas_bar_sqrt -= alphas_bar_sqrt_T
116+
117+
# Scale so the first timestep is back to the old value.
118+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
119+
120+
# Convert alphas_bar_sqrt to betas
121+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
122+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
123+
alphas = torch.cat([alphas_bar[0:1], alphas])
124+
betas = 1 - alphas
125+
126+
return betas
127+
128+
92129
class DDPMScheduler(SchedulerMixin, ConfigMixin):
93130
"""
94131
`DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling.
@@ -131,6 +168,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
131168
An offset added to the inference steps. You can use a combination of `offset=1` and
132169
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
133170
Diffusion.
171+
rescale_betas_zero_snr (`bool`, defaults to `False`):
172+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
173+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
174+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
134175
"""
135176

136177
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -153,6 +194,7 @@ def __init__(
153194
sample_max_value: float = 1.0,
154195
timestep_spacing: str = "leading",
155196
steps_offset: int = 0,
197+
rescale_betas_zero_snr: int = False,
156198
):
157199
if trained_betas is not None:
158200
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -171,6 +213,10 @@ def __init__(
171213
else:
172214
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
173215

216+
# Rescale for zero SNR
217+
if rescale_betas_zero_snr:
218+
self.betas = rescale_zero_terminal_snr(self.betas)
219+
174220
self.alphas = 1.0 - self.betas
175221
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
176222
self.one = torch.tensor(1.0)

src/diffusers/schedulers/scheduling_ddpm_parallel.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,43 @@ def alpha_bar_fn(t):
9191
return torch.tensor(betas, dtype=torch.float32)
9292

9393

94+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
95+
def rescale_zero_terminal_snr(betas):
96+
"""
97+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
98+
99+
100+
Args:
101+
betas (`torch.FloatTensor`):
102+
the betas that the scheduler is being initialized with.
103+
104+
Returns:
105+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
106+
"""
107+
# Convert betas to alphas_bar_sqrt
108+
alphas = 1.0 - betas
109+
alphas_cumprod = torch.cumprod(alphas, dim=0)
110+
alphas_bar_sqrt = alphas_cumprod.sqrt()
111+
112+
# Store old values.
113+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
114+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
115+
116+
# Shift so the last timestep is zero.
117+
alphas_bar_sqrt -= alphas_bar_sqrt_T
118+
119+
# Scale so the first timestep is back to the old value.
120+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
121+
122+
# Convert alphas_bar_sqrt to betas
123+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
124+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
125+
alphas = torch.cat([alphas_bar[0:1], alphas])
126+
betas = 1 - alphas
127+
128+
return betas
129+
130+
94131
class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
95132
"""
96133
Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
@@ -139,6 +176,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
139176
an offset added to the inference steps. You can use a combination of `offset=1` and
140177
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
141178
stable diffusion.
179+
rescale_betas_zero_snr (`bool`, defaults to `False`):
180+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
181+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
182+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
142183
"""
143184

144185
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -163,6 +204,7 @@ def __init__(
163204
sample_max_value: float = 1.0,
164205
timestep_spacing: str = "leading",
165206
steps_offset: int = 0,
207+
rescale_betas_zero_snr: int = False,
166208
):
167209
if trained_betas is not None:
168210
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -181,6 +223,10 @@ def __init__(
181223
else:
182224
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
183225

226+
# Rescale for zero SNR
227+
if rescale_betas_zero_snr:
228+
self.betas = rescale_zero_terminal_snr(self.betas)
229+
184230
self.alphas = 1.0 - self.betas
185231
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
186232
self.one = torch.tensor(1.0)

tests/schedulers/test_scheduler_ddpm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def test_variance(self):
6868
assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5
6969
assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5
7070

71+
def test_rescale_betas_zero_snr(self):
72+
for rescale_betas_zero_snr in [True, False]:
73+
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)
74+
7175
def test_full_loop_no_noise(self):
7276
scheduler_class = self.scheduler_classes[0]
7377
scheduler_config = self.get_scheduler_config()

tests/schedulers/test_scheduler_ddpm_parallel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ def test_variance(self):
8282
assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.00979)) < 1e-5
8383
assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.02)) < 1e-5
8484

85+
def test_rescale_betas_zero_snr(self):
86+
for rescale_betas_zero_snr in [True, False]:
87+
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)
88+
8589
def test_batch_step_no_noise(self):
8690
scheduler_class = self.scheduler_classes[0]
8791
scheduler_config = self.get_scheduler_config()

0 commit comments

Comments
 (0)