Skip to content

Commit 6c9b5d4

Browse files
feat: add true v-diffusion, separate vk-diffusion
1 parent 531dcee commit 6c9b5d4

File tree

5 files changed

+166
-50
lines changed

5 files changed

+166
-50
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ y = unet(x, t) # [3, 1, 32768], compute 3 samples of ~1.5 seconds at 22050Hz wit
171171

172172
#### Training
173173
```python
174-
from audio_diffusion_pytorch import KDiffusion, VDiffusion, LogNormalDistribution, VDistribution
174+
from audio_diffusion_pytorch import KDiffusion, LogNormalDistribution
175+
from audio_diffusion_pytorch import VDiffusion, UniformDistribution
175176

176177
# Either use KDiffusion
177178
diffusion = KDiffusion(
@@ -184,7 +185,7 @@ diffusion = KDiffusion(
184185
# Or use VDiffusion
185186
diffusion = VDiffusion(
186187
net=unet,
187-
sigma_distribution=VDistribution()
188+
sigma_distribution=UniformDistribution()
188189
)
189190

190191
x = torch.randn(3, 1, 2 ** 18) # Batch of training audio samples

audio_diffusion_pytorch/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
Sampler,
1313
Schedule,
1414
SpanBySpanComposer,
15+
UniformDistribution,
1516
VDiffusion,
16-
VDistribution,
17+
VKDiffusion,
18+
VKDistribution,
19+
VSampler,
1720
)
1821
from .model import (
1922
AudioDiffusionAutoencoder,

audio_diffusion_pytorch/diffusion.py

Lines changed: 142 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from math import atan, pi, sqrt
2-
from typing import Any, Callable, Optional, Tuple
1+
from math import atan, cos, pi, sin, sqrt
2+
from typing import Any, Callable, List, Optional, Tuple, Type
33

44
import torch
55
import torch.nn as nn
@@ -33,7 +33,12 @@ def __call__(
3333
return normal.exp()
3434

3535

36-
class VDistribution(Distribution):
36+
class UniformDistribution(Distribution):
37+
def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
38+
return torch.rand(num_samples, device=device)
39+
40+
41+
class VKDistribution(Distribution):
3742
def __init__(
3843
self,
3944
min_value: float = 0.0,
@@ -94,6 +99,8 @@ def to_batch(
9499

95100
class Diffusion(nn.Module):
96101

102+
alias: str = ""
103+
97104
"""Base diffusion class"""
98105

99106
def denoise_fn(
@@ -110,24 +117,19 @@ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
110117

111118

112119
class VDiffusion(Diffusion):
120+
121+
alias = "v"
122+
113123
def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
114124
super().__init__()
115125
self.net = net
116126
self.sigma_distribution = sigma_distribution
117127

118-
def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
119-
sigma_data = 1.0
120-
sigmas = rearrange(sigmas, "b -> b 1 1")
121-
c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
122-
c_out = -sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
123-
c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
124-
return c_skip, c_out, c_in
125-
126-
def sigma_to_t(self, sigmas: Tensor) -> Tensor:
127-
return sigmas.atan() / pi * 2
128-
129-
def t_to_sigma(self, t: Tensor) -> Tensor:
130-
return (t * pi / 2).tan()
128+
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
129+
angle = sigmas * pi / 2
130+
alpha = torch.cos(angle)
131+
beta = torch.sin(angle)
132+
return alpha, beta
131133

132134
def denoise_fn(
133135
self,
@@ -138,12 +140,7 @@ def denoise_fn(
138140
) -> Tensor:
139141
batch_size, device = x_noisy.shape[0], x_noisy.device
140142
sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
141-
142-
# Predict network output and add skip connection
143-
c_skip, c_out, c_in = self.get_scale_weights(sigmas)
144-
x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
145-
x_denoised = c_skip * x_noisy + c_out * x_pred
146-
return x_denoised
143+
return self.net(x_noisy, sigmas, **kwargs)
147144

148145
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
149146
batch_size, device = x.shape[0], x.device
@@ -152,25 +149,24 @@ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
152149
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
153150
sigmas_padded = rearrange(sigmas, "b -> b 1 1")
154151

155-
# Add noise to input
152+
# Get noise
156153
noise = default(noise, lambda: torch.randn_like(x))
157-
x_noisy = x + sigmas_padded * noise
158154

159-
# Compute model output
160-
c_skip, c_out, c_in = self.get_scale_weights(sigmas)
161-
x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
155+
# Combine input and noise weighted by half-circle
156+
alpha, beta = self.get_alpha_beta(sigmas_padded)
157+
x_noisy = x * alpha + noise * beta
158+
x_target = noise * alpha - x * beta
162159

163-
# Compute v-objective target
164-
v_target = (x - c_skip * x_noisy) / (c_out + 1e-7)
165-
166-
# Compute loss
167-
loss = F.mse_loss(x_pred, v_target)
168-
return loss
160+
# Denoise and return loss
161+
x_denoised = self.denoise_fn(x_noisy, sigmas, **kwargs)
162+
return F.mse_loss(x_denoised, x_target)
169163

170164

171165
class KDiffusion(Diffusion):
172166
"""Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364"""
173167

168+
alias = "k"
169+
174170
def __init__(
175171
self,
176172
net: nn.Module,
@@ -235,7 +231,68 @@ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
235231
losses = reduce(losses, "b ... -> b", "mean")
236232
losses = losses * self.loss_weight(sigmas)
237233
loss = losses.mean()
234+
return loss
235+
236+
237+
class VKDiffusion(Diffusion):
238+
239+
alias = "vk"
240+
241+
def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
242+
super().__init__()
243+
self.net = net
244+
self.sigma_distribution = sigma_distribution
245+
246+
def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
247+
sigma_data = 1.0
248+
sigmas = rearrange(sigmas, "b -> b 1 1")
249+
c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
250+
c_out = -sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
251+
c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
252+
return c_skip, c_out, c_in
238253

254+
def sigma_to_t(self, sigmas: Tensor) -> Tensor:
255+
return sigmas.atan() / pi * 2
256+
257+
def t_to_sigma(self, t: Tensor) -> Tensor:
258+
return (t * pi / 2).tan()
259+
260+
def denoise_fn(
261+
self,
262+
x_noisy: Tensor,
263+
sigmas: Optional[Tensor] = None,
264+
sigma: Optional[float] = None,
265+
**kwargs,
266+
) -> Tensor:
267+
batch_size, device = x_noisy.shape[0], x_noisy.device
268+
sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
269+
270+
# Predict network output and add skip connection
271+
c_skip, c_out, c_in = self.get_scale_weights(sigmas)
272+
x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
273+
x_denoised = c_skip * x_noisy + c_out * x_pred
274+
return x_denoised
275+
276+
def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
277+
batch_size, device = x.shape[0], x.device
278+
279+
# Sample amount of noise to add for each batch element
280+
sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
281+
sigmas_padded = rearrange(sigmas, "b -> b 1 1")
282+
283+
# Add noise to input
284+
noise = default(noise, lambda: torch.randn_like(x))
285+
x_noisy = x + sigmas_padded * noise
286+
287+
# Compute model output
288+
c_skip, c_out, c_in = self.get_scale_weights(sigmas)
289+
x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
290+
291+
# Compute v-objective target
292+
v_target = (x - c_skip * x_noisy) / (c_out + 1e-7)
293+
294+
# Compute loss
295+
loss = F.mse_loss(x_pred, v_target)
239296
return loss
240297

241298

@@ -253,6 +310,12 @@ def forward(self, num_steps: int, device: torch.device) -> Tensor:
253310
raise NotImplementedError()
254311

255312

313+
class LinearSchedule(Schedule):
314+
def forward(self, num_steps: int, device: Any) -> Tensor:
315+
sigmas = torch.linspace(1, 0, num_steps + 1)[:-1]
316+
return sigmas
317+
318+
256319
class KarrasSchedule(Schedule):
257320
"""https://arxiv.org/abs/2206.00364 equation 5"""
258321

@@ -278,6 +341,9 @@ def forward(self, num_steps: int, device: Any) -> Tensor:
278341

279342

280343
class Sampler(nn.Module):
344+
345+
diffusion_types: List[Type[Diffusion]] = []
346+
281347
def forward(
282348
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
283349
) -> Tensor:
@@ -295,9 +361,41 @@ def inpaint(
295361
raise NotImplementedError("Inpainting not available with current sampler")
296362

297363

364+
class VSampler(Sampler):
365+
366+
diffusion_types = [VDiffusion]
367+
368+
def get_alpha_beta(self, sigma: float) -> Tuple[float, float]:
369+
angle = sigma * pi / 2
370+
alpha = cos(angle)
371+
beta = sin(angle)
372+
return alpha, beta
373+
374+
def forward(
375+
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
376+
) -> Tensor:
377+
x = sigmas[0] * noise
378+
alpha, beta = self.get_alpha_beta(sigmas[0].item())
379+
380+
for i in range(num_steps - 1):
381+
is_last = i == num_steps - 1
382+
383+
x_denoised = fn(x, sigma=sigmas[i])
384+
x_pred = x * alpha - x_denoised * beta
385+
x_eps = x * beta + x_denoised * alpha
386+
387+
if not is_last:
388+
alpha, beta = self.get_alpha_beta(sigmas[i + 1].item())
389+
x = x_pred * alpha + x_eps * beta
390+
391+
return x
392+
393+
298394
class KarrasSampler(Sampler):
299395
"""https://arxiv.org/abs/2206.00364 algorithm 1"""
300396

397+
diffusion_types = [KDiffusion, VKDiffusion]
398+
301399
def __init__(
302400
self,
303401
s_tmin: float = 0,
@@ -351,6 +449,9 @@ def forward(
351449

352450

353451
class AEulerSampler(Sampler):
452+
453+
diffusion_types = [KDiffusion, VKDiffusion]
454+
354455
def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]:
355456
sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
356457
sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
@@ -380,6 +481,8 @@ def forward(
380481
class ADPM2Sampler(Sampler):
381482
"""https://www.desmos.com/calculator/jbxjlqd9mb"""
382483

484+
diffusion_types = [KDiffusion, VKDiffusion]
485+
383486
def __init__(self, rho: float = 1.0):
384487
super().__init__()
385488
self.rho = rho
@@ -459,6 +562,12 @@ def __init__(
459562
self.sigma_schedule = sigma_schedule
460563
self.num_steps = num_steps
461564

565+
# Check sampler is compatible with diffusion type
566+
sampler_class = sampler.__class__.__name__
567+
diffusion_class = diffusion.__class__.__name__
568+
message = f"{sampler_class} incompatible with {diffusion_class}"
569+
assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message
570+
462571
@torch.no_grad()
463572
def forward(
464573
self, noise: Tensor, num_steps: Optional[int] = None, **kwargs

audio_diffusion_pytorch/model.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
from torch import Tensor, nn
55

66
from .diffusion import (
7-
AEulerSampler,
8-
Diffusion,
97
DiffusionSampler,
10-
KarrasSchedule,
118
KDiffusion,
9+
LinearSchedule,
1210
Sampler,
1311
Schedule,
12+
UniformDistribution,
1413
VDiffusion,
15-
VDistribution,
14+
VKDiffusion,
15+
VSampler,
1616
)
1717
from .modules import (
1818
Bottleneck,
@@ -38,12 +38,15 @@ def __init__(
3838
UNet = UNetConditional1d if use_classifier_free_guidance else UNet1d
3939
self.unet = UNet(**kwargs)
4040

41-
if diffusion_type == "v":
42-
self.diffusion: Diffusion = VDiffusion(net=self.unet, **diffusion_kwargs)
43-
elif diffusion_type == "k":
44-
self.diffusion = KDiffusion(net=self.unet, **diffusion_kwargs)
45-
else:
46-
raise ValueError(f"diffusion_type must be v or k, found {diffusion_type}")
41+
# Check valid diffusion type
42+
diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion]
43+
aliases = [t.alias for t in diffusion_classes] # type: ignore
44+
message = f"diffusion_type='{diffusion_type}' must be one of {*aliases,}"
45+
assert diffusion_type in aliases, message
46+
47+
for XDiffusion in diffusion_classes:
48+
if XDiffusion.alias == diffusion_type: # type: ignore
49+
self.diffusion = XDiffusion(net=self.unet, **diffusion_kwargs)
4750

4851
def forward(self, x: Tensor, **kwargs) -> Tensor:
4952
return self.diffusion(x, **kwargs)
@@ -242,14 +245,14 @@ def get_default_model_kwargs():
242245
use_context_time=True,
243246
use_magnitude_channels=False,
244247
diffusion_type="v",
245-
diffusion_sigma_distribution=VDistribution(),
248+
diffusion_sigma_distribution=UniformDistribution(),
246249
)
247250

248251

249252
def get_default_sampling_kwargs():
250253
return dict(
251-
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
252-
sampler=AEulerSampler(),
254+
sigma_schedule=LinearSchedule(),
255+
sampler=VSampler(),
253256
)
254257

255258

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.66",
6+
version="0.0.67",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)