Skip to content

Commit 825498b

Browse files
feat: added ancestral euler sampler AEulerSampler
1 parent 786c4fc commit 825498b

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

audio_diffusion_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .diffusion import (
22
ADPM2Sampler,
3+
AEulerSampler,
34
Diffusion,
45
DiffusionInpainter,
56
DiffusionSampler,

audio_diffusion_pytorch/diffusion.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def forward(self, num_steps: int, device: Any) -> Tensor:
6262

6363
""" Samplers """
6464

65+
""" Many methods inspired by https://github.com/crowsonkb/k-diffusion/ """
66+
6567

6668
class Sampler(nn.Module):
6769
def forward(
@@ -136,9 +138,35 @@ def forward(
136138
return x
137139

138140

139-
class ADPM2Sampler(Sampler):
140-
"""https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py"""
141+
class AEulerSampler(Sampler):
142+
def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]:
143+
sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
144+
sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
145+
return sigma_up, sigma_down
146+
147+
def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
148+
# Sigma steps
149+
sigma_up, sigma_down = self.get_sigmas(sigma, sigma_next)
150+
# Derivative at sigma (∂x/∂sigma)
151+
d = (x - fn(x, sigma=sigma)) / sigma
152+
# Euler method
153+
x_next = x + d * (sigma_down - sigma)
154+
# Add randomness
155+
x_next = x_next + torch.randn_like(x) * sigma_up
156+
print(sigma_up)
157+
return x_next
141158

159+
def forward(
160+
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
161+
) -> Tensor:
162+
x = sigmas[0] * noise
163+
# Denoise to sample
164+
for i in range(num_steps - 1):
165+
x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
166+
return x
167+
168+
169+
class ADPM2Sampler(Sampler):
142170
"""https://www.desmos.com/calculator/jbxjlqd9mb"""
143171

144172
def __init__(self, rho: float = 1.0):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.19",
6+
version="0.0.20",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)