|
| 1 | +# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from typing import List, Union |
| 16 | + |
| 17 | +import numpy as np |
| 18 | +import torch |
| 19 | + |
| 20 | +from scipy import integrate |
| 21 | + |
| 22 | +from ..configuration_utils import ConfigMixin, register_to_config |
| 23 | +from .scheduling_utils import SchedulerMixin |
| 24 | + |
| 25 | + |
| 26 | +class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): |
| 27 | + @register_to_config |
| 28 | + def __init__( |
| 29 | + self, |
| 30 | + num_train_timesteps=1000, |
| 31 | + beta_start=0.0001, |
| 32 | + beta_end=0.02, |
| 33 | + beta_schedule="linear", |
| 34 | + trained_betas=None, |
| 35 | + timestep_values=None, |
| 36 | + tensor_format="pt", |
| 37 | + ): |
| 38 | + """ |
| 39 | + Linear Multistep Scheduler for discrete beta schedules. |
| 40 | + Based on the original k-diffusion implementation by Katherine Crowson: |
| 41 | + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181 |
| 42 | + """ |
| 43 | + |
| 44 | + if beta_schedule == "linear": |
| 45 | + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) |
| 46 | + elif beta_schedule == "scaled_linear": |
| 47 | + # this schedule is very specific to the latent diffusion model. |
| 48 | + self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 |
| 49 | + else: |
| 50 | + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") |
| 51 | + |
| 52 | + self.alphas = 1.0 - self.betas |
| 53 | + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) |
| 54 | + |
| 55 | + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 |
| 56 | + |
| 57 | + # setable values |
| 58 | + self.num_inference_steps = None |
| 59 | + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() |
| 60 | + self.derivatives = [] |
| 61 | + |
| 62 | + self.tensor_format = tensor_format |
| 63 | + self.set_format(tensor_format=tensor_format) |
| 64 | + |
| 65 | + def get_lms_coefficient(self, order, t, current_order): |
| 66 | + """ |
| 67 | + Compute a linear multistep coefficient |
| 68 | + """ |
| 69 | + |
| 70 | + def lms_derivative(tau): |
| 71 | + prod = 1.0 |
| 72 | + for k in range(order): |
| 73 | + if current_order == k: |
| 74 | + continue |
| 75 | + prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k]) |
| 76 | + return prod |
| 77 | + |
| 78 | + integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0] |
| 79 | + |
| 80 | + return integrated_coeff |
| 81 | + |
| 82 | + def set_timesteps(self, num_inference_steps): |
| 83 | + self.num_inference_steps = num_inference_steps |
| 84 | + self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) |
| 85 | + |
| 86 | + low_idx = np.floor(self.timesteps).astype(int) |
| 87 | + high_idx = np.ceil(self.timesteps).astype(int) |
| 88 | + frac = np.mod(self.timesteps, 1.0) |
| 89 | + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) |
| 90 | + sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] |
| 91 | + self.sigmas = np.concatenate([sigmas, [0.0]]) |
| 92 | + |
| 93 | + self.derivatives = [] |
| 94 | + |
| 95 | + self.set_format(tensor_format=self.tensor_format) |
| 96 | + |
| 97 | + def step( |
| 98 | + self, |
| 99 | + model_output: Union[torch.FloatTensor, np.ndarray], |
| 100 | + timestep: int, |
| 101 | + sample: Union[torch.FloatTensor, np.ndarray], |
| 102 | + order: int = 4, |
| 103 | + ): |
| 104 | + sigma = self.sigmas[timestep] |
| 105 | + |
| 106 | + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise |
| 107 | + pred_original_sample = sample - sigma * model_output |
| 108 | + |
| 109 | + # 2. Convert to an ODE derivative |
| 110 | + derivative = (sample - pred_original_sample) / sigma |
| 111 | + self.derivatives.append(derivative) |
| 112 | + if len(self.derivatives) > order: |
| 113 | + self.derivatives.pop(0) |
| 114 | + |
| 115 | + # 3. Compute linear multistep coefficients |
| 116 | + order = min(timestep + 1, order) |
| 117 | + lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)] |
| 118 | + |
| 119 | + # 4. Compute previous sample based on the derivatives path |
| 120 | + prev_sample = sample + sum( |
| 121 | + coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives)) |
| 122 | + ) |
| 123 | + |
| 124 | + return {"prev_sample": prev_sample} |
| 125 | + |
| 126 | + def add_noise(self, original_samples, noise, timesteps): |
| 127 | + alpha_prod = self.alphas_cumprod[timesteps] |
| 128 | + alpha_prod = self.match_shape(alpha_prod, original_samples) |
| 129 | + |
| 130 | + noisy_samples = (alpha_prod**0.5) * original_samples + ((1 - alpha_prod) ** 0.5) * noise |
| 131 | + return noisy_samples |
| 132 | + |
| 133 | + def __len__(self): |
| 134 | + return self.config.num_train_timesteps |
0 commit comments