Skip to content

Commit f98c25e

Browse files
feat: add option to insert quantizer as diffusion autoencoder bottleneck
1 parent e0ba36c commit f98c25e

File tree

4 files changed

+26
-8
lines changed

4 files changed

+26
-8
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ y_long = composer(y, keep_start=True) # [1, 1, 98304]
203203
- [x] Add trainer with experiments.
204204
- [x] Add diffusion upsampler.
205205
- [x] Add ancestral euler sampler `AEulerSampler`.
206+
- [x] Add diffusion autoencoder.
207+
- [x] Add autoencoder bottleneck option for quantization.
206208

207209
## Appreciation
208210

audio_diffusion_pytorch/model.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import random
2-
from typing import Optional, Sequence, Union
2+
from typing import Any, Optional, Sequence, Tuple, Union
33

44
import torch
55
from torch import Tensor, nn
@@ -15,7 +15,7 @@
1515
Schedule,
1616
)
1717
from .modules import Encoder1d, ResnetBlock1d, UNet1d
18-
from .utils import default, prod, to_list
18+
from .utils import default, exists, prod, to_list
1919

2020
""" Diffusion Classes (generic for 1d data) """
2121

@@ -129,6 +129,13 @@ def sample( # type: ignore
129129
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore
130130

131131

132+
class Bottleneck(nn.Module):
133+
"""Bottleneck interface (subclass can be provided to DiffusionAutoencoder1d)"""
134+
135+
def forward(self, x: Tensor) -> Tuple[Tensor, Any]:
136+
raise NotImplementedError()
137+
138+
132139
class DiffusionAutoencoder1d(Model1d):
133140
def __init__(
134141
self,
@@ -144,6 +151,7 @@ def __init__(
144151
encoder_depth: int,
145152
encoder_channels: int,
146153
context_channels: int,
154+
bottleneck: Optional[Bottleneck] = None,
147155
**kwargs
148156
):
149157
super().__init__(
@@ -162,6 +170,7 @@ def __init__(
162170

163171
self.in_channels = in_channels
164172
self.encoder_factor = patch_size * prod(factors[0:encoder_depth])
173+
self.bottleneck = bottleneck
165174

166175
self.encoder = Encoder1d(
167176
in_channels=in_channels,
@@ -187,9 +196,15 @@ def forward(self, x: Tensor, **kwargs) -> Tensor:
187196
context = self.to_context(latent)
188197
return self.diffusion(x, context=[context], **kwargs)
189198

190-
def encode(self, x: Tensor) -> Tensor:
199+
def encode(
200+
self, x: Tensor, with_info: bool = False
201+
) -> Union[Tensor, Tuple[Tensor, Any]]:
191202
x = self.encoder(x)[-1]
192203
latent = torch.tanh(x)
204+
# Apply bottleneck if provided (e.g. quantization module)
205+
if exists(self.bottleneck):
206+
latent, info = self.bottleneck(latent)
207+
return (latent, info) if with_info else latent
193208
return latent
194209

195210
def decode(self, latent: Tensor, **kwargs) -> Tensor:

audio_diffusion_pytorch/modules.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -703,11 +703,12 @@ def __init__(
703703
self, in_channels: int, out_channels: int, kernel_sizes: Sequence[int]
704704
):
705705
super().__init__()
706+
mid_channels = in_channels * 8
706707

707708
self.block1 = nn.ModuleList(
708709
Conv1d(
709710
in_channels=in_channels,
710-
out_channels=out_channels,
711+
out_channels=mid_channels,
711712
kernel_size=kernel_size,
712713
padding=(kernel_size - 1) // 2,
713714
)
@@ -716,7 +717,7 @@ def __init__(
716717

717718
self.block2 = nn.ModuleList(
718719
Conv1d(
719-
in_channels=in_channels,
720+
in_channels=mid_channels,
720721
out_channels=out_channels,
721722
kernel_size=kernel_size,
722723
padding=(kernel_size - 1) // 2,
@@ -725,9 +726,9 @@ def __init__(
725726
)
726727

727728
def forward(self, x: Tensor) -> Tensor:
728-
xs = torch.stack([x] + [conv(x) for conv in self.block1])
729+
xs = torch.stack([conv(x) for conv in self.block1])
729730
x = reduce(xs, "n b c t -> b c t", "sum")
730-
xs = torch.stack([x] + [conv(x) for conv in self.block2])
731+
xs = torch.stack([conv(x) for conv in self.block2])
731732
x = reduce(xs, "n b c t -> b c t", "sum")
732733
return x
733734

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.27",
6+
version="0.0.28",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)