Skip to content

Commit aac10ad

Browse files
authored
Add SA-Solver sampler (Comfy-Org#8834)
1 parent 9742542 commit aac10ad

File tree

4 files changed

+278
-1
lines changed

4 files changed

+278
-1
lines changed

comfy/k_diffusion/sa_solver.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# SA-Solver: Stochastic Adams Solver (NeurIPS 2023, arXiv:2309.05019)
2+
# Conference: https://proceedings.neurips.cc/paper_files/paper/2023/file/f4a6806490d31216a3ba667eb240c897-Paper-Conference.pdf
3+
# Codebase ref: https://github.com/scxue/SA-Solver
4+
5+
import math
6+
from typing import Union, Callable
7+
import torch
8+
9+
10+
def compute_exponential_coeffs(s: torch.Tensor, t: torch.Tensor, solver_order: int, tau_t: float) -> torch.Tensor:
11+
"""Compute (1 + tau^2) * integral of exp((1 + tau^2) * x) * x^p dx from s to t with exp((1 + tau^2) * t) factored out, using integration by parts.
12+
13+
Integral of exp((1 + tau^2) * x) * x^p dx
14+
= product_terms[p] - (p / (1 + tau^2)) * integral of exp((1 + tau^2) * x) * x^(p-1) dx,
15+
with base case p=0 where integral equals product_terms[0].
16+
17+
where
18+
product_terms[p] = x^p * exp((1 + tau^2) * x) / (1 + tau^2).
19+
20+
Construct a recursive coefficient matrix following the above recursive relation to compute all integral terms up to p = (solver_order - 1).
21+
Return coefficients used by the SA-Solver in data prediction mode.
22+
23+
Args:
24+
s: Start time s.
25+
t: End time t.
26+
solver_order: Current order of the solver.
27+
tau_t: Stochastic strength parameter in the SDE.
28+
29+
Returns:
30+
Exponential coefficients used in data prediction, with exp((1 + tau^2) * t) factored out, ordered from p=0 to p=solver_order−1, shape (solver_order,).
31+
"""
32+
tau_mul = 1 + tau_t ** 2
33+
h = t - s
34+
p = torch.arange(solver_order, dtype=s.dtype, device=s.device)
35+
36+
# product_terms after factoring out exp((1 + tau^2) * t)
37+
# Includes (1 + tau^2) factor from outside the integral
38+
product_terms_factored = (t ** p - s ** p * (-tau_mul * h).exp())
39+
40+
# Lower triangular recursive coefficient matrix
41+
# Accumulates recursive coefficients based on p / (1 + tau^2)
42+
recursive_depth_mat = p.unsqueeze(1) - p.unsqueeze(0)
43+
log_factorial = (p + 1).lgamma()
44+
recursive_coeff_mat = log_factorial.unsqueeze(1) - log_factorial.unsqueeze(0)
45+
if tau_t > 0:
46+
recursive_coeff_mat = recursive_coeff_mat - (recursive_depth_mat * math.log(tau_mul))
47+
signs = torch.where(recursive_depth_mat % 2 == 0, 1.0, -1.0)
48+
recursive_coeff_mat = (recursive_coeff_mat.exp() * signs).tril()
49+
50+
return recursive_coeff_mat @ product_terms_factored
51+
52+
53+
def compute_simple_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, is_corrector_step: bool = False) -> torch.Tensor:
54+
"""Compute simple order-2 b coefficients from SA-Solver paper (Appendix D. Implementation Details)."""
55+
tau_mul = 1 + tau_t ** 2
56+
h = lambda_t - lambda_s
57+
alpha_t = sigma_next * lambda_t.exp()
58+
if is_corrector_step:
59+
# Simplified 1-step (order-2) corrector
60+
b_1 = alpha_t * (0.5 * tau_mul * h)
61+
b_2 = alpha_t * (-h * tau_mul).expm1().neg() - b_1
62+
else:
63+
# Simplified 2-step predictor
64+
b_2 = alpha_t * (0.5 * tau_mul * h ** 2) / (curr_lambdas[-2] - lambda_s)
65+
b_1 = alpha_t * (-h * tau_mul).expm1().neg() - b_2
66+
return torch.stack([b_2, b_1])
67+
68+
69+
def compute_stochastic_adams_b_coeffs(sigma_next: torch.Tensor, curr_lambdas: torch.Tensor, lambda_s: torch.Tensor, lambda_t: torch.Tensor, tau_t: float, simple_order_2: bool = False, is_corrector_step: bool = False) -> torch.Tensor:
70+
"""Compute b_i coefficients for the SA-Solver (see eqs. 15 and 18).
71+
72+
The solver order corresponds to the number of input lambdas (half-logSNR points).
73+
74+
Args:
75+
sigma_next: Sigma at end time t.
76+
curr_lambdas: Lambda time points used to construct the Lagrange basis, shape (N,).
77+
lambda_s: Lambda at start time s.
78+
lambda_t: Lambda at end time t.
79+
tau_t: Stochastic strength parameter in the SDE.
80+
simple_order_2: Whether to enable the simple order-2 scheme.
81+
is_corrector_step: Flag for corrector step in simple order-2 mode.
82+
83+
Returns:
84+
b_i coefficients for the SA-Solver, shape (N,), where N is the solver order.
85+
"""
86+
num_timesteps = curr_lambdas.shape[0]
87+
88+
if simple_order_2 and num_timesteps == 2:
89+
return compute_simple_stochastic_adams_b_coeffs(sigma_next, curr_lambdas, lambda_s, lambda_t, tau_t, is_corrector_step)
90+
91+
# Compute coefficients by solving a linear system from Lagrange basis interpolation
92+
exp_integral_coeffs = compute_exponential_coeffs(lambda_s, lambda_t, num_timesteps, tau_t)
93+
vandermonde_matrix_T = torch.vander(curr_lambdas, num_timesteps, increasing=True).T
94+
lagrange_integrals = torch.linalg.solve(vandermonde_matrix_T, exp_integral_coeffs)
95+
96+
# (sigma_t * exp(-tau^2 * lambda_t)) * exp((1 + tau^2) * lambda_t)
97+
# = sigma_t * exp(lambda_t) = alpha_t
98+
# exp((1 + tau^2) * lambda_t) is extracted from the integral
99+
alpha_t = sigma_next * lambda_t.exp()
100+
return alpha_t * lagrange_integrals
101+
102+
103+
def get_tau_interval_func(start_sigma: float, end_sigma: float, eta: float = 1.0) -> Callable[[Union[torch.Tensor, float]], float]:
104+
"""Return a function that controls the stochasticity of SA-Solver.
105+
106+
When eta = 0, SA-Solver runs as ODE. The official approach uses
107+
time t to determine the SDE interval, while here we use sigma instead.
108+
109+
See:
110+
https://github.com/scxue/SA-Solver/blob/main/README.md
111+
"""
112+
113+
def tau_func(sigma: Union[torch.Tensor, float]) -> float:
114+
if eta <= 0:
115+
return 0.0 # ODE
116+
117+
if isinstance(sigma, torch.Tensor):
118+
sigma = sigma.item()
119+
return eta if start_sigma >= sigma >= end_sigma else 0.0
120+
121+
return tau_func

