Skip to content

Commit d16dfa7

Browse files
feat: add variational bottleneck, add option to change diffusion type
1 parent 68da808 commit d16dfa7

File tree

5 files changed

+95
-22
lines changed

5 files changed

+95
-22
lines changed

audio_diffusion_pytorch/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,10 @@
2424
DiffusionUpsampler1d,
2525
Model1d,
2626
)
27-
from .modules import AutoEncoder1d, MultiEncoder1d, UNet1d, UNetConditional1d
27+
from .modules import (
28+
AutoEncoder1d,
29+
MultiEncoder1d,
30+
UNet1d,
31+
UNetConditional1d,
32+
Variational,
33+
)

audio_diffusion_pytorch/model.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55

66
from .diffusion import (
77
AEulerSampler,
8+
Diffusion,
89
DiffusionSampler,
9-
Distribution,
1010
KarrasSchedule,
11+
KDiffusion,
1112
Sampler,
1213
Schedule,
1314
VDiffusion,
@@ -20,7 +21,7 @@
2021
UNet1d,
2122
UNetConditional1d,
2223
)
23-
from .utils import default, downsample, exists, to_list, upsample
24+
from .utils import default, downsample, exists, groupby_kwargs_prefix, to_list, upsample
2425

2526
"""
2627
Diffusion Classes (generic for 1d data)
@@ -29,20 +30,20 @@
2930

3031
class Model1d(nn.Module):
3132
def __init__(
32-
self,
33-
diffusion_sigma_distribution: Distribution,
34-
use_classifier_free_guidance: bool = False,
35-
**kwargs
33+
self, diffusion_type: str, use_classifier_free_guidance: bool = False, **kwargs
3634
):
3735
super().__init__()
36+
diffusion_kwargs, kwargs = groupby_kwargs_prefix("diffusion_", kwargs)
3837

3938
UNet = UNetConditional1d if use_classifier_free_guidance else UNet1d
40-
4139
self.unet = UNet(**kwargs)
4240

43-
self.diffusion = VDiffusion(
44-
net=self.unet, sigma_distribution=diffusion_sigma_distribution
45-
)
41+
if diffusion_type == "v":
42+
self.diffusion: Diffusion = VDiffusion(net=self.unet, **diffusion_kwargs)
43+
elif diffusion_type == "k":
44+
self.diffusion = KDiffusion(net=self.unet, **diffusion_kwargs)
45+
else:
46+
raise ValueError(f"diffusion_type must be v or k, found {diffusion_type}")
4647

4748
def forward(self, x: Tensor, **kwargs) -> Tensor:
4849
return self.diffusion(x, **kwargs)
@@ -53,7 +54,7 @@ def sample(
5354
num_steps: int,
5455
sigma_schedule: Schedule,
5556
sampler: Sampler,
56-
**kwargs
57+
**kwargs,
5758
) -> Tensor:
5859
diffusion_sampler = DiffusionSampler(
5960
diffusion=self.diffusion,
@@ -71,7 +72,7 @@ def __init__(
7172
factor: Union[int, Sequence[int]],
7273
factor_features: Optional[int] = None,
7374
*args,
74-
**kwargs
75+
**kwargs,
7576
):
7677
self.factors = to_list(factor)
7778
self.use_conditioning = exists(factor_features)
@@ -144,7 +145,7 @@ def __init__(
144145
bottleneck: Optional[Bottleneck] = None,
145146
encoder_num_blocks: Optional[Sequence[int]] = None,
146147
encoder_out_layers: int = 0,
147-
**kwargs
148+
**kwargs,
148149
):
149150
self.in_channels = in_channels
150151
encoder_num_blocks = default(encoder_num_blocks, num_blocks)
@@ -240,6 +241,7 @@ def get_default_model_kwargs():
240241
use_skip_scale=True,
241242
use_context_time=True,
242243
use_magnitude_channels=False,
244+
diffusion_type="v",
243245
diffusion_sigma_distribution=VDistribution(),
244246
)
245247

@@ -289,7 +291,7 @@ def __init__(
289291
embedding_features: int,
290292
embedding_max_length: int,
291293
embedding_mask_proba: float = 0.1,
292-
**kwargs
294+
**kwargs,
293295
):
294296
self.embedding_mask_proba = embedding_mask_proba
295297
default_kwargs = dict(

audio_diffusion_pytorch/modules.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,10 +1138,46 @@ def forward( # type: ignore
11381138
class Bottleneck(nn.Module):
11391139
"""Bottleneck interface (subclass can be provided to (Diffusion)Autoencoder1d)"""
11401140

1141-
def forward(self, x: Tensor) -> Tuple[Tensor, Any]:
1141+
def forward(
1142+
self, x: Tensor, with_info: bool = False
1143+
) -> Union[Tensor, Tuple[Tensor, Any]]:
11421144
raise NotImplementedError()
11431145

11441146

1147+
def gaussian_sample(mean: Tensor, logvar: Tensor) -> Tensor:
1148+
std = torch.exp(0.5 * logvar)
1149+
eps = torch.randn_like(std)
1150+
sample = mean + std * eps
1151+
return sample
1152+
1153+
1154+
def kl_loss(mean: Tensor, logvar: Tensor) -> Tensor:
1155+
losses = mean ** 2 + logvar.exp() - logvar - 1
1156+
loss = reduce(losses, "b ... -> 1", "mean").item()
1157+
return loss
1158+
1159+
1160+
class Variational(Bottleneck):
1161+
def __init__(self, channels: int, loss_weight: float = 1.0):
1162+
super().__init__()
1163+
self.loss_weight = loss_weight
1164+
self.to_mean_and_logvar = Conv1d(
1165+
in_channels=channels,
1166+
out_channels=channels * 2,
1167+
kernel_size=1,
1168+
)
1169+
1170+
def forward(
1171+
self, x: Tensor, with_info: bool = False
1172+
) -> Union[Tensor, Tuple[Tensor, Any]]:
1173+
mean_and_logvar = self.to_mean_and_logvar(x)
1174+
mean, logvar = torch.chunk(mean_and_logvar, chunks=2, dim=1)
1175+
logvar = torch.clamp(logvar, -30.0, 20.0)
1176+
out = gaussian_sample(mean, logvar)
1177+
loss = kl_loss(mean, logvar) * self.loss_weight
1178+
return (out, dict(loss=loss)) if with_info else out
1179+
1180+
11451181
class AutoEncoder1d(nn.Module):
11461182
def __init__(
11471183
self,
@@ -1208,18 +1244,28 @@ def __init__(
12081244
factor=patch_factor,
12091245
)
12101246

1211-
def encode(
1247+
def forward(
12121248
self, x: Tensor, with_info: bool = False
12131249
) -> Union[Tensor, Tuple[Tensor, Any]]:
1250+
z, info = self.encode(x, with_info=True)
1251+
y = self.decode(z)
1252+
return (y, info) if with_info else y
12141253

1254+
def encode(
1255+
self, x: Tensor, with_info: bool = False
1256+
) -> Union[Tensor, Tuple[Tensor, Any]]:
1257+
xs = []
12151258
x = self.to_in(x)
12161259
for downsample in self.downsamples:
12171260
x = downsample(x)
1261+
xs += [x]
1262+
info = dict(xs=xs)
12181263

12191264
if exists(self.bottleneck):
1220-
x, info = self.bottleneck(x)
1221-
return (x, info) if with_info else x
1222-
return x
1265+
x, info_bottleneck = self.bottleneck(x, with_info=True)
1266+
info = {**info, **info_bottleneck}
1267+
1268+
return (x, info) if with_info else x
12231269

12241270
def decode(self, x: Tensor) -> Tensor:
12251271
for upsample in self.upsamples:

audio_diffusion_pytorch/utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
from functools import reduce
33
from inspect import isfunction
4-
from typing import Callable, List, Optional, Sequence, TypeVar, Union
4+
from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
55

66
import torch
77
import torch.nn.functional as F
@@ -42,6 +42,25 @@ def prod(vals: Sequence[int]) -> int:
4242
return reduce(lambda x, y: x * y, vals)
4343

4444

45+
"""
46+
Kwargs Utils
47+
"""
48+
49+
50+
def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
51+
return_dicts: Tuple[Dict, Dict] = ({}, {})
52+
for key in d.keys():
53+
no_prefix = int(not key.startswith(prefix))
54+
return_dicts[no_prefix][key] = d[key]
55+
return return_dicts
56+
57+
58+
def groupby_kwargs_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
59+
kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
60+
kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
61+
return kwargs_no_prefix, kwargs
62+
63+
4564
"""
4665
DSP Utils
4766
"""

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.60",
6+
version="0.0.61",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)