@@ -71,6 +71,43 @@ def alpha_bar_fn(t):
7171 return torch .tensor (betas , dtype = torch .float32 )
7272
7373
74+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
75+ def rescale_zero_terminal_snr (betas ):
76+ """
77+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
78+
79+
80+ Args:
81+ betas (`torch.FloatTensor`):
82+ the betas that the scheduler is being initialized with.
83+
84+ Returns:
85+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
86+ """
87+ # Convert betas to alphas_bar_sqrt
88+ alphas = 1.0 - betas
89+ alphas_cumprod = torch .cumprod (alphas , dim = 0 )
90+ alphas_bar_sqrt = alphas_cumprod .sqrt ()
91+
92+ # Store old values.
93+ alphas_bar_sqrt_0 = alphas_bar_sqrt [0 ].clone ()
94+ alphas_bar_sqrt_T = alphas_bar_sqrt [- 1 ].clone ()
95+
96+ # Shift so the last timestep is zero.
97+ alphas_bar_sqrt -= alphas_bar_sqrt_T
98+
99+ # Scale so the first timestep is back to the old value.
100+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T )
101+
102+ # Convert alphas_bar_sqrt to betas
103+ alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
104+ alphas = alphas_bar [1 :] / alphas_bar [:- 1 ] # Revert cumprod
105+ alphas = torch .cat ([alphas_bar [0 :1 ], alphas ])
106+ betas = 1 - alphas
107+
108+ return betas
109+
110+
74111class UniPCMultistepScheduler (SchedulerMixin , ConfigMixin ):
75112 """
76113 `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
@@ -130,6 +167,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
130167 final_sigmas_type (`str`, defaults to `"zero"`):
131168 The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
132169 sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
170+ rescale_betas_zero_snr (`bool`, defaults to `False`):
171+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
172+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
173+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
133174 """
134175
135176 _compatibles = [e .name for e in KarrasDiffusionSchedulers ]
@@ -157,6 +198,7 @@ def __init__(
157198 timestep_spacing : str = "linspace" ,
158199 steps_offset : int = 0 ,
159200 final_sigmas_type : Optional [str ] = "zero" , # "zero", "sigma_min"
201+ rescale_betas_zero_snr : bool = False ,
160202 ):
161203 if trained_betas is not None :
162204 self .betas = torch .tensor (trained_betas , dtype = torch .float32 )
@@ -171,8 +213,17 @@ def __init__(
171213 else :
172214 raise NotImplementedError (f"{ beta_schedule } does is not implemented for { self .__class__ } " )
173215
216+ if rescale_betas_zero_snr :
217+ self .betas = rescale_zero_terminal_snr (self .betas )
218+
174219 self .alphas = 1.0 - self .betas
175220 self .alphas_cumprod = torch .cumprod (self .alphas , dim = 0 )
221+
222+ if rescale_betas_zero_snr :
223+ # Close to 0 without being 0 so first sigma is not inf
224+ # FP16 smallest positive subnormal works well here
225+ self .alphas_cumprod [- 1 ] = 2 ** - 24
226+
176227 # Currently we only support VP-type noise schedule
177228 self .alpha_t = torch .sqrt (self .alphas_cumprod )
178229 self .sigma_t = torch .sqrt (1 - self .alphas_cumprod )
0 commit comments