Skip to content

Commit 01015bf

Browse files
authored
Add er_sde sampler (#7187)
1 parent 2330754 commit 01015bf

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

comfy/k_diffusion/sampling.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,3 +1366,59 @@ def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None,
13661366
x = x + d_bar * dt
13671367
old_d = d
13681368
return x
1369+
1370+
@torch.no_grad()
1371+
def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
1372+
"""
1373+
Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
1374+
Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
1375+
"""
1376+
extra_args = {} if extra_args is None else extra_args
1377+
seed = extra_args.get("seed", None)
1378+
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
1379+
s_in = x.new_ones([x.shape[0]])
1380+
1381+
def default_noise_scaler(sigma):
1382+
return sigma * ((sigma ** 0.3).exp() + 10.0)
1383+
noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
1384+
num_integration_points = 200.0
1385+
point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
1386+
1387+
old_denoised = None
1388+
old_denoised_d = None
1389+
1390+
for i in trange(len(sigmas) - 1, disable=disable):
1391+
denoised = model(x, sigmas[i] * s_in, **extra_args)
1392+
if callback is not None:
1393+
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1394+
stage_used = min(max_stage, i + 1)
1395+
if sigmas[i + 1] == 0:
1396+
x = denoised
1397+
elif stage_used == 1:
1398+
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
1399+
x = r * x + (1 - r) * denoised
1400+
else:
1401+
r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
1402+
x = r * x + (1 - r) * denoised
1403+
1404+
dt = sigmas[i + 1] - sigmas[i]
1405+
sigma_step_size = -dt / num_integration_points
1406+
sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
1407+
scaled_pos = noise_scaler(sigma_pos)
1408+
1409+
# Stage 2
1410+
s = torch.sum(1 / scaled_pos) * sigma_step_size
1411+
denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
1412+
x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
1413+
1414+
if stage_used >= 3:
1415+
# Stage 3
1416+
s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
1417+
denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
1418+
x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
1419+
old_denoised_d = denoised_d
1420+
1421+
if s_noise != 0 and sigmas[i + 1] > 0:
1422+
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt()
1423+
old_denoised = denoised
1424+
return x

comfy/samplers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ def max_denoise(self, model_wrap, sigmas):
710710
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
711711
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
712712
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
713-
"gradient_estimation"]
713+
"gradient_estimation", "er_sde"]
714714

715715
class KSAMPLER(Sampler):
716716
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):

0 commit comments

Comments
 (0)