@@ -62,6 +62,8 @@ def forward(self, num_steps: int, device: Any) -> Tensor:
6262
6363""" Samplers """
6464
65+ """ Many methods inspired by https://github.com/crowsonkb/k-diffusion/ """
66+
6567
6668class Sampler (nn .Module ):
6769 def forward (
@@ -136,9 +138,35 @@ def forward(
136138 return x
137139
138140
139- class ADPM2Sampler (Sampler ):
140- """https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py"""
141+ class AEulerSampler (Sampler ):
142+ def get_sigmas (self , sigma : float , sigma_next : float ) -> Tuple [float , float ]:
143+ sigma_up = sqrt (sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2 ) / sigma ** 2 )
144+ sigma_down = sqrt (sigma_next ** 2 - sigma_up ** 2 )
145+ return sigma_up , sigma_down
146+
147+ def step (self , x : Tensor , fn : Callable , sigma : float , sigma_next : float ) -> Tensor :
148+ # Sigma steps
149+ sigma_up , sigma_down = self .get_sigmas (sigma , sigma_next )
150+ # Derivative at sigma (∂x/∂sigma)
151+ d = (x - fn (x , sigma = sigma )) / sigma
152+ # Euler method
153+ x_next = x + d * (sigma_down - sigma )
154+ # Add randomness
155+ x_next = x_next + torch .randn_like (x ) * sigma_up
156+ print (sigma_up )
157+ return x_next
141158
159+ def forward (
160+ self , noise : Tensor , fn : Callable , sigmas : Tensor , num_steps : int
161+ ) -> Tensor :
162+ x = sigmas [0 ] * noise
163+ # Denoise to sample
164+ for i in range (num_steps - 1 ):
165+ x = self .step (x , fn = fn , sigma = sigmas [i ], sigma_next = sigmas [i + 1 ]) # type: ignore # noqa
166+ return x
167+
168+
169+ class ADPM2Sampler (Sampler ):
142170 """https://www.desmos.com/calculator/jbxjlqd9mb"""
143171
144172 def __init__ (self , rho : float = 1.0 ):
0 commit comments