Skip to content

Commit aa19025

Browse files
authored
UniPC Multistep add rescale_betas_zero_snr (#7531)
* UniPC Multistep add `rescale_betas_zero_snr` Same patch as DPM and Euler with the patched final alpha cumprod BF16 doesn't seem to break down, I think cause UniPC upcasts during some phases already? We could still force an upcast since it only loses ≈ 0.005 it/s for me but the difference in output is very small. A better endeavor might upcasting in step() and removing all the other upcasts elsewhere? * UniPC ZSNR UT * Re-add `rescale_betas_zsnr` doc oops
1 parent 19ab04f commit aa19025

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,43 @@ def alpha_bar_fn(t):
7171
return torch.tensor(betas, dtype=torch.float32)
7272

7373

74+
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
75+
def rescale_zero_terminal_snr(betas):
76+
"""
77+
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
78+
79+
80+
Args:
81+
betas (`torch.FloatTensor`):
82+
the betas that the scheduler is being initialized with.
83+
84+
Returns:
85+
`torch.FloatTensor`: rescaled betas with zero terminal SNR
86+
"""
87+
# Convert betas to alphas_bar_sqrt
88+
alphas = 1.0 - betas
89+
alphas_cumprod = torch.cumprod(alphas, dim=0)
90+
alphas_bar_sqrt = alphas_cumprod.sqrt()
91+
92+
# Store old values.
93+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
94+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
95+
96+
# Shift so the last timestep is zero.
97+
alphas_bar_sqrt -= alphas_bar_sqrt_T
98+
99+
# Scale so the first timestep is back to the old value.
100+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
101+
102+
# Convert alphas_bar_sqrt to betas
103+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
104+
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
105+
alphas = torch.cat([alphas_bar[0:1], alphas])
106+
betas = 1 - alphas
107+
108+
return betas
109+
110+
74111
class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
75112
"""
76113
`UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
@@ -130,6 +167,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
130167
final_sigmas_type (`str`, defaults to `"zero"`):
131168
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
132169
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
170+
rescale_betas_zero_snr (`bool`, defaults to `False`):
171+
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
172+
dark samples instead of limiting it to samples with medium brightness. Loosely related to
173+
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
133174
"""
134175

135176
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -157,6 +198,7 @@ def __init__(
157198
timestep_spacing: str = "linspace",
158199
steps_offset: int = 0,
159200
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
201+
rescale_betas_zero_snr: bool = False,
160202
):
161203
if trained_betas is not None:
162204
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -171,8 +213,17 @@ def __init__(
171213
else:
172214
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
173215

216+
if rescale_betas_zero_snr:
217+
self.betas = rescale_zero_terminal_snr(self.betas)
218+
174219
self.alphas = 1.0 - self.betas
175220
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
221+
222+
if rescale_betas_zero_snr:
223+
# Close to 0 without being 0 so first sigma is not inf
224+
# FP16 smallest positive subnormal works well here
225+
self.alphas_cumprod[-1] = 2**-24
226+
176227
# Currently we only support VP-type noise schedule
177228
self.alpha_t = torch.sqrt(self.alphas_cumprod)
178229
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)

tests/schedulers/test_scheduler_unipc.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,10 @@ def test_prediction_type(self):
180180
for prediction_type in ["epsilon", "v_prediction"]:
181181
self.check_over_configs(prediction_type=prediction_type)
182182

183+
def test_rescale_betas_zero_snr(self):
184+
for rescale_betas_zero_snr in [True, False]:
185+
self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)
186+
183187
def test_solver_order_and_type(self):
184188
for solver_type in ["bh1", "bh2"]:
185189
for order in [1, 2, 3]:

0 commit comments

Comments
 (0)