Skip to content

Commit df197bf

Browse files
feat: add new sampler, refactor diffusion, v0.0.11
1 parent ebc9021 commit df197bf

File tree

6 files changed

+207
-113
lines changed

6 files changed

+207
-113
lines changed

README.md

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,25 @@ pip install audio-diffusion-pytorch
1414

1515
## Usage
1616

17+
```py
18+
19+
model = AudioDiffusionModel()
20+
21+
# Train model with audio sources [batch, channels, samples]
22+
x = torch.randn(2, 1, 2 ** 18)
23+
loss = net(x)
24+
loss.backward()
25+
26+
27+
# Sample given start noise
28+
noise = torch.randn(2, 1, 2 ** 18)
29+
sampled = net.sample(
30+
noise=noise,
31+
num_steps=5 # Range 1-100
32+
) # [2, 1, 2**18]
33+
```
34+
35+
## Usage with Components
1736

1837
### UNet1d
1938
```py
@@ -50,15 +69,15 @@ y = unet(x, t) # [2, 1, 32768], 2 samples of ~1.5 seconds of generated audio at
5069

5170
#### Training
5271
```python
53-
from audio_diffusion_pytorch import Diffusion, LogNormalSampler
72+
from audio_diffusion_pytorch import Diffusion, LogNormalDistribution
5473

5574
diffusion = Diffusion(
5675
net=unet,
57-
sigma_sampler=LogNormalSampler(mean = -3.0, std = 1.0),
76+
sigma_distribution=LogNormalDistribution(mean = -3.0, std = 1.0),
5877
sigma_data=0.1
5978
)
6079

61-
x = torch.randn(3, 1, 2 ** 16) # Batch of training audio samples
80+
x = torch.randn(3, 1, 2 ** 18) # Batch of training audio samples
6281
loss = diffusion(x)
6382
loss.backward() # Do this many times
6483
```
@@ -69,22 +88,21 @@ from audio_diffusion_pytorch import DiffusionSampler, KerrasSchedule
6988

7089
sampler = DiffusionSampler(
7190
diffusion,
72-
num_steps=50, # Range 32-1000, higher for better quality
73-
sigma_schedule=KerrasSchedule(
91+
num_steps=5, # Range 1-100, higher better quality but takes longer
92+
sampler=ADPM2Sampler(rho=1),
93+
sigma_schedule=KarrasSchedule(
7494
sigma_min=0.002,
7595
sigma_max=1
76-
),
77-
s_tmin=0,
78-
s_tmax=10,
79-
s_churn=40,
80-
s_noise=1.003
96+
)
8197
)
8298
# Generate a sample starting from the provided noise
83-
y = sampler(x = torch.randn(1,1,2 ** 15))
99+
y = sampler(noise = torch.randn(1,1,2 ** 18))
84100
```
85101

86102
#### Inpainting
87103

