|
| 1 | +# SA-Solver: Stochastic Adams Solver (NeurIPS 2023, arXiv:2309.05019) |
| 2 | +# Conference: https://proceedings.neurips.cc/paper_files/paper/2023/file/f4a6806490d31216a3ba667eb240c897-Paper-Conference.pdf |
| 3 | +# Codebase ref: https://github.com/scxue/SA-Solver |
| 4 | + |
| 5 | +import math |
| 6 | +from typing import Union, Callable |
| 7 | +import torch |
| 8 | + |
| 9 | + |
| 10 | +def compute_exponential_coeffs(s: torch.Tensor, t: torch.Tensor, solver_order: int, tau_t: float) -> torch.Tensor: |
| 11 | + """Compute (1 + tau^2) * integral of exp((1 + tau^2) * x) * x^p dx from s to t with exp((1 + tau^2) * t) factored out, using integration by parts. |
| 12 | +
|
| 13 | + Integral of exp((1 + tau^2) * x) * x^p dx |
| 14 | + = product_terms[p] - (p / (1 + tau^2)) * integral of exp((1 + tau^2) * x) * x^(p-1) dx, |
| 15 | + with base case p=0 where integral equals product_terms[0]. |
| 16 | +
|
| 17 | + where |
| 18 | + product_terms[p] = x^p * exp((1 + tau^2) * x) / (1 + tau^2). |
| 19 | +
|
| 20 | + Construct a recursive coefficient matrix following the above recursive relation to compute all integral terms up to p = (solver_order - 1). |
| 21 | + Return coefficients used by the SA-Solver in data prediction mode. |
| 22 | +
|
| 23 | + Args: |
| 24 | + s: Start time s. |
| 25 | + t: End time t. |
| 26 | + solver_order: Current order of the solver. |
| 27 | + tau_t: Stochastic strength parameter in the SDE. |
| 28 | +
|
| 29 | + Returns: |
| 30 | + Exponential coefficients used in data prediction, with exp((1 + tau^2) * t) factored out, ordered from p=0 to p=solver_order−1, shape (solver_order,). |
| 31 | + """ |
| 32 | + tau_mul = 1 + tau_t ** 2 |
| 33 | + h = t - s |
| 34 | + p = torch.arange(solver_order, dtype=s.dtype, device=s.device) |
| 35 | + |
| 36 | + # product_terms after factoring out exp((1 + tau^2) * t) |
| 37 | + # Includes (1 + tau^2) factor from outside the integral |
| 38 | + product_terms_factored = (t ** p - s ** p * (-tau_mul * h).exp()) |
| 39 | + |
| 40 | + # Lower triangular recursive coefficient matrix |
| 41 | + # Accumulates recursive coefficients based on p / (1 + tau^2) |
| 42 | + recursive_depth_mat = p.unsqueeze(1) - p.unsqueeze(0) |
| 43 | + log_factorial = (p + 1).lgamma() |
| 44 | + recursive_coeff_mat = log_factorial.unsqueeze(1) - log_factorial.unsqueeze(0) |
| 45 | + if tau_t > 0: |
| 46 | + recursive_coeff_mat = recursive_coeff_mat - (recursive_depth_mat * math.log(tau_mul)) |
| 47 | + signs = torch.where(recursive_depth_mat % 2 == 0, 1.0, -1.0) |
| 48 | + recursive_coeff_mat = (recursive_coeff_mat.exp() * signs).tril() |
| 49 | + |
| 50 | + return recursive_coeff_mat @ product_terms_factored |
| 51 | + |
| 52 | + |
| 53 | +def compute_simple_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, is_corrector_step: bool = False) -> torch.Tensor: |
| 54 | + """Compute simple order-2 b coefficients from SA-Solver paper (Appendix D. Implementation Details).""" |
| 55 | + tau_mul = 1 + tau_t ** 2 |
| 56 | + h = lambda_t - lambda_s |
| 57 | + alpha_t = sigma_next * lambda_t.exp() |
| 58 | + if is_corrector_step: |
| 59 | + # Simplified 1-step (order-2) corrector |
| 60 | + b_1 = alpha_t * (0.5 * tau_mul * h) |
| 61 | + b_2 = alpha_t * (-h * tau_mul).expm1().neg() - b_1 |
| 62 | + else: |
| 63 | + # Simplified 2-step predictor |
| 64 | + b_2 = alpha_t * (0.5 * tau_mul * h ** 2) / (curr_lambdas[-2] - lambda_s) |
| 65 | + b_1 = alpha_t * (-h * tau_mul).expm1().neg() - b_2 |
| 66 | + return torch.stack([b_2, b_1]) |
| 67 | + |
| 68 | + |
| 69 | +def compute_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, simple_order_2: bool = False, is_corrector_step: bool = False) -> torch.Tensor: |
| 70 | + """Compute b_i coefficients for the SA-Solver (see eqs. 15 and 18). |
| 71 | +
|
| 72 | + The solver order corresponds to the number of input lambdas (half-logSNR points). |
| 73 | +
|
| 74 | + Args: |
| 75 | + sigma_next: Sigma at end time t. |
| 76 | + curr_lambdas: Lambda time points used to construct the Lagrange basis, shape (N,). |
| 77 | + lambda_s: Lambda at start time s. |
| 78 | + lambda_t: Lambda at end time t. |
| 79 | + tau_t: Stochastic strength parameter in the SDE. |
| 80 | + simple_order_2: Whether to enable the simple order-2 scheme. |
| 81 | + is_corrector_step: Flag for corrector step in simple order-2 mode. |
| 82 | +
|
| 83 | + Returns: |
| 84 | + b_i coefficients for the SA-Solver, shape (N,), where N is the solver order. |
| 85 | + """ |
| 86 | + num_timesteps = curr_lambdas.shape[0] |
| 87 | + |
| 88 | + if simple_order_2 and num_timesteps == 2: |
| 89 | + return compute_simple_stochastic_adams_b_coeffs(sigma_next, curr_lambdas, lambda_s, lambda_t, tau_t, is_corrector_step) |
| 90 | + |
| 91 | + # Compute coefficients by solving a linear system from Lagrange basis interpolation |
| 92 | + exp_integral_coeffs = compute_exponential_coeffs(lambda_s, lambda_t, num_timesteps, tau_t) |
| 93 | + vandermonde_matrix_T = torch.vander(curr_lambdas, num_timesteps, increasing=True).T |
| 94 | + lagrange_integrals = torch.linalg.solve(vandermonde_matrix_T, exp_integral_coeffs) |
| 95 | + |
| 96 | + # (sigma_t * exp(-tau^2 * lambda_t)) * exp((1 + tau^2) * lambda_t) |
| 97 | + # = sigma_t * exp(lambda_t) = alpha_t |
| 98 | + # exp((1 + tau^2) * lambda_t) is extracted from the integral |
| 99 | + alpha_t = sigma_next * lambda_t.exp() |
| 100 | + return alpha_t * lagrange_integrals |
| 101 | + |
| 102 | + |
| 103 | +def get_tau_interval_func(start_sigma: float, end_sigma: float, eta: float = 1.0) -> Callable[[Union[torch.Tensor, float]], float]: |
| 104 | + """Return a function that controls the stochasticity of SA-Solver. |
| 105 | +
|
| 106 | + When eta = 0, SA-Solver runs as ODE. The official approach uses |
| 107 | + time t to determine the SDE interval, while here we use sigma instead. |
| 108 | +
|
| 109 | + See: |
| 110 | + https://github.com/scxue/SA-Solver/blob/main/README.md |
| 111 | + """ |
| 112 | + |
| 113 | + def tau_func(sigma: Union[torch.Tensor, float]) -> float: |
| 114 | + if eta <= 0: |
| 115 | + return 0.0 # ODE |
| 116 | + |
| 117 | + if isinstance(sigma, torch.Tensor): |
| 118 | + sigma = sigma.item() |
| 119 | + return eta if start_sigma >= sigma >= end_sigma else 0.0 |
| 120 | + |
| 121 | + return tau_func |
0 commit comments