|
| 1 | +"""Variance schedule for diffusion models.""" |
| 2 | +from __future__ import annotations |
| 3 | + |
| 4 | +import enum |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +from jax import numpy as jnp |
| 8 | + |
| 9 | + |
| 10 | +class DiffusionBetaSchedule(enum.Enum): |
| 11 | + """Class to define beta schedule.""" |
| 12 | + |
| 13 | + LINEAR = enum.auto() |
| 14 | + QUADRADIC = enum.auto() |
| 15 | + COSINE = enum.auto() |
| 16 | + WARMUP10 = enum.auto() |
| 17 | + WARMUP50 = enum.auto() |
| 18 | + |
| 19 | + |
| 20 | +def get_beta_schedule( |
| 21 | + num_timesteps: int, |
| 22 | + beta_schedule: DiffusionBetaSchedule, |
| 23 | + beta_start: float, |
| 24 | + beta_end: float, |
| 25 | +) -> jnp.ndarray: |
| 26 | + """Get variance (beta) schedule for q(x_t | x_{t-1}). |
| 27 | +
|
| 28 | + Args: |
| 29 | + num_timesteps: number of time steps in total, T. |
| 30 | + beta_schedule: schedule for beta. |
| 31 | + beta_start: beta for t=0. |
| 32 | + beta_end: beta for t=T-1. |
| 33 | +
|
| 34 | + Returns: |
| 35 | + Shape (num_timesteps,) array of beta values, for t=0, ..., T-1. |
| 36 | + Values are in ascending order. |
| 37 | +
|
| 38 | + Raises: |
| 39 | + ValueError: for unknown schedule. |
| 40 | + """ |
| 41 | + if beta_schedule == DiffusionBetaSchedule.LINEAR: |
| 42 | + return jnp.linspace( |
| 43 | + beta_start, |
| 44 | + beta_end, |
| 45 | + num_timesteps, |
| 46 | + ) |
| 47 | + if beta_schedule == DiffusionBetaSchedule.QUADRADIC: |
| 48 | + return ( |
| 49 | + jnp.linspace( |
| 50 | + beta_start**0.5, |
| 51 | + beta_end**0.5, |
| 52 | + num_timesteps, |
| 53 | + ) |
| 54 | + ** 2 |
| 55 | + ) |
| 56 | + if beta_schedule == DiffusionBetaSchedule.COSINE: |
| 57 | + |
| 58 | + def f(t: float) -> float: |
| 59 | + """Eq 17 in https://arxiv.org/abs/2102.09672. |
| 60 | +
|
| 61 | + Args: |
| 62 | + t: time step with values in [0, 1]. |
| 63 | +
|
| 64 | + Returns: |
| 65 | + Cumulative product of alpha. |
| 66 | + """ |
| 67 | + return np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2 |
| 68 | + |
| 69 | + betas = [0.0] |
| 70 | + alphas_cumprod_prev = 1.0 |
| 71 | + for i in range(1, num_timesteps): |
| 72 | + t = i / (num_timesteps - 1) |
| 73 | + alphas_cumprod = f(t) |
| 74 | + beta = 1 - alphas_cumprod / alphas_cumprod_prev |
| 75 | + betas.append(beta) |
| 76 | + return jnp.array(betas) * (beta_end - beta_start) + beta_start |
| 77 | + |
| 78 | + if beta_schedule == DiffusionBetaSchedule.WARMUP10: |
| 79 | + num_timesteps_warmup = max(num_timesteps // 10, 1) |
| 80 | + betas_warmup = ( |
| 81 | + jnp.linspace( |
| 82 | + beta_start**0.5, |
| 83 | + beta_end**0.5, |
| 84 | + num_timesteps_warmup, |
| 85 | + ) |
| 86 | + ** 2 |
| 87 | + ) |
| 88 | + return jnp.concatenate( |
| 89 | + [ |
| 90 | + betas_warmup, |
| 91 | + jnp.ones((num_timesteps - num_timesteps_warmup,)) * beta_end, |
| 92 | + ] |
| 93 | + ) |
| 94 | + if beta_schedule == DiffusionBetaSchedule.WARMUP50: |
| 95 | + num_timesteps_warmup = max(num_timesteps // 2, 1) |
| 96 | + betas_warmup = ( |
| 97 | + jnp.linspace( |
| 98 | + beta_start**0.5, |
| 99 | + beta_end**0.5, |
| 100 | + num_timesteps_warmup, |
| 101 | + ) |
| 102 | + ** 2 |
| 103 | + ) |
| 104 | + return jnp.concatenate( |
| 105 | + [ |
| 106 | + betas_warmup, |
| 107 | + jnp.ones((num_timesteps - num_timesteps_warmup,)) * beta_end, |
| 108 | + ] |
| 109 | + ) |
| 110 | + raise ValueError(f"Unknown beta_schedule {beta_schedule}.") |
| 111 | + |
| 112 | + |
| 113 | +def downsample_beta_schedule( |
| 114 | + betas: jnp.ndarray, |
| 115 | + num_timesteps: int, |
| 116 | + num_timesteps_to_keep: int, |
| 117 | +) -> jnp.ndarray: |
| 118 | + """Downsample beta schedule. |
| 119 | +
|
| 120 | + Args: |
| 121 | + betas: beta schedule, shape (num_timesteps,). |
| 122 | + Values are in ascending order. |
| 123 | + num_timesteps: number of time steps in total, T. |
| 124 | + num_timesteps_to_keep: number of time steps to keep. |
| 125 | +
|
| 126 | + Returns: |
| 127 | + Downsampled beta schedule, shape (num_timesteps_to_keep,). |
| 128 | + """ |
| 129 | + if betas.shape != (num_timesteps,): |
| 130 | + raise ValueError( |
| 131 | + f"betas.shape ({betas.shape}) must be equal to " |
| 132 | + f"(num_timesteps,)=({num_timesteps},)" |
| 133 | + ) |
| 134 | + if (num_timesteps - 1) % (num_timesteps_to_keep - 1) != 0: |
| 135 | + raise ValueError( |
| 136 | + f"num_timesteps-1={num_timesteps-1} can't be evenly divided by " |
| 137 | + f"num_timesteps_to_keep-1={num_timesteps_to_keep-1}." |
| 138 | + ) |
| 139 | + if num_timesteps_to_keep < 2: |
| 140 | + raise ValueError( |
| 141 | + f"num_timesteps_to_keep ({num_timesteps_to_keep}) must be >= 2." |
| 142 | + ) |
| 143 | + if num_timesteps_to_keep == num_timesteps: |
| 144 | + return betas |
| 145 | + if num_timesteps_to_keep < num_timesteps: |
| 146 | + step_scale = (num_timesteps - 1) // (num_timesteps_to_keep - 1) |
| 147 | + beta0 = betas[0] |
| 148 | + alphas = 1.0 - betas |
| 149 | + alphas_cumprod = jnp.cumprod(alphas) |
| 150 | + # (num_timesteps_to_keep,) |
| 151 | + alphas_cumprod = alphas_cumprod[::step_scale] |
| 152 | + # (num_timesteps_to_keep-1,) |
| 153 | + betas = 1.0 - alphas_cumprod[1:] / alphas_cumprod[:-1] |
| 154 | + # (num_timesteps_to_keep,) |
| 155 | + betas = jnp.append(beta0, betas) |
| 156 | + return betas |
| 157 | + raise ValueError( |
| 158 | + f"num_timesteps_to_keep ({num_timesteps_to_keep}) " |
| 159 | + f"must be <= num_timesteps ({num_timesteps})" |
| 160 | + ) |
0 commit comments