comfy/k_diffusion/sampling.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from . import utils
1111
from . import deis
12+
from . import sa_solver
1213
import comfy.model_patcher
1314
import comfy.model_sampling
1415

@@ -1648,3 +1649,113 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
16481649
if inject_noise:
16491650
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
16501651
return x
1652+
1653+
1654+
@torch.no_grad()
1655+
def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, use_pece=False, simple_order_2=False):
1656+
"""Stochastic Adams Solver with predictor-corrector method (NeurIPS 2023)."""
1657+
if len(sigmas) <= 1:
1658+
return x
1659+
extra_args = {} if extra_args is None else extra_args
1660+
seed = extra_args.get("seed", None)
1661+
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
1662+
s_in = x.new_ones([x.shape[0]])
1663+
1664+
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
1665+
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
1666+
lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling)
1667+
1668+
if tau_func is None:
1669+
# Use default interval for stochastic sampling
1670+
start_sigma = model_sampling.percent_to_sigma(0.2)
1671+
end_sigma = model_sampling.percent_to_sigma(0.8)
1672+
tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=1.0)
1673+
1674+
max_used_order = max(predictor_order, corrector_order)
1675+
x_pred = x # x: current state, x_pred: predicted next state
1676+
1677+
h = 0.0
1678+
tau_t = 0.0
1679+
noise = 0.0
1680+
pred_list = []
1681+
1682+
# Lower order near the end to improve stability
1683+
lower_order_to_end = sigmas[-1].item() == 0
1684+
1685+
for i in trange(len(sigmas) - 1, disable=disable):
1686+
# Evaluation
1687+
denoised = model(x_pred, sigmas[i] * s_in, **extra_args)
1688+
if callback is not None:
1689+
callback({"x": x_pred, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
1690+
pred_list.append(denoised)
1691+
pred_list = pred_list[-max_used_order:]
1692+
1693+
predictor_order_used = min(predictor_order, len(pred_list))
1694+
if i == 0 or (sigmas[i + 1] == 0 and not use_pece):
1695+
corrector_order_used = 0
1696+
else:
1697+
corrector_order_used = min(corrector_order, len(pred_list))
1698+
1699+
if lower_order_to_end:
1700+
predictor_order_used = min(predictor_order_used, len(sigmas) - 2 - i)
1701+
corrector_order_used = min(corrector_order_used, len(sigmas) - 1 - i)
1702+
1703+
# Corrector
1704+
if corrector_order_used == 0:
1705+
# Update by the predicted state
1706+
x = x_pred
1707+
else:
1708+
curr_lambdas = lambdas[i - corrector_order_used + 1:i + 1]
1709+
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
1710+
sigmas[i],
1711+
curr_lambdas,
1712+
lambdas[i - 1],
1713+
lambdas[i],
1714+
tau_t,
1715+
simple_order_2,
1716+
is_corrector_step=True,
1717+
)
1718+
pred_mat = torch.stack(pred_list[-corrector_order_used:], dim=1) # (B, K, ...)
1719+
corr_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
1720+
x = sigmas[i] / sigmas[i - 1] * (-(tau_t ** 2) * h).exp() * x + corr_res
1721+
1722+
if tau_t > 0 and s_noise > 0:
1723+
# The noise from the previous predictor step
1724+
x = x + noise
1725+
1726+
if use_pece:
1727+
# Evaluate the corrected state
1728+
denoised = model(x, sigmas[i] * s_in, **extra_args)
1729+
pred_list[-1] = denoised
1730+
1731+
# Predictor
1732+
if sigmas[i + 1] == 0:
1733+
# Denoising step
1734+
x = denoised
1735+
else:
1736+
tau_t = tau_func(sigmas[i + 1])
1737+
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
1738+
b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs(
1739+
sigmas[i + 1],
1740+
curr_lambdas,
1741+
lambdas[i],
1742+
lambdas[i + 1],
1743+
tau_t,
1744+
simple_order_2,
1745+
is_corrector_step=False,
1746+
)
1747+
pred_mat = torch.stack(pred_list[-predictor_order_used:], dim=1) # (B, K, ...)
1748+
pred_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...)
1749+
h = lambdas[i + 1] - lambdas[i]
1750+
x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t ** 2) * h).exp() * x + pred_res
1751+
1752+
if tau_t > 0 and s_noise > 0:
1753+
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
1754+
x_pred = x_pred + noise
1755+
return x
1756+
1757+
1758+
@torch.no_grad()
1759+
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
1760+
"""Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
1761+
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)

comfy/samplers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ def max_denoise(self, model_wrap, sigmas):
720720
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
721721
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
722722
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
723-
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"]
723+
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
724724

725725
class KSAMPLER(Sampler):
726726
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):

comfy_extras/nodes_custom_sampler.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import comfy.samplers
33
import comfy.sample
44
from comfy.k_diffusion import sampling as k_diffusion_sampling
5+
from comfy.k_diffusion import sa_solver
56
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
67
import latent_preview
78
import torch
@@ -521,6 +522,49 @@ def reverse_time_sde_noise_scaler(x):
521522
return (sampler,)
522523

523524

525+
class SamplerSASolver(ComfyNodeABC):
526+
@classmethod
527+
def INPUT_TYPES(cls) -> InputTypeDict:
528+
return {
529+
"required": {
530+
"model": (IO.MODEL, {}),
531+
"eta": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "round": False},),
532+
"sde_start_percent": (IO.FLOAT, {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001},),
533+
"sde_end_percent": (IO.FLOAT, {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001},),
534+
"s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False},),
535+
"predictor_order": (IO.INT, {"default": 3, "min": 1, "max": 6}),
536+
"corrector_order": (IO.INT, {"default": 4, "min": 0, "max": 6}),
537+
"use_pece": (IO.BOOLEAN, {}),
538+
"simple_order_2": (IO.BOOLEAN, {}),
539+
}
540+
}
541+
542+
RETURN_TYPES = (IO.SAMPLER,)
543+
CATEGORY = "sampling/custom_sampling/samplers"
544+
545+
FUNCTION = "get_sampler"
546+
547+
def get_sampler(self, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2):
548+
model_sampling = model.get_model_object("model_sampling")
549+
start_sigma = model_sampling.percent_to_sigma(sde_start_percent)
550+
end_sigma = model_sampling.percent_to_sigma(sde_end_percent)
551+
tau_func = sa_solver.get_tau_interval_func(start_sigma, end_sigma, eta=eta)
552+
553+
sampler_name = "sa_solver"
554+
sampler = comfy.samplers.ksampler(
555+
sampler_name,
556+
{
557+
"tau_func": tau_func,
558+
"s_noise": s_noise,
559+
"predictor_order": predictor_order,
560+
"corrector_order": corrector_order,
561+
"use_pece": use_pece,
562+
"simple_order_2": simple_order_2,
563+
},
564+
)
565+
return (sampler,)
566+
567+
524568
class Noise_EmptyNoise:
525569
def __init__(self):
526570
self.seed = 0
@@ -829,6 +873,7 @@ def add_noise(self, model, noise, sigmas, latent_image):
829873
"SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
830874
"SamplerDPMAdaptative": SamplerDPMAdaptative,
831875
"SamplerER_SDE": SamplerER_SDE,
876+
"SamplerSASolver": SamplerSASolver,
832877
"SplitSigmas": SplitSigmas,
833878
"SplitSigmasDenoise": SplitSigmasDenoise,
834879
"FlipSigmas": FlipSigmas,

0 commit comments

Comments
 (0)