104+
Note: this uses an old version, needs to be updated.
105+
88106
```py
89107
from audio_diffusion_pytorch import DiffusionInpainter, KerrasSchedule
90108

audio_diffusion_pytorch/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from .diffusion import (
2+
ADPM2Sampler,
23
Diffusion,
34
DiffusionInpainter,
45
DiffusionSampler,
5-
KerrasSchedule,
6-
LogNormalSampler,
7-
SigmaSampler,
8-
SigmaSchedule,
6+
Distribution,
7+
KarrasSampler,
8+
KarrasSchedule,
9+
LogNormalDistribution,
10+
Sampler,
11+
Schedule,
912
SpanBySpanComposer,
1013
)
1114
from .model import AudioDiffusionModel, Model1d

audio_diffusion_pytorch/diffusion.py

Lines changed: 144 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from math import sqrt
2-
from typing import Any, Optional
2+
from typing import Any, Callable, Optional
33

44
import torch
55
import torch.nn as nn
@@ -9,15 +9,15 @@
99

1010
from .utils import default, exists
1111

12-
""" Samplers and sigma schedules """
12+
""" Distributions """
1313

1414

15-
class SigmaSampler:
15+
class Distribution:
1616
def __call__(self, num_samples: int, device: torch.device):
1717
raise NotImplementedError()
1818

1919

20-
class LogNormalSampler(SigmaSampler):
20+
class LogNormalDistribution(Distribution):
2121
def __init__(self, mean: float, std: float):
2222
self.mean = mean
2323
self.std = std
@@ -29,15 +29,18 @@ def __call__(
2929
return normal.exp()
3030

3131

32-
class SigmaSchedule(nn.Module):
33-
"""Interface used by different sampling sigma schedules"""
32+
""" Schedules """
33+
34+
35+
class Schedule(nn.Module):
36+
"""Interface used by different schedules"""
3437

3538
def forward(self, num_steps: int, device: torch.device) -> Tensor:
3639
raise NotImplementedError()
3740

3841

39-
class KerrasSchedule(SigmaSchedule):
40-
"""https://arxiv.org/abs/2206.00364 eq. (5)"""
42+
class KarrasSchedule(Schedule):
43+
"""https://arxiv.org/abs/2206.00364 equation 5"""
4144

4245
def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
4346
super().__init__()
@@ -57,21 +60,139 @@ def forward(self, num_steps: int, device: Any) -> Tensor:
5760
return sigmas
5861

5962

63+
""" Samplers """
64+
65+
66+
class Sampler(nn.Module):
67+
def forward(
68+
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
69+
) -> Tensor:
70+
raise NotImplementedError()
71+
72+
73+
class KarrasSampler(Sampler):
74+
"""https://arxiv.org/abs/2206.00364 algorithm 1"""
75+
76+
def __init__(
77+
self,
78+
s_tmin: float = 0,
79+
s_tmax: float = float("inf"),
80+
s_churn: float = 0.0,
81+
s_noise: float = 1.0,
82+
):
83+
super().__init__()
84+
self.s_tmin = s_tmin
85+
self.s_tmax = s_tmax
86+
self.s_noise = s_noise
87+
self.s_churn = s_churn
88+
89+
def step(
90+
self,
91+
x: Tensor,
92+
fn: Callable,
93+
sigma: float,
94+
sigma_next: float,
95+
gamma: float,
96+
clamp: bool = True,
97+
) -> Tensor:
98+
"""Algorithm 2 (step)"""
99+
# Select temporarily increased noise level
100+
sigma_hat = sigma + gamma * sigma
101+
# Add noise to move from sigma to sigma_hat
102+
epsilon = self.s_noise * torch.randn_like(x)
103+
x_hat = x + sqrt(sigma_hat ** 2 - sigma ** 2) * epsilon
104+
# Evaluate ∂x/∂sigma at sigma_hat
105+
d = (x_hat - fn(x_hat, sigma=sigma_hat, clamp=clamp)) / sigma_hat
106+
# Take euler step from sigma_hat to sigma_next
107+
x_next = x_hat + (sigma_next - sigma_hat) * d
108+
# Second order correction
109+
if sigma_next != 0:
110+
model_out_next = fn(x_next, sigma=sigma_next, clamp=clamp)
111+
d_prime = (x_next - model_out_next) / sigma_next
112+
x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime)
113+
return x_next
114+
115+
def forward(
116+
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
117+
) -> Tensor:
118+
x = sigmas[0] * noise
119+
# Compute gammas
120+
gammas = torch.where(
121+
(sigmas >= self.s_tmin) & (sigmas <= self.s_tmax),
122+
min(self.s_churn / num_steps, sqrt(2) - 1),
123+
0.0,
124+
)
125+
# Denoise to sample
126+
for i in range(num_steps - 1):
127+
x = self.step(
128+
x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1], gamma=gammas[i] # type: ignore # noqa
129+
)
130+
131+
return x
132+
133+
134+
class ADPM2Sampler(Sampler):
135+
"""https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py"""
136+
137+
""" https://www.desmos.com/calculator/jbxjlqd9mb """
138+
139+
def __init__(self, rho: float = 1.0):
140+
super().__init__()
141+
self.rho = rho
142+
143+
def step(
144+
self,
145+
x: Tensor,
146+
fn: Callable,
147+
sigma: float,
148+
sigma_next: float,
149+
clamp: bool = True,
150+
) -> Tensor:
151+
# Sigma steps
152+
r = self.rho
153+
sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
154+
sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
155+
sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
156+
# Derivative at sigma (∂x/∂sigma)
157+
d = (x - fn(x, sigma=sigma, clamp=clamp)) / sigma
158+
# Denoise to midpoint
159+
x_mid = x + d * (sigma_mid - sigma)
160+
# Derivative at sigma_mid (∂x_mid/∂sigma_mid)
161+
d_mid = (x_mid - fn(x_mid, sigma=sigma_mid, clamp=clamp)) / sigma_mid
162+
# Denoise to next
163+
x = x + d_mid * (sigma_down - sigma)
164+
# Add randomness
165+
x_next = x + torch.randn_like(x) * sigma_up
166+
return x_next
167+
168+
def forward(
169+
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
170+
) -> Tensor:
171+
x = sigmas[0] * noise
172+
# Denoise to sample
173+
for i in range(num_steps - 1):
174+
x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
175+
return x
176+
177+
178+
""" Diffusion Classes """
179+
180+
60181
class Diffusion(nn.Module):
61182
"""Elucidated Diffusion: https://arxiv.org/abs/2206.00364"""
62183

63184
def __init__(
64185
self,
65186
net: nn.Module,
66187
*,
67-
sigma_sampler: SigmaSampler,
188+
sigma_distribution: Distribution,
68189
sigma_data: float, # data distribution standard deviation
69190
):
70191
super().__init__()
71192

72193
self.net = net
73194
self.sigma_data = sigma_data
74-
self.sigma_sampler = sigma_sampler
195+
self.sigma_distribution = sigma_distribution
75196

76197
def c_skip(self, sigmas: Tensor) -> Tensor:
77198
return (self.sigma_data ** 2) / (sigmas ** 2 + self.sigma_data ** 2)
@@ -121,7 +242,7 @@ def forward(self, x: Tensor, noise: Tensor = None) -> Tensor:
121242
batch, device = x.shape[0], x.device
122243

123244
# Sample amount of noise to add for each batch element
124-
sigmas = self.sigma_sampler(num_samples=batch, device=device)
245+
sigmas = self.sigma_distribution(num_samples=batch, device=device)
125246
sigmas_padded = rearrange(sigmas, "b -> b 1 1")
126247

127248
# Add noise to input
@@ -145,65 +266,25 @@ def __init__(
145266
self,
146267
diffusion: Diffusion,
147268
*,
148-
num_steps: int,
149-
sigma_schedule: SigmaSchedule,
150-
s_tmin: float = 0,
151-
s_tmax: float = float("inf"),
152-
s_churn: float = 0.0,
153-
s_noise: float = 1.0,
269+
sampler: Sampler,
270+
sigma_schedule: Schedule,
271+
num_steps: Optional[int] = None,
154272
):
155273
super().__init__()
156274
self.denoise_fn = diffusion.denoise_fn
157-
self.num_steps = num_steps
275+
self.sampler = sampler
158276
self.sigma_schedule = sigma_schedule
159-
self.s_tmin = s_tmin
160-
self.s_tmax = s_tmax
161-
self.s_noise = s_noise
162-
self.s_churn = s_churn
163-
164-
def step(
165-
self,
166-
x: Tensor,
167-
sigma: float,
168-
sigma_next: float,
169-
gamma: float,
170-
clamp: bool = True,
171-
) -> Tensor:
172-
"""Algorithm 2 (step)"""
173-
# Select temporarily increased noise level
174-
sigma_hat = sigma + gamma * sigma
175-
# Add noise to move from sigma to sigma_hat
176-
epsilon = self.s_noise * torch.randn_like(x)
177-
x_hat = x + sqrt(sigma_hat ** 2 - sigma ** 2) * epsilon
178-
# Evaluate ∂x/∂sigma at sigma_hat
179-
d = (x_hat - self.denoise_fn(x_hat, sigma=sigma_hat, clamp=clamp)) / sigma_hat
180-
# Take euler step from sigma_hat to sigma_next
181-
x_next = x_hat + (sigma_next - sigma_hat) * d
182-
# Second order correction
183-
if sigma_next != 0:
184-
model_out_next = self.denoise_fn(x_next, sigma=sigma_next, clamp=clamp)
185-
d_prime = (x_next - model_out_next) / sigma_next
186-
x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime)
187-
return x_next
277+
self.num_steps = num_steps
188278

189279
@torch.no_grad()
190-
def forward(self, x: Tensor, num_steps: int = None) -> Tensor:
191-
device = x.device
192-
num_steps = default(num_steps, self.num_steps)
280+
def forward(self, noise: Tensor, num_steps: Optional[int] = None) -> Tensor:
281+
device = noise.device
282+
num_steps = default(num_steps, self.num_steps) # type: ignore
283+
assert exists(num_steps), "Parameter `num_steps` must be provided"
193284
# Compute sigmas using schedule
194285
sigmas = self.sigma_schedule(num_steps, device)
195-
# Sample from first sigma distribution
196-
x = sigmas[0] * x
197-
# Compute gammas
198-
gammas = torch.where(
199-
(sigmas >= self.s_tmin) & (sigmas <= self.s_tmax),
200-
min(self.s_churn / num_steps, sqrt(2) - 1),
201-
0.0,
202-
)
203-
# Denoise x
204-
for i in range(num_steps - 1):
205-
x = self.step(x, sigma=sigmas[i], sigma_next=sigmas[i + 1], gamma=gammas[i]) # type: ignore # noqa
206-
286+
# Sample using sampler
287+
x = self.sampler(noise, fn=self.denoise_fn, sigmas=sigmas, num_steps=num_steps)
207288
x = x.clamp(-1.0, 1.0)
208289
return x
209290

@@ -217,7 +298,7 @@ def __init__(
217298
*,
218299
num_steps: int,
219300
num_resamples: int,
220-
sigma_schedule: SigmaSchedule,
301+
sigma_schedule: Schedule,
221302
s_tmin: float = 0,
222303
s_tmax: float = float("inf"),
223304
s_churn: float = 0.0,

0 commit comments

Comments
 (0)