11from math import sqrt
2- from typing import Any , Optional
2+ from typing import Any , Callable , Optional
33
44import torch
55import torch .nn as nn
99
1010from .utils import default , exists
1111
12- """ Samplers and sigma schedules """
12+ """ Distributions """
1313
1414
15- class SigmaSampler :
15+ class Distribution :
1616 def __call__ (self , num_samples : int , device : torch .device ):
1717 raise NotImplementedError ()
1818
1919
20- class LogNormalSampler ( SigmaSampler ):
20+ class LogNormalDistribution ( Distribution ):
2121 def __init__ (self , mean : float , std : float ):
2222 self .mean = mean
2323 self .std = std
@@ -29,15 +29,18 @@ def __call__(
2929 return normal .exp ()
3030
3131
32- class SigmaSchedule (nn .Module ):
33- """Interface used by different sampling sigma schedules"""
32+ """ Schedules """
33+
34+
35+ class Schedule (nn .Module ):
36+ """Interface used by different schedules"""
3437
3538 def forward (self , num_steps : int , device : torch .device ) -> Tensor :
3639 raise NotImplementedError ()
3740
3841
39- class KerrasSchedule ( SigmaSchedule ):
40- """https://arxiv.org/abs/2206.00364 eq. (5) """
42+ class KarrasSchedule ( Schedule ):
43+ """https://arxiv.org/abs/2206.00364 equation 5 """
4144
4245 def __init__ (self , sigma_min : float , sigma_max : float , rho : float = 7.0 ):
4346 super ().__init__ ()
@@ -57,21 +60,139 @@ def forward(self, num_steps: int, device: Any) -> Tensor:
5760 return sigmas
5861
5962
63+ """ Samplers """
64+
65+
66+ class Sampler (nn .Module ):
67+ def forward (
68+ self , noise : Tensor , fn : Callable , sigmas : Tensor , num_steps : int
69+ ) -> Tensor :
70+ raise NotImplementedError ()
71+
72+
73+ class KarrasSampler (Sampler ):
74+ """https://arxiv.org/abs/2206.00364 algorithm 1"""
75+
76+ def __init__ (
77+ self ,
78+ s_tmin : float = 0 ,
79+ s_tmax : float = float ("inf" ),
80+ s_churn : float = 0.0 ,
81+ s_noise : float = 1.0 ,
82+ ):
83+ super ().__init__ ()
84+ self .s_tmin = s_tmin
85+ self .s_tmax = s_tmax
86+ self .s_noise = s_noise
87+ self .s_churn = s_churn
88+
89+ def step (
90+ self ,
91+ x : Tensor ,
92+ fn : Callable ,
93+ sigma : float ,
94+ sigma_next : float ,
95+ gamma : float ,
96+ clamp : bool = True ,
97+ ) -> Tensor :
98+ """Algorithm 2 (step)"""
99+ # Select temporarily increased noise level
100+ sigma_hat = sigma + gamma * sigma
101+ # Add noise to move from sigma to sigma_hat
102+ epsilon = self .s_noise * torch .randn_like (x )
103+ x_hat = x + sqrt (sigma_hat ** 2 - sigma ** 2 ) * epsilon
104+ # Evaluate ∂x/∂sigma at sigma_hat
105+ d = (x_hat - fn (x_hat , sigma = sigma_hat , clamp = clamp )) / sigma_hat
106+ # Take euler step from sigma_hat to sigma_next
107+ x_next = x_hat + (sigma_next - sigma_hat ) * d
108+ # Second order correction
109+ if sigma_next != 0 :
110+ model_out_next = fn (x_next , sigma = sigma_next , clamp = clamp )
111+ d_prime = (x_next - model_out_next ) / sigma_next
112+ x_next = x_hat + 0.5 * (sigma - sigma_hat ) * (d + d_prime )
113+ return x_next
114+
115+ def forward (
116+ self , noise : Tensor , fn : Callable , sigmas : Tensor , num_steps : int
117+ ) -> Tensor :
118+ x = sigmas [0 ] * noise
119+ # Compute gammas
120+ gammas = torch .where (
121+ (sigmas >= self .s_tmin ) & (sigmas <= self .s_tmax ),
122+ min (self .s_churn / num_steps , sqrt (2 ) - 1 ),
123+ 0.0 ,
124+ )
125+ # Denoise to sample
126+ for i in range (num_steps - 1 ):
127+ x = self .step (
128+ x , fn = fn , sigma = sigmas [i ], sigma_next = sigmas [i + 1 ], gamma = gammas [i ] # type: ignore # noqa
129+ )
130+
131+ return x
132+
133+
134+ class ADPM2Sampler (Sampler ):
135+ """https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py"""
136+
137+ """ https://www.desmos.com/calculator/jbxjlqd9mb """
138+
139+ def __init__ (self , rho : float = 1.0 ):
140+ super ().__init__ ()
141+ self .rho = rho
142+
143+ def step (
144+ self ,
145+ x : Tensor ,
146+ fn : Callable ,
147+ sigma : float ,
148+ sigma_next : float ,
149+ clamp : bool = True ,
150+ ) -> Tensor :
151+ # Sigma steps
152+ r = self .rho
153+ sigma_up = sqrt (sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2 ) / sigma ** 2 )
154+ sigma_down = sqrt (sigma_next ** 2 - sigma_up ** 2 )
155+ sigma_mid = ((sigma ** (1 / r ) + sigma_down ** (1 / r )) / 2 ) ** r
156+ # Derivative at sigma (∂x/∂sigma)
157+ d = (x - fn (x , sigma = sigma , clamp = clamp )) / sigma
158+ # Denoise to midpoint
159+ x_mid = x + d * (sigma_mid - sigma )
160+ # Derivative at sigma_mid (∂x_mid/∂sigma_mid)
161+ d_mid = (x_mid - fn (x_mid , sigma = sigma_mid , clamp = clamp )) / sigma_mid
162+ # Denoise to next
163+ x = x + d_mid * (sigma_down - sigma )
164+ # Add randomness
165+ x_next = x + torch .randn_like (x ) * sigma_up
166+ return x_next
167+
168+ def forward (
169+ self , noise : Tensor , fn : Callable , sigmas : Tensor , num_steps : int
170+ ) -> Tensor :
171+ x = sigmas [0 ] * noise
172+ # Denoise to sample
173+ for i in range (num_steps - 1 ):
174+ x = self .step (x , fn = fn , sigma = sigmas [i ], sigma_next = sigmas [i + 1 ]) # type: ignore # noqa
175+ return x
176+
177+
178+ """ Diffusion Classes """
179+
180+
60181class Diffusion (nn .Module ):
61182 """Elucidated Diffusion: https://arxiv.org/abs/2206.00364"""
62183
63184 def __init__ (
64185 self ,
65186 net : nn .Module ,
66187 * ,
67- sigma_sampler : SigmaSampler ,
188+ sigma_distribution : Distribution ,
68189 sigma_data : float , # data distribution standard deviation
69190 ):
70191 super ().__init__ ()
71192
72193 self .net = net
73194 self .sigma_data = sigma_data
74- self .sigma_sampler = sigma_sampler
195+ self .sigma_distribution = sigma_distribution
75196
76197 def c_skip (self , sigmas : Tensor ) -> Tensor :
77198 return (self .sigma_data ** 2 ) / (sigmas ** 2 + self .sigma_data ** 2 )
@@ -121,7 +242,7 @@ def forward(self, x: Tensor, noise: Tensor = None) -> Tensor:
121242 batch , device = x .shape [0 ], x .device
122243
123244 # Sample amount of noise to add for each batch element
124- sigmas = self .sigma_sampler (num_samples = batch , device = device )
245+ sigmas = self .sigma_distribution (num_samples = batch , device = device )
125246 sigmas_padded = rearrange (sigmas , "b -> b 1 1" )
126247
127248 # Add noise to input
@@ -145,65 +266,25 @@ def __init__(
145266 self ,
146267 diffusion : Diffusion ,
147268 * ,
148- num_steps : int ,
149- sigma_schedule : SigmaSchedule ,
150- s_tmin : float = 0 ,
151- s_tmax : float = float ("inf" ),
152- s_churn : float = 0.0 ,
153- s_noise : float = 1.0 ,
269+ sampler : Sampler ,
270+ sigma_schedule : Schedule ,
271+ num_steps : Optional [int ] = None ,
154272 ):
155273 super ().__init__ ()
156274 self .denoise_fn = diffusion .denoise_fn
157- self .num_steps = num_steps
275+ self .sampler = sampler
158276 self .sigma_schedule = sigma_schedule
159- self .s_tmin = s_tmin
160- self .s_tmax = s_tmax
161- self .s_noise = s_noise
162- self .s_churn = s_churn
163-
164- def step (
165- self ,
166- x : Tensor ,
167- sigma : float ,
168- sigma_next : float ,
169- gamma : float ,
170- clamp : bool = True ,
171- ) -> Tensor :
172- """Algorithm 2 (step)"""
173- # Select temporarily increased noise level
174- sigma_hat = sigma + gamma * sigma
175- # Add noise to move from sigma to sigma_hat
176- epsilon = self .s_noise * torch .randn_like (x )
177- x_hat = x + sqrt (sigma_hat ** 2 - sigma ** 2 ) * epsilon
178- # Evaluate ∂x/∂sigma at sigma_hat
179- d = (x_hat - self .denoise_fn (x_hat , sigma = sigma_hat , clamp = clamp )) / sigma_hat
180- # Take euler step from sigma_hat to sigma_next
181- x_next = x_hat + (sigma_next - sigma_hat ) * d
182- # Second order correction
183- if sigma_next != 0 :
184- model_out_next = self .denoise_fn (x_next , sigma = sigma_next , clamp = clamp )
185- d_prime = (x_next - model_out_next ) / sigma_next
186- x_next = x_hat + 0.5 * (sigma - sigma_hat ) * (d + d_prime )
187- return x_next
277+ self .num_steps = num_steps
188278
189279 @torch .no_grad ()
190- def forward (self , x : Tensor , num_steps : int = None ) -> Tensor :
191- device = x .device
192- num_steps = default (num_steps , self .num_steps )
280+ def forward (self , noise : Tensor , num_steps : Optional [int ] = None ) -> Tensor :
281+ device = noise .device
282+ num_steps = default (num_steps , self .num_steps ) # type: ignore
283+ assert exists (num_steps ), "Parameter `num_steps` must be provided"
193284 # Compute sigmas using schedule
194285 sigmas = self .sigma_schedule (num_steps , device )
195- # Sample from first sigma distribution
196- x = sigmas [0 ] * x
197- # Compute gammas
198- gammas = torch .where (
199- (sigmas >= self .s_tmin ) & (sigmas <= self .s_tmax ),
200- min (self .s_churn / num_steps , sqrt (2 ) - 1 ),
201- 0.0 ,
202- )
203- # Denoise x
204- for i in range (num_steps - 1 ):
205- x = self .step (x , sigma = sigmas [i ], sigma_next = sigmas [i + 1 ], gamma = gammas [i ]) # type: ignore # noqa
206-
286+ # Sample using sampler
287+ x = self .sampler (noise , fn = self .denoise_fn , sigmas = sigmas , num_steps = num_steps )
207288 x = x .clamp (- 1.0 , 1.0 )
208289 return x
209290
@@ -217,7 +298,7 @@ def __init__(
217298 * ,
218299 num_steps : int ,
219300 num_resamples : int ,
220- sigma_schedule : SigmaSchedule ,
301+ sigma_schedule : Schedule ,
221302 s_tmin : float = 0 ,
222303 s_tmax : float = float ("inf" ),
223304 s_churn : float = 0.0 ,
0 commit comments