11import torch
2- import numpy as np
32from ldm_patched .ldm .modules .diffusionmodules .util import make_beta_schedule
43import math
54
@@ -12,12 +11,28 @@ def calculate_denoised(self, sigma, model_output, model_input):
1211 sigma = sigma .view (sigma .shape [:1 ] + (1 ,) * (model_output .ndim - 1 ))
1312 return model_input - model_output * sigma
1413
14+ def noise_scaling (self , sigma , noise , latent_image , max_denoise = False ):
15+ if max_denoise :
16+ noise = noise * torch .sqrt (1.0 + sigma ** 2.0 )
17+ else :
18+ noise = noise * sigma
19+
20+ noise += latent_image
21+ return noise
22+
23+ def inverse_noise_scaling (self , sigma , latent ):
24+ return latent
1525
1626class V_PREDICTION (EPS ):
1727 def calculate_denoised (self , sigma , model_output , model_input ):
1828 sigma = sigma .view (sigma .shape [:1 ] + (1 ,) * (model_output .ndim - 1 ))
1929 return model_input * self .sigma_data ** 2 / (sigma ** 2 + self .sigma_data ** 2 ) - model_output * sigma * self .sigma_data / (sigma ** 2 + self .sigma_data ** 2 ) ** 0.5
2030
31+ class EDM (V_PREDICTION ):
32+ def calculate_denoised (self , sigma , model_output , model_input ):
33+ sigma = sigma .view (sigma .shape [:1 ] + (1 ,) * (model_output .ndim - 1 ))
34+ return model_input * self .sigma_data ** 2 / (sigma ** 2 + self .sigma_data ** 2 ) + model_output * sigma * self .sigma_data / (sigma ** 2 + self .sigma_data ** 2 ) ** 0.5
35+
2136
2237class ModelSamplingDiscrete (torch .nn .Module ):
2338 def __init__ (self , model_config = None ):
@@ -42,24 +57,23 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps
4257 else :
4358 betas = make_beta_schedule (beta_schedule , timesteps , linear_start = linear_start , linear_end = linear_end , cosine_s = cosine_s )
4459 alphas = 1. - betas
45- alphas_cumprod = torch .tensor (np .cumprod (alphas , axis = 0 ), dtype = torch .float32 )
46- # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
60+ alphas_cumprod = torch .cumprod (alphas , dim = 0 )
4761
4862 timesteps , = betas .shape
4963 self .num_timesteps = int (timesteps )
5064 self .linear_start = linear_start
5165 self .linear_end = linear_end
5266
67+ # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
68+ # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
69+ # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
70+
5371 sigmas = ((1 - alphas_cumprod ) / alphas_cumprod ) ** 0.5
5472 self .set_sigmas (sigmas )
55- self .set_alphas_cumprod (alphas_cumprod .float ())
5673
5774 def set_sigmas (self , sigmas ):
58- self .register_buffer ('sigmas' , sigmas )
59- self .register_buffer ('log_sigmas' , sigmas .log ())
60-
61- def set_alphas_cumprod (self , alphas_cumprod ):
62- self .register_buffer ("alphas_cumprod" , alphas_cumprod .float ())
75+ self .register_buffer ('sigmas' , sigmas .float ())
76+ self .register_buffer ('log_sigmas' , sigmas .log ().float ())
6377
6478 @property
6579 def sigma_min (self ):
@@ -94,18 +108,18 @@ def percent_to_sigma(self, percent):
94108class ModelSamplingContinuousEDM (torch .nn .Module ):
95109 def __init__ (self , model_config = None ):
96110 super ().__init__ ()
97- self .sigma_data = 1.0
98-
99111 if model_config is not None :
100112 sampling_settings = model_config .sampling_settings
101113 else :
102114 sampling_settings = {}
103115
104116 sigma_min = sampling_settings .get ("sigma_min" , 0.002 )
105117 sigma_max = sampling_settings .get ("sigma_max" , 120.0 )
106- self .set_sigma_range (sigma_min , sigma_max )
118+ sigma_data = sampling_settings .get ("sigma_data" , 1.0 )
119+ self .set_parameters (sigma_min , sigma_max , sigma_data )
107120
108- def set_sigma_range (self , sigma_min , sigma_max ):
121+ def set_parameters (self , sigma_min , sigma_max , sigma_data ):
122+ self .sigma_data = sigma_data
109123 sigmas = torch .linspace (math .log (sigma_min ), math .log (sigma_max ), 1000 ).exp ()
110124
111125 self .register_buffer ('sigmas' , sigmas ) #for compatibility with some schedulers
@@ -134,3 +148,56 @@ def percent_to_sigma(self, percent):
134148
135149 log_sigma_min = math .log (self .sigma_min )
136150 return math .exp ((math .log (self .sigma_max ) - log_sigma_min ) * percent + log_sigma_min )
151+
152+ class StableCascadeSampling (ModelSamplingDiscrete ):
153+ def __init__ (self , model_config = None ):
154+ super ().__init__ ()
155+
156+ if model_config is not None :
157+ sampling_settings = model_config .sampling_settings
158+ else :
159+ sampling_settings = {}
160+
161+ self .set_parameters (sampling_settings .get ("shift" , 1.0 ))
162+
163+ def set_parameters (self , shift = 1.0 , cosine_s = 8e-3 ):
164+ self .shift = shift
165+ self .cosine_s = torch .tensor (cosine_s )
166+ self ._init_alpha_cumprod = torch .cos (self .cosine_s / (1 + self .cosine_s ) * torch .pi * 0.5 ) ** 2
167+
168+ #This part is just for compatibility with some schedulers in the codebase
169+ self .num_timesteps = 10000
170+ sigmas = torch .empty ((self .num_timesteps ), dtype = torch .float32 )
171+ for x in range (self .num_timesteps ):
172+ t = (x + 1 ) / self .num_timesteps
173+ sigmas [x ] = self .sigma (t )
174+
175+ self .set_sigmas (sigmas )
176+
177+ def sigma (self , timestep ):
178+ alpha_cumprod = (torch .cos ((timestep + self .cosine_s ) / (1 + self .cosine_s ) * torch .pi * 0.5 ) ** 2 / self ._init_alpha_cumprod )
179+
180+ if self .shift != 1.0 :
181+ var = alpha_cumprod
182+ logSNR = (var / (1 - var )).log ()
183+ logSNR += 2 * torch .log (1.0 / torch .tensor (self .shift ))
184+ alpha_cumprod = logSNR .sigmoid ()
185+
186+ alpha_cumprod = alpha_cumprod .clamp (0.0001 , 0.9999 )
187+ return ((1 - alpha_cumprod ) / alpha_cumprod ) ** 0.5
188+
189+ def timestep (self , sigma ):
190+ var = 1 / ((sigma * sigma ) + 1 )
191+ var = var .clamp (0 , 1.0 )
192+ s , min_var = self .cosine_s .to (var .device ), self ._init_alpha_cumprod .to (var .device )
193+ t = (((var * min_var ) ** 0.5 ).acos () / (torch .pi * 0.5 )) * (1 + s ) - s
194+ return t
195+
196+ def percent_to_sigma (self , percent ):
197+ if percent <= 0.0 :
198+ return 999999999.9
199+ if percent >= 1.0 :
200+ return 0.0
201+
202+ percent = 1.0 - percent
203+ return self .sigma (torch .tensor (percent ))
0 commit comments