Skip to content

Commit 4ec697e

Browse files
feat: add audio autoencoder head option
1 parent 31e2c27 commit 4ec697e

File tree

4 files changed

+71
-15
lines changed

4 files changed

+71
-15
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ y_long = composer(y, keep_start=True) # [1, 1, 98304]
146146
- [x] Add elucidated diffusion.
147147
- [x] Add ancestral DPM2 sampler.
148148
- [x] Add dynamic thresholding.
149-
- [ ] Add support with (variational) autoencoder to compress audio before diffusion.
149+
- [x] Add (variational) autoencoder option to compress audio before diffusion.
150150
- [ ] Fix inpainting and make it work with ADPM2 sampler.
151151

152152
## Appreciation

audio_diffusion_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@
1111
Schedule,
1212
SpanBySpanComposer,
1313
)
14-
from .model import AudioDiffusionModel, Model1d
14+
from .model import AudioAutoEncoderModel, AudioDiffusionModel, Model1d
1515
from .modules import AutoEncoder1d, UNet1d

audio_diffusion_pytorch/model.py

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
Sampler,
1313
Schedule,
1414
)
15-
from .modules import UNet1d
15+
from .modules import AutoEncoder1d, UNet1d
16+
from .utils import exists
1617

1718

1819
class Model1d(nn.Module):
@@ -39,9 +40,19 @@ def __init__(
3940
diffusion_sigma_data: int,
4041
diffusion_dynamic_threshold: float,
4142
out_channels: Optional[int] = None,
43+
use_autoencoder: bool = False,
44+
autoencoder: Optional[AutoEncoder1d] = None,
45+
autoencoder_scale: float = 1.0,
4246
):
4347
super().__init__()
4448

49+
self.use_autoencoder = use_autoencoder
50+
51+
if use_autoencoder:
52+
assert exists(autoencoder)
53+
self.autoencoder_scale = autoencoder_scale
54+
self.autoencoder = autoencoder
55+
4556
self.unet = UNet1d(
4657
in_channels=in_channels,
4758
channels=channels,
@@ -71,6 +82,8 @@ def __init__(
7182
)
7283

7384
def forward(self, x: Tensor) -> Tensor:
85+
if self.use_autoencoder:
86+
x = self.autoencoder_scale * self.autoencoder.encode(x) # type: ignore
7487
return self.diffusion(x)
7588

7689
def sample(
@@ -82,19 +95,35 @@ def sample(
8295
sigma_schedule=sigma_schedule,
8396
num_steps=num_steps,
8497
)
85-
return diffusion_sampler(noise)
98+
x = diffusion_sampler(noise)
8699

100+
if self.use_autoencoder:
101+
x = (1.0 / self.autoencoder_scale) * self.autoencoder.decode(x)
87102

88-
class AudioDiffusionModel(Model1d):
103+
return x
104+
105+
106+
class AudioAutoEncoderModel(AutoEncoder1d):
89107
def __init__(self, *args, **kwargs):
90108
default_kwargs = dict(
91109
in_channels=1,
110+
bottleneck_channels=128,
92111
channels=128,
93112
patch_size=16,
94-
multipliers=[1, 2, 4, 4, 4, 4, 4],
95-
factors=[4, 4, 4, 2, 2, 2],
96-
num_blocks=[2, 2, 2, 2, 2, 2],
97-
attentions=[False, False, False, True, True, True],
113+
multipliers=[1, 1, 1, 1, 1],
114+
factors=[1, 4, 4, 4],
115+
num_blocks=[2, 2, 2, 2],
116+
resnet_groups=8,
117+
kernel_multiplier_downsample=2,
118+
loss_kl_weight=1e-8,
119+
)
120+
super().__init__(*args, **{**default_kwargs, **kwargs})
121+
122+
123+
class AudioDiffusionModel(Model1d):
124+
def __init__(self, *args, **kwargs):
125+
default_kwargs = dict(
126+
channels=128,
98127
attention_heads=8,
99128
attention_features=64,
100129
attention_multiplier=2,
@@ -106,14 +135,41 @@ def __init__(self, *args, **kwargs):
106135
use_attention_bottleneck=True,
107136
use_learned_time_embedding=True,
108137
diffusion_sigma_distribution=LogNormalDistribution(mean=-3.0, std=1.0),
109-
diffusion_sigma_data=0.1,
110-
diffusion_dynamic_threshold=0.95,
111138
)
112-
super().__init__(*args, **{**default_kwargs, **kwargs})
139+
140+
model_kwargs = None
141+
142+
if "autoencoder" in kwargs:
143+
sigma_data = 0.2
144+
model_kwargs = dict(
145+
in_channels=128,
146+
patch_size=1,
147+
multipliers=[1, 4, 4, 4],
148+
factors=[2, 2, 2],
149+
num_blocks=[2, 2, 2],
150+
attentions=[True, True, True],
151+
diffusion_sigma_data=sigma_data,
152+
diffusion_dynamic_threshold=0.0,
153+
use_autoencoder=True,
154+
autoencoder_scale=sigma_data,
155+
)
156+
else:
157+
model_kwargs = dict(
158+
in_channels=1,
159+
patch_size=16,
160+
multipliers=[1, 2, 4, 4, 4, 4, 4],
161+
factors=[4, 4, 4, 2, 2, 2],
162+
num_blocks=[2, 2, 2, 2, 2, 2],
163+
attentions=[False, False, False, True, True, True],
164+
diffusion_sigma_data=0.1,
165+
diffusion_dynamic_threshold=0.95,
166+
use_autoencoder=False,
167+
)
168+
super().__init__(*args, **{**default_kwargs, **model_kwargs, **kwargs})
113169

114170
def sample(self, *args, **kwargs):
115171
default_kwargs = dict(
116-
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3, rho=9.0),
117-
sampler=ADPM2Sampler(rho=1),
172+
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
173+
sampler=ADPM2Sampler(rho=1.0),
118174
)
119175
return super().sample(*args, **{**default_kwargs, **kwargs})

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.12",
6+
version="0.0.13",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)