22from comfy .ldm .modules .diffusionmodules .util import make_beta_schedule
33import math
44
5+ def rescale_zero_terminal_snr_sigmas (sigmas ):
6+ alphas_cumprod = 1 / ((sigmas * sigmas ) + 1 )
7+ alphas_bar_sqrt = alphas_cumprod .sqrt ()
8+
9+ # Store old values.
10+ alphas_bar_sqrt_0 = alphas_bar_sqrt [0 ].clone ()
11+ alphas_bar_sqrt_T = alphas_bar_sqrt [- 1 ].clone ()
12+
13+ # Shift so the last timestep is zero.
14+ alphas_bar_sqrt -= (alphas_bar_sqrt_T )
15+
16+ # Scale so the first timestep is back to the old value.
17+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T )
18+
19+ # Convert alphas_bar_sqrt to betas
20+ alphas_bar = alphas_bar_sqrt ** 2 # Revert sqrt
21+ alphas_bar [- 1 ] = 4.8973451890853435e-08
22+ return ((1 - alphas_bar ) / alphas_bar ) ** 0.5
23+
524class EPS :
625 def calculate_input (self , sigma , noise ):
726 sigma = sigma .view (sigma .shape [:1 ] + (1 ,) * (noise .ndim - 1 ))
@@ -48,7 +67,7 @@ def inverse_noise_scaling(self, sigma, latent):
4867 return latent / (1.0 - sigma )
4968
5069class ModelSamplingDiscrete (torch .nn .Module ):
51- def __init__ (self , model_config = None ):
70+ def __init__ (self , model_config = None , zsnr = None ):
5271 super ().__init__ ()
5372
5473 if model_config is not None :
@@ -61,11 +80,14 @@ def __init__(self, model_config=None):
6180 linear_end = sampling_settings .get ("linear_end" , 0.012 )
6281 timesteps = sampling_settings .get ("timesteps" , 1000 )
6382
64- self ._register_schedule (given_betas = None , beta_schedule = beta_schedule , timesteps = timesteps , linear_start = linear_start , linear_end = linear_end , cosine_s = 8e-3 )
83+ if zsnr is None :
84+ zsnr = sampling_settings .get ("zsnr" , False )
85+
86+ self ._register_schedule (given_betas = None , beta_schedule = beta_schedule , timesteps = timesteps , linear_start = linear_start , linear_end = linear_end , cosine_s = 8e-3 , zsnr = zsnr )
6587 self .sigma_data = 1.0
6688
6789 def _register_schedule (self , given_betas = None , beta_schedule = "linear" , timesteps = 1000 ,
68- linear_start = 1e-4 , linear_end = 2e-2 , cosine_s = 8e-3 ):
90+ linear_start = 1e-4 , linear_end = 2e-2 , cosine_s = 8e-3 , zsnr = False ):
6991 if given_betas is not None :
7092 betas = given_betas
7193 else :
@@ -83,6 +105,9 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps
83105 # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
84106
85107 sigmas = ((1 - alphas_cumprod ) / alphas_cumprod ) ** 0.5
108+ if zsnr :
109+ sigmas = rescale_zero_terminal_snr_sigmas (sigmas )
110+
86111 self .set_sigmas (sigmas )
87112
88113 def set_sigmas (self , sigmas ):
0 commit comments