1- from math import atan , pi , sqrt
2- from typing import Any , Callable , Optional , Tuple
1+ from math import atan , cos , pi , sin , sqrt
2+ from typing import Any , Callable , List , Optional , Tuple , Type
33
44import torch
55import torch .nn as nn
@@ -33,7 +33,12 @@ def __call__(
3333 return normal .exp ()
3434
3535
36- class VDistribution (Distribution ):
36+ class UniformDistribution (Distribution ):
37+ def __call__ (self , num_samples : int , device : torch .device = torch .device ("cpu" )):
38+ return torch .rand (num_samples , device = device )
39+
40+
41+ class VKDistribution (Distribution ):
3742 def __init__ (
3843 self ,
3944 min_value : float = 0.0 ,
@@ -94,6 +99,8 @@ def to_batch(
9499
95100class Diffusion (nn .Module ):
96101
102+ alias : str = ""
103+
97104 """Base diffusion class"""
98105
99106 def denoise_fn (
@@ -110,24 +117,19 @@ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
110117
111118
112119class VDiffusion (Diffusion ):
120+
121+ alias = "v"
122+
113123 def __init__ (self , net : nn .Module , * , sigma_distribution : Distribution ):
114124 super ().__init__ ()
115125 self .net = net
116126 self .sigma_distribution = sigma_distribution
117127
118- def get_scale_weights (self , sigmas : Tensor ) -> Tuple [Tensor , ...]:
119- sigma_data = 1.0
120- sigmas = rearrange (sigmas , "b -> b 1 1" )
121- c_skip = (sigma_data ** 2 ) / (sigmas ** 2 + sigma_data ** 2 )
122- c_out = - sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2 ) ** - 0.5
123- c_in = (sigmas ** 2 + sigma_data ** 2 ) ** - 0.5
124- return c_skip , c_out , c_in
125-
126- def sigma_to_t (self , sigmas : Tensor ) -> Tensor :
127- return sigmas .atan () / pi * 2
128-
129- def t_to_sigma (self , t : Tensor ) -> Tensor :
130- return (t * pi / 2 ).tan ()
128+ def get_alpha_beta (self , sigmas : Tensor ) -> Tuple [Tensor , Tensor ]:
129+ angle = sigmas * pi / 2
130+ alpha = torch .cos (angle )
131+ beta = torch .sin (angle )
132+ return alpha , beta
131133
132134 def denoise_fn (
133135 self ,
@@ -138,12 +140,7 @@ def denoise_fn(
138140 ) -> Tensor :
139141 batch_size , device = x_noisy .shape [0 ], x_noisy .device
140142 sigmas = to_batch (x = sigma , xs = sigmas , batch_size = batch_size , device = device )
141-
142- # Predict network output and add skip connection
143- c_skip , c_out , c_in = self .get_scale_weights (sigmas )
144- x_pred = self .net (c_in * x_noisy , self .sigma_to_t (sigmas ), ** kwargs )
145- x_denoised = c_skip * x_noisy + c_out * x_pred
146- return x_denoised
143+ return self .net (x_noisy , sigmas , ** kwargs )
147144
148145 def forward (self , x : Tensor , noise : Tensor = None , ** kwargs ) -> Tensor :
149146 batch_size , device = x .shape [0 ], x .device
@@ -152,25 +149,24 @@ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
152149 sigmas = self .sigma_distribution (num_samples = batch_size , device = device )
153150 sigmas_padded = rearrange (sigmas , "b -> b 1 1" )
154151
155- # Add noise to input
152+ # Get noise
156153 noise = default (noise , lambda : torch .randn_like (x ))
157- x_noisy = x + sigmas_padded * noise
158154
159- # Compute model output
160- c_skip , c_out , c_in = self .get_scale_weights (sigmas )
161- x_pred = self .net (c_in * x_noisy , self .sigma_to_t (sigmas ), ** kwargs )
155+ # Combine input and noise weighted by half-circle
156+ alpha , beta = self .get_alpha_beta (sigmas_padded )
157+ x_noisy = x * alpha + noise * beta
158+ x_target = noise * alpha - x * beta
162159
163- # Compute v-objective target
164- v_target = (x - c_skip * x_noisy ) / (c_out + 1e-7 )
165-
166- # Compute loss
167- loss = F .mse_loss (x_pred , v_target )
168- return loss
160+ # Denoise and return loss
161+ x_denoised = self .denoise_fn (x_noisy , sigmas , ** kwargs )
162+ return F .mse_loss (x_denoised , x_target )
169163
170164
171165class KDiffusion (Diffusion ):
172166 """Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364"""
173167
168+ alias = "k"
169+
174170 def __init__ (
175171 self ,
176172 net : nn .Module ,
@@ -235,7 +231,68 @@ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
235231 losses = reduce (losses , "b ... -> b" , "mean" )
236232 losses = losses * self .loss_weight (sigmas )
237233 loss = losses .mean ()
234+ return loss
235+
236+
237+ class VKDiffusion (Diffusion ):
238+
239+ alias = "vk"
240+
241+ def __init__ (self , net : nn .Module , * , sigma_distribution : Distribution ):
242+ super ().__init__ ()
243+ self .net = net
244+ self .sigma_distribution = sigma_distribution
245+
246+ def get_scale_weights (self , sigmas : Tensor ) -> Tuple [Tensor , ...]:
247+ sigma_data = 1.0
248+ sigmas = rearrange (sigmas , "b -> b 1 1" )
249+ c_skip = (sigma_data ** 2 ) / (sigmas ** 2 + sigma_data ** 2 )
250+ c_out = - sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2 ) ** - 0.5
251+ c_in = (sigmas ** 2 + sigma_data ** 2 ) ** - 0.5
252+ return c_skip , c_out , c_in
238253
254+ def sigma_to_t (self , sigmas : Tensor ) -> Tensor :
255+ return sigmas .atan () / pi * 2
256+
257+ def t_to_sigma (self , t : Tensor ) -> Tensor :
258+ return (t * pi / 2 ).tan ()
259+
260+ def denoise_fn (
261+ self ,
262+ x_noisy : Tensor ,
263+ sigmas : Optional [Tensor ] = None ,
264+ sigma : Optional [float ] = None ,
265+ ** kwargs ,
266+ ) -> Tensor :
267+ batch_size , device = x_noisy .shape [0 ], x_noisy .device
268+ sigmas = to_batch (x = sigma , xs = sigmas , batch_size = batch_size , device = device )
269+
270+ # Predict network output and add skip connection
271+ c_skip , c_out , c_in = self .get_scale_weights (sigmas )
272+ x_pred = self .net (c_in * x_noisy , self .sigma_to_t (sigmas ), ** kwargs )
273+ x_denoised = c_skip * x_noisy + c_out * x_pred
274+ return x_denoised
275+
276+ def forward (self , x : Tensor , noise : Tensor = None , ** kwargs ) -> Tensor :
277+ batch_size , device = x .shape [0 ], x .device
278+
279+ # Sample amount of noise to add for each batch element
280+ sigmas = self .sigma_distribution (num_samples = batch_size , device = device )
281+ sigmas_padded = rearrange (sigmas , "b -> b 1 1" )
282+
283+ # Add noise to input
284+ noise = default (noise , lambda : torch .randn_like (x ))
285+ x_noisy = x + sigmas_padded * noise
286+
287+ # Compute model output
288+ c_skip , c_out , c_in = self .get_scale_weights (sigmas )
289+ x_pred = self .net (c_in * x_noisy , self .sigma_to_t (sigmas ), ** kwargs )
290+
291+ # Compute v-objective target
292+ v_target = (x - c_skip * x_noisy ) / (c_out + 1e-7 )
293+
294+ # Compute loss
295+ loss = F .mse_loss (x_pred , v_target )
239296 return loss
240297
241298
@@ -253,6 +310,12 @@ def forward(self, num_steps: int, device: torch.device) -> Tensor:
253310 raise NotImplementedError ()
254311
255312
313+ class LinearSchedule (Schedule ):
314+ def forward (self , num_steps : int , device : Any ) -> Tensor :
315+ sigmas = torch .linspace (1 , 0 , num_steps + 1 )[:- 1 ]
316+ return sigmas
317+
318+
256319class KarrasSchedule (Schedule ):
257320 """https://arxiv.org/abs/2206.00364 equation 5"""
258321
@@ -278,6 +341,9 @@ def forward(self, num_steps: int, device: Any) -> Tensor:
278341
279342
280343class Sampler (nn .Module ):
344+
345+ diffusion_types : List [Type [Diffusion ]] = []
346+
281347 def forward (
282348 self , noise : Tensor , fn : Callable , sigmas : Tensor , num_steps : int
283349 ) -> Tensor :
@@ -295,9 +361,41 @@ def inpaint(
295361 raise NotImplementedError ("Inpainting not available with current sampler" )
296362
297363
364+ class VSampler (Sampler ):
365+
366+ diffusion_types = [VDiffusion ]
367+
368+ def get_alpha_beta (self , sigma : float ) -> Tuple [float , float ]:
369+ angle = sigma * pi / 2
370+ alpha = cos (angle )
371+ beta = sin (angle )
372+ return alpha , beta
373+
374+ def forward (
375+ self , noise : Tensor , fn : Callable , sigmas : Tensor , num_steps : int
376+ ) -> Tensor :
377+ x = sigmas [0 ] * noise
378+ alpha , beta = self .get_alpha_beta (sigmas [0 ].item ())
379+
380+ for i in range (num_steps - 1 ):
381+ is_last = i == num_steps - 1
382+
383+ x_denoised = fn (x , sigma = sigmas [i ])
384+ x_pred = x * alpha - x_denoised * beta
385+ x_eps = x * beta + x_denoised * alpha
386+
387+ if not is_last :
388+ alpha , beta = self .get_alpha_beta (sigmas [i + 1 ].item ())
389+ x = x_pred * alpha + x_eps * beta
390+
391+ return x
392+
393+
298394class KarrasSampler (Sampler ):
299395 """https://arxiv.org/abs/2206.00364 algorithm 1"""
300396
397+ diffusion_types = [KDiffusion , VKDiffusion ]
398+
301399 def __init__ (
302400 self ,
303401 s_tmin : float = 0 ,
@@ -351,6 +449,9 @@ def forward(
351449
352450
353451class AEulerSampler (Sampler ):
452+
453+ diffusion_types = [KDiffusion , VKDiffusion ]
454+
354455 def get_sigmas (self , sigma : float , sigma_next : float ) -> Tuple [float , float ]:
355456 sigma_up = sqrt (sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2 ) / sigma ** 2 )
356457 sigma_down = sqrt (sigma_next ** 2 - sigma_up ** 2 )
@@ -380,6 +481,8 @@ def forward(
380481class ADPM2Sampler (Sampler ):
381482 """https://www.desmos.com/calculator/jbxjlqd9mb"""
382483
484+ diffusion_types = [KDiffusion , VKDiffusion ]
485+
383486 def __init__ (self , rho : float = 1.0 ):
384487 super ().__init__ ()
385488 self .rho = rho
@@ -459,6 +562,12 @@ def __init__(
459562 self .sigma_schedule = sigma_schedule
460563 self .num_steps = num_steps
461564
565+ # Check sampler is compatible with diffusion type
566+ sampler_class = sampler .__class__ .__name__
567+ diffusion_class = diffusion .__class__ .__name__
568+ message = f"{ sampler_class } incompatible with { diffusion_class } "
569+ assert diffusion .alias in [t .alias for t in sampler .diffusion_types ], message
570+
462571 @torch .no_grad ()
463572 def forward (
464573 self , noise : Tensor , num_steps : Optional [int ] = None , ** kwargs
0 commit comments