|
| 1 | +from typing import List, Optional, Union |
| 2 | + |
| 3 | +import math |
| 4 | +import numpy as np |
| 5 | +import torch |
| 6 | + |
| 7 | +from ...configuration_utils import ConfigMixin, register_to_config |
| 8 | +from ..sigmas.beta_sigmas import BetaSigmas |
| 9 | +from ..sigmas.exponential_sigmas import ExponentialSigmas |
| 10 | +from ..sigmas.karras_sigmas import KarrasSigmas |
| 11 | + |
| 12 | +def betas_for_alpha_bar( |
| 13 | + num_diffusion_timesteps, |
| 14 | + max_beta=0.999, |
| 15 | + alpha_transform_type="cosine", |
| 16 | +): |
| 17 | + """ |
| 18 | + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of |
| 19 | + (1-beta) over time from t = [0,1]. |
| 20 | +
|
| 21 | + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up |
| 22 | + to that part of the diffusion process. |
| 23 | +
|
| 24 | +
|
| 25 | + Args: |
| 26 | + num_diffusion_timesteps (`int`): the number of betas to produce. |
| 27 | + max_beta (`float`): the maximum beta to use; use values lower than 1 to |
| 28 | + prevent singularities. |
| 29 | + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. |
| 30 | + Choose from `cosine` or `exp` |
| 31 | +
|
| 32 | + Returns: |
| 33 | + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs |
| 34 | + """ |
| 35 | + if alpha_transform_type == "cosine": |
| 36 | + |
| 37 | + def alpha_bar_fn(t): |
| 38 | + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 |
| 39 | + |
| 40 | + elif alpha_transform_type == "exp": |
| 41 | + |
| 42 | + def alpha_bar_fn(t): |
| 43 | + return math.exp(t * -12.0) |
| 44 | + |
| 45 | + else: |
| 46 | + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") |
| 47 | + |
| 48 | + betas = [] |
| 49 | + for i in range(num_diffusion_timesteps): |
| 50 | + t1 = i / num_diffusion_timesteps |
| 51 | + t2 = (i + 1) / num_diffusion_timesteps |
| 52 | + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) |
| 53 | + return torch.tensor(betas, dtype=torch.float32) |
| 54 | + |
| 55 | +def rescale_zero_terminal_snr(betas): |
| 56 | + """ |
| 57 | + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) |
| 58 | +
|
| 59 | +
|
| 60 | + Args: |
| 61 | + betas (`torch.Tensor`): |
| 62 | + the betas that the scheduler is being initialized with. |
| 63 | +
|
| 64 | + Returns: |
| 65 | + `torch.Tensor`: rescaled betas with zero terminal SNR |
| 66 | + """ |
| 67 | + # Convert betas to alphas_bar_sqrt |
| 68 | + alphas = 1.0 - betas |
| 69 | + alphas_cumprod = torch.cumprod(alphas, dim=0) |
| 70 | + alphas_bar_sqrt = alphas_cumprod.sqrt() |
| 71 | + |
| 72 | + # Store old values. |
| 73 | + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() |
| 74 | + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() |
| 75 | + |
| 76 | + # Shift so the last timestep is zero. |
| 77 | + alphas_bar_sqrt -= alphas_bar_sqrt_T |
| 78 | + |
| 79 | + # Scale so the first timestep is back to the old value. |
| 80 | + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) |
| 81 | + |
| 82 | + # Convert alphas_bar_sqrt to betas |
| 83 | + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt |
| 84 | + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod |
| 85 | + alphas = torch.cat([alphas_bar[0:1], alphas]) |
| 86 | + betas = 1 - alphas |
| 87 | + |
| 88 | + return betas |
| 89 | + |
| 90 | + |
| 91 | +class BetaSchedule: |
| 92 | + |
| 93 | + scale_model_input = True |
| 94 | + |
| 95 | + def __init__( |
| 96 | + self, |
| 97 | + num_train_timesteps: int = 1000, |
| 98 | + beta_start: float = 0.0001, |
| 99 | + beta_end: float = 0.02, |
| 100 | + beta_schedule: str = "linear", |
| 101 | + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, |
| 102 | + rescale_betas_zero_snr: bool = False, |
| 103 | + interpolation_type: str = "linear", |
| 104 | + timestep_spacing: str = "linspace", |
| 105 | + timestep_type: str = "discrete", # can be "discrete" or "continuous" |
| 106 | + steps_offset: int = 0, |
| 107 | + sigma_min: Optional[float] = None, |
| 108 | + sigma_max: Optional[float] = None, |
| 109 | + final_sigmas_type: str = "zero", # can be "zero" or "sigma_min" |
| 110 | + **kwargs, |
| 111 | + ): |
| 112 | + if trained_betas is not None: |
| 113 | + self.betas = torch.tensor(trained_betas, dtype=torch.float32) |
| 114 | + elif beta_schedule == "linear": |
| 115 | + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) |
| 116 | + elif beta_schedule == "scaled_linear": |
| 117 | + # this schedule is very specific to the latent diffusion model. |
| 118 | + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 |
| 119 | + elif beta_schedule == "squaredcos_cap_v2": |
| 120 | + # Glide cosine schedule |
| 121 | + self.betas = betas_for_alpha_bar(num_train_timesteps) |
| 122 | + else: |
| 123 | + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") |
| 124 | + |
| 125 | + if rescale_betas_zero_snr: |
| 126 | + self.betas = rescale_zero_terminal_snr(self.betas) |
| 127 | + |
| 128 | + self.alphas = 1.0 - self.betas |
| 129 | + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
| 130 | + |
| 131 | + if rescale_betas_zero_snr: |
| 132 | + # Close to 0 without being 0 so first sigma is not inf |
| 133 | + # FP16 smallest positive subnormal works well here |
| 134 | + self.alphas_cumprod[-1] = 2**-24 |
| 135 | + |
| 136 | + self.num_train_timesteps = num_train_timesteps |
| 137 | + self.beta_start = beta_start |
| 138 | + self.beta_end = beta_end |
| 139 | + self.beta_schedule = beta_schedule |
| 140 | + self.trained_betas = trained_betas |
| 141 | + self.rescale_betas_zero_snr = rescale_betas_zero_snr |
| 142 | + self.interpolation_type = interpolation_type |
| 143 | + self.timestep_spacing = timestep_spacing |
| 144 | + self.timestep_type = timestep_type |
| 145 | + self.steps_offset = steps_offset |
| 146 | + self.sigma_min = sigma_min |
| 147 | + self.sigma_max = sigma_max |
| 148 | + self.final_sigmas_type = final_sigmas_type |
| 149 | + |
| 150 | + def _sigma_to_t(self, sigma, log_sigmas): |
| 151 | + # get log sigma |
| 152 | + log_sigma = np.log(np.maximum(sigma, 1e-10)) |
| 153 | + |
| 154 | + # get distribution |
| 155 | + dists = log_sigma - log_sigmas[:, np.newaxis] |
| 156 | + |
| 157 | + # get sigmas range |
| 158 | + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) |
| 159 | + high_idx = low_idx + 1 |
| 160 | + |
| 161 | + low = log_sigmas[low_idx] |
| 162 | + high = log_sigmas[high_idx] |
| 163 | + |
| 164 | + # interpolate sigmas |
| 165 | + w = (low - log_sigma) / (low - high) |
| 166 | + w = np.clip(w, 0, 1) |
| 167 | + |
| 168 | + # transform interpolation to time range |
| 169 | + t = (1 - w) * low_idx + w * high_idx |
| 170 | + t = t.reshape(sigma.shape) |
| 171 | + return t |
| 172 | + |
| 173 | + def __call__( |
| 174 | + self, |
| 175 | + num_inference_steps: int = None, |
| 176 | + device: Union[str, torch.device] = None, |
| 177 | + timesteps: Optional[List[int]] = None, |
| 178 | + sigmas: Optional[List[float]] = None, |
| 179 | + sigma_schedule: Optional[Union[KarrasSigmas, ExponentialSigmas, BetaSigmas]] = None, |
| 180 | + **kwargs, |
| 181 | + ): |
| 182 | + if sigmas is not None: |
| 183 | + log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)) |
| 184 | + sigmas = np.array(sigmas).astype(np.float32) |
| 185 | + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]]) |
| 186 | + |
| 187 | + else: |
| 188 | + if timesteps is not None: |
| 189 | + timesteps = np.array(timesteps).astype(np.float32) |
| 190 | + else: |
| 191 | + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 |
| 192 | + if self.timestep_spacing == "linspace": |
| 193 | + timesteps = np.linspace( |
| 194 | + 0, self.num_train_timesteps - 1, num_inference_steps, dtype=np.float32 |
| 195 | + )[::-1].copy() |
| 196 | + elif self.timestep_spacing == "leading": |
| 197 | + step_ratio = self.num_train_timesteps // num_inference_steps |
| 198 | + # creates integer timesteps by multiplying by ratio |
| 199 | + # casting to int to avoid issues when num_inference_step is power of 3 |
| 200 | + timesteps = ( |
| 201 | + (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) |
| 202 | + ) |
| 203 | + timesteps += self.steps_offset |
| 204 | + elif self.timestep_spacing == "trailing": |
| 205 | + step_ratio = self.num_train_timesteps / num_inference_steps |
| 206 | + # creates integer timesteps by multiplying by ratio |
| 207 | + # casting to int to avoid issues when num_inference_step is power of 3 |
| 208 | + timesteps = ( |
| 209 | + (np.arange(self.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) |
| 210 | + ) |
| 211 | + timesteps -= 1 |
| 212 | + else: |
| 213 | + raise ValueError( |
| 214 | + f"{self.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." |
| 215 | + ) |
| 216 | + |
| 217 | + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) |
| 218 | + log_sigmas = np.log(sigmas) |
| 219 | + if self.interpolation_type == "linear": |
| 220 | + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) |
| 221 | + elif self.interpolation_type == "log_linear": |
| 222 | + sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy() |
| 223 | + else: |
| 224 | + raise ValueError( |
| 225 | + f"{self.interpolation_type} is not implemented. Please specify interpolation_type to either" |
| 226 | + " 'linear' or 'log_linear'" |
| 227 | + ) |
| 228 | + |
| 229 | + if sigma_schedule is not None: |
| 230 | + sigmas = sigma_schedule(sigmas) |
| 231 | + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) |
| 232 | + |
| 233 | + if self.final_sigmas_type == "sigma_min": |
| 234 | + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 |
| 235 | + elif self.final_sigmas_type == "zero": |
| 236 | + sigma_last = 0 |
| 237 | + else: |
| 238 | + raise ValueError( |
| 239 | + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.final_sigmas_type}" |
| 240 | + ) |
| 241 | + |
| 242 | + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) |
| 243 | + |
| 244 | + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) |
| 245 | + |
| 246 | + # TODO: Support the full EDM scalings for all prediction types and timestep types |
| 247 | + if self.timestep_type == "continuous" and self.prediction_type == "v_prediction": |
| 248 | + timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device) |
| 249 | + else: |
| 250 | + timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device) |
| 251 | + |
| 252 | + return sigmas, timesteps |
0 commit comments