Skip to content

Commit 2a18e98

Browse files
Refactor so that zsnr can be set in the sampling_settings.
1 parent 8a52810 commit 2a18e98

File tree

2 files changed

+29
-25
lines changed

2 files changed

+29
-25
lines changed

comfy/model_sampling.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,25 @@
22
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
33
import math
44

5+
def rescale_zero_terminal_snr_sigmas(sigmas):
6+
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
7+
alphas_bar_sqrt = alphas_cumprod.sqrt()
8+
9+
# Store old values.
10+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
11+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
12+
13+
# Shift so the last timestep is zero.
14+
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
15+
16+
# Scale so the first timestep is back to the old value.
17+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
18+
19+
# Convert alphas_bar_sqrt to betas
20+
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
21+
alphas_bar[-1] = 4.8973451890853435e-08
22+
return ((1 - alphas_bar) / alphas_bar) ** 0.5
23+
524
class EPS:
625
def calculate_input(self, sigma, noise):
726
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
@@ -48,7 +67,7 @@ def inverse_noise_scaling(self, sigma, latent):
4867
return latent / (1.0 - sigma)
4968

5069
class ModelSamplingDiscrete(torch.nn.Module):
51-
def __init__(self, model_config=None):
70+
def __init__(self, model_config=None, zsnr=None):
5271
super().__init__()
5372

5473
if model_config is not None:
@@ -61,11 +80,14 @@ def __init__(self, model_config=None):
6180
linear_end = sampling_settings.get("linear_end", 0.012)
6281
timesteps = sampling_settings.get("timesteps", 1000)
6382

64-
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
83+
if zsnr is None:
84+
zsnr = sampling_settings.get("zsnr", False)
85+
86+
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3, zsnr=zsnr)
6587
self.sigma_data = 1.0
6688

6789
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
68-
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
90+
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zsnr=False):
6991
if given_betas is not None:
7092
betas = given_betas
7193
else:
@@ -83,6 +105,9 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps
83105
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
84106

85107
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
108+
if zsnr:
109+
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
110+
86111
self.set_sigmas(sigmas)
87112

88113
def set_sigmas(self, sigmas):

comfy_extras/nodes_model_advanced.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -51,25 +51,6 @@ def sigma(self, timestep):
5151
return log_sigma.exp().to(timestep.device)
5252

5353

54-
def rescale_zero_terminal_snr_sigmas(sigmas):
55-
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
56-
alphas_bar_sqrt = alphas_cumprod.sqrt()
57-
58-
# Store old values.
59-
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
60-
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
61-
62-
# Shift so the last timestep is zero.
63-
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
64-
65-
# Scale so the first timestep is back to the old value.
66-
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
67-
68-
# Convert alphas_bar_sqrt to betas
69-
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
70-
alphas_bar[-1] = 4.8973451890853435e-08
71-
return ((1 - alphas_bar) / alphas_bar) ** 0.5
72-
7354
class ModelSamplingDiscrete:
7455
@classmethod
7556
def INPUT_TYPES(s):
@@ -100,9 +81,7 @@ def patch(self, model, sampling, zsnr):
10081
class ModelSamplingAdvanced(sampling_base, sampling_type):
10182
pass
10283

103-
model_sampling = ModelSamplingAdvanced(model.model.model_config)
104-
if zsnr:
105-
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
84+
model_sampling = ModelSamplingAdvanced(model.model.model_config, zsnr=zsnr)
10685

10786
m.add_object_patch("model_sampling", model_sampling)
10887
return (m, )

0 commit comments

Comments
 (0)