Skip to content

Commit 8b2ccc7

Browse files
feat: add diffusion autoencoder
1 parent 4fdcb35 commit 8b2ccc7

File tree

3 files changed

+156
-22
lines changed

3 files changed

+156
-22
lines changed

audio_diffusion_pytorch/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,12 @@
1212
Schedule,
1313
SpanBySpanComposer,
1414
)
15-
from .model import AudioDiffusionModel, AudioDiffusionUpsampler, Model1d
15+
from .model import (
16+
AudioDiffusionAutoencoder,
17+
AudioDiffusionModel,
18+
AudioDiffusionUpsampler,
19+
DiffusionAutoencoder1d,
20+
DiffusionUpsampler1d,
21+
Model1d,
22+
)
1623
from .modules import Encoder1d, UNet1d

audio_diffusion_pytorch/model.py

Lines changed: 147 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from math import prod
12
from typing import Optional, Sequence
23

34
import torch
@@ -13,7 +14,9 @@
1314
Sampler,
1415
Schedule,
1516
)
16-
from .modules import UNet1d
17+
from .modules import Encoder1d, ResnetBlock1d, UNet1d
18+
19+
""" Diffusion Classes (generic for 1d data) """
1720

1821

1922
class Model1d(nn.Module):
@@ -47,8 +50,6 @@ def __init__(
4750
in_channels=in_channels,
4851
channels=channels,
4952
patch_size=patch_size,
50-
resnet_groups=resnet_groups,
51-
kernel_multiplier_downsample=kernel_multiplier_downsample,
5253
kernel_sizes_init=kernel_sizes_init,
5354
multipliers=multipliers,
5455
factors=factors,
@@ -57,9 +58,11 @@ def __init__(
5758
attention_heads=attention_heads,
5859
attention_features=attention_features,
5960
attention_multiplier=attention_multiplier,
61+
use_attention_bottleneck=use_attention_bottleneck,
62+
resnet_groups=resnet_groups,
63+
kernel_multiplier_downsample=kernel_multiplier_downsample,
6064
use_nearest_upsample=use_nearest_upsample,
6165
use_skip_scale=use_skip_scale,
62-
use_attention_bottleneck=use_attention_bottleneck,
6366
out_channels=out_channels,
6467
context_channels=context_channels,
6568
)
@@ -91,10 +94,110 @@ def sample(
9194
return diffusion_sampler(noise, **kwargs)
9295

9396

97+
class DiffusionUpsampler1d(Model1d):
98+
def __init__(self, factor: int, in_channels: int, *args, **kwargs):
99+
self.factor = factor
100+
default_kwargs = dict(
101+
in_channels=in_channels,
102+
context_channels=[in_channels],
103+
)
104+
super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore
105+
106+
def forward(self, x: Tensor, **kwargs) -> Tensor:
107+
# Downsample by picking every `factor` item
108+
downsampled = x[:, :, :: self.factor]
109+
# Upsample by interleaving to get context
110+
context = torch.repeat_interleave(downsampled, repeats=self.factor, dim=2)
111+
return self.diffusion(x, context=[context], **kwargs)
112+
113+
def sample(self, undersampled: Tensor, *args, **kwargs): # type: ignore
114+
# Upsample context by interleaving
115+
context = torch.repeat_interleave(undersampled, repeats=self.factor, dim=2)
116+
noise = torch.randn_like(context)
117+
default_kwargs = dict(context=[context])
118+
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore
119+
120+
121+
class DiffusionAutoencoder1d(Model1d):
122+
def __init__(
123+
self,
124+
in_channels: int,
125+
channels: int,
126+
patch_size: int,
127+
kernel_sizes_init: Sequence[int],
128+
multipliers: Sequence[int],
129+
factors: Sequence[int],
130+
num_blocks: Sequence[int],
131+
resnet_groups: int,
132+
kernel_multiplier_downsample: int,
133+
encoder_depth: int,
134+
encoder_channels: int,
135+
context_channels: int,
136+
**kwargs
137+
):
138+
super().__init__(
139+
in_channels=in_channels,
140+
channels=channels,
141+
patch_size=patch_size,
142+
kernel_sizes_init=kernel_sizes_init,
143+
multipliers=multipliers,
144+
factors=factors,
145+
num_blocks=num_blocks,
146+
resnet_groups=resnet_groups,
147+
kernel_multiplier_downsample=kernel_multiplier_downsample,
148+
context_channels=[0] * encoder_depth + [context_channels],
149+
**kwargs
150+
)
151+
152+
self.in_channels = in_channels
153+
self.encoder_factor = patch_size * prod(factors[0:encoder_depth])
154+
155+
self.encoder = Encoder1d(
156+
in_channels=in_channels,
157+
channels=channels,
158+
patch_size=patch_size,
159+
kernel_sizes_init=kernel_sizes_init,
160+
multipliers=multipliers,
161+
factors=factors,
162+
num_blocks=num_blocks,
163+
resnet_groups=resnet_groups,
164+
kernel_multiplier_downsample=kernel_multiplier_downsample,
165+
extract_channels=[0] * (encoder_depth - 1) + [encoder_channels],
166+
)
167+
168+
self.to_context = ResnetBlock1d(
169+
in_channels=encoder_channels,
170+
out_channels=context_channels,
171+
num_groups=resnet_groups,
172+
)
173+
174+
def forward(self, x: Tensor, **kwargs) -> Tensor:
175+
latent = self.encode(x)
176+
context = self.to_context(latent)
177+
return self.diffusion(x, context=[context], **kwargs)
178+
179+
def encode(self, x: Tensor) -> Tensor:
180+
x = self.encoder(x)[-1]
181+
latent = torch.tanh(x)
182+
return latent
183+
184+
def decode(self, latent: Tensor, **kwargs) -> Tensor:
185+
b, length = latent.shape[0], latent.shape[2] * self.encoder_factor
186+
# Compute noise by inferring shape from latent length
187+
noise = torch.randn(b, self.in_channels, length).to(latent)
188+
# Compute context form latent
189+
context = self.to_context(latent)
190+
default_kwargs = dict(context=[context])
191+
# Decode by sampling while conditioning on latent context
192+
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore
193+
194+
195+
""" Audio Diffusion Classes (specific for 1d audio data) """
196+
197+
94198
class AudioDiffusionModel(Model1d):
95199
def __init__(self, *args, **kwargs):
96200
default_kwargs = dict(
97-
in_channels=1,
98201
channels=128,
99202
patch_size=16,
100203
kernel_sizes_init=[1, 3, 7],
@@ -125,10 +228,8 @@ def sample(self, *args, **kwargs):
125228
return super().sample(*args, **{**default_kwargs, **kwargs})
126229

127230

128-
class AudioDiffusionUpsampler(Model1d):
129-
def __init__(self, factor: int, in_channels: int = 1, *args, **kwargs):
130-
self.factor = factor
131-
231+
class AudioDiffusionUpsampler(DiffusionUpsampler1d):
232+
def __init__(self, in_channels: int, *args, **kwargs):
132233
default_kwargs = dict(
133234
in_channels=in_channels,
134235
channels=128,
@@ -154,19 +255,45 @@ def __init__(self, factor: int, in_channels: int = 1, *args, **kwargs):
154255

155256
super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore
156257

157-
def forward(self, x: Tensor, **kwargs) -> Tensor:
158-
# Downsample by picking every `factor` item
159-
downsampled = x[:, :, :: self.factor]
160-
# Upsample by interleaving to get context
161-
context = torch.repeat_interleave(downsampled, repeats=self.factor, dim=2)
162-
return self.diffusion(x, context=[context], **kwargs)
258+
def sample(self, *args, **kwargs):
259+
default_kwargs = dict(
260+
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
261+
sampler=ADPM2Sampler(rho=1.0),
262+
)
263+
return super().sample(*args, **{**default_kwargs, **kwargs})
163264

164-
def sample(self, start: Tensor, *args, **kwargs): # type: ignore
165-
context = torch.repeat_interleave(start, repeats=self.factor, dim=2)
166-
noise = torch.randn_like(context)
265+
266+
class AudioDiffusionAutoencoder(DiffusionAutoencoder1d):
267+
def __init__(self, *args, **kwargs):
268+
default_kwargs = dict(
269+
channels=128,
270+
patch_size=16,
271+
kernel_sizes_init=[1, 3, 7],
272+
multipliers=[1, 2, 4, 4, 4, 4, 4],
273+
factors=[4, 4, 4, 2, 2, 2],
274+
num_blocks=[2, 2, 2, 2, 2, 2],
275+
attentions=[False, False, False, True, True, True],
276+
attention_heads=8,
277+
attention_features=64,
278+
attention_multiplier=2,
279+
use_attention_bottleneck=True,
280+
resnet_groups=8,
281+
kernel_multiplier_downsample=2,
282+
use_nearest_upsample=False,
283+
use_skip_scale=True,
284+
encoder_depth=4,
285+
encoder_channels=32,
286+
context_channels=512,
287+
diffusion_sigma_distribution=LogNormalDistribution(mean=-3.0, std=1.0),
288+
diffusion_sigma_data=0.1,
289+
diffusion_dynamic_threshold=0.0,
290+
)
291+
292+
super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore
293+
294+
def decode(self, *args, **kwargs):
167295
default_kwargs = dict(
168-
context=[context],
169296
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
170297
sampler=ADPM2Sampler(rho=1.0),
171298
)
172-
return super().sample(noise, *args, **{**default_kwargs, **kwargs}) # type: ignore # noqa
299+
return super().decode(*args, **{**default_kwargs, **kwargs})

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.23",
6+
version="0.0.24",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)