Skip to content

Commit d30eb1f

Browse files
feat: added audio diffusion upsampler
1 parent 03b9bb0 commit d30eb1f

File tree

4 files changed

+76
-3
lines changed

4 files changed

+76
-3
lines changed

README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ https://colab.research.google.com/gist/flavioschneider/39c6454bfc2d03dc7d0c5c9d8
1919

2020
## Usage
2121

22+
### Generation
2223
```py
2324
from audio_diffusion_pytorch import AudioDiffusionModel
2425

@@ -37,6 +38,28 @@ sampled = model.sample(
3738
) # [2, 1, 262144]
3839
```
3940

41+
### Upsampling
42+
```py
43+
from audio_diffusion_pytorch import AudioDiffusionUpsampler
44+
45+
upsampler = AudioDiffusionUpsampler(
46+
factor=4,
47+
in_channels=1
48+
)
49+
50+
# Train on high frequency data
51+
x = torch.randn(2, 1, 2 ** 18) # [batch, in_channels, samples]
52+
loss = upsampler(x)
53+
loss.backward()
54+
55+
# Given start undersampled source, samples upsampled source
56+
start = torch.randn(1, 1, 2 ** 16)
57+
sampled = upsampler.sample(
58+
start=start,
59+
num_steps=5 # Suggested range: 2-100
60+
)
61+
```
62+
4063
## Usage with Components
4164

4265
### UNet1d
@@ -131,6 +154,8 @@ y_long = composer(y, keep_start=True) # [1, 1, 98304]
131154
```
132155

133156

157+
158+
134159
## Experiments
135160

136161

audio_diffusion_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@
1111
Schedule,
1212
SpanBySpanComposer,
1313
)
14-
from .model import AudioDiffusionModel, Model1d
14+
from .model import AudioDiffusionModel, AudioDiffusionUpsampler, Model1d
1515
from .modules import Encoder1d, UNet1d

audio_diffusion_pytorch/model.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional, Sequence
22

3+
import torch
34
from torch import Tensor, nn
45

56
from .diffusion import (
@@ -111,7 +112,7 @@ def __init__(self, *args, **kwargs):
111112
use_skip_scale=True,
112113
diffusion_sigma_distribution=LogNormalDistribution(mean=-3.0, std=1.0),
113114
diffusion_sigma_data=0.1,
114-
diffusion_dynamic_threshold=0.95,
115+
diffusion_dynamic_threshold=0.0,
115116
)
116117

117118
super().__init__(*args, **{**default_kwargs, **kwargs})
@@ -122,3 +123,50 @@ def sample(self, *args, **kwargs):
122123
sampler=ADPM2Sampler(rho=1.0),
123124
)
124125
return super().sample(*args, **{**default_kwargs, **kwargs})
126+
127+
128+
class AudioDiffusionUpsampler(Model1d):
129+
def __init__(self, factor: int, in_channels: int = 1, *args, **kwargs):
130+
self.factor = factor
131+
132+
default_kwargs = dict(
133+
in_channels=in_channels,
134+
channels=128,
135+
patch_size=16,
136+
kernel_sizes_init=[1, 3, 7],
137+
multipliers=[1, 2, 4, 4, 4, 4, 4],
138+
factors=[4, 4, 4, 2, 2, 2],
139+
num_blocks=[2, 2, 2, 2, 2, 2],
140+
attentions=[False, False, False, True, True, True],
141+
attention_heads=8,
142+
attention_features=64,
143+
attention_multiplier=2,
144+
use_attention_bottleneck=True,
145+
resnet_groups=8,
146+
kernel_multiplier_downsample=2,
147+
use_nearest_upsample=False,
148+
use_skip_scale=True,
149+
diffusion_sigma_distribution=LogNormalDistribution(mean=-3.0, std=1.0),
150+
diffusion_sigma_data=0.1,
151+
diffusion_dynamic_threshold=0.0,
152+
context_channels=[in_channels],
153+
)
154+
155+
super().__init__(*args, {**default_kwargs, **kwargs}) # type: ignore
156+
157+
def forward(self, x: Tensor, **kwargs) -> Tensor:
158+
# Downsample by picking every `factor` item
159+
downsampled = x[:, :, :: self.factor]
160+
# Upsample by interleaving to get context
161+
context = torch.repeat_interleave(downsampled, repeats=self.factor, dim=2)
162+
return self.diffusion(x, context=[context], **kwargs)
163+
164+
def sample(self, start: Tensor, *args, **kwargs): # type: ignore
165+
context = torch.repeat_interleave(start, repeats=self.factor, dim=2)
166+
noise = torch.randn_like(context)
167+
default_kwargs = dict(
168+
context=[context],
169+
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
170+
sampler=ADPM2Sampler(rho=1.0),
171+
)
172+
return super().sample(noise, *args, **{**default_kwargs, **kwargs}) # type: ignore # noqa

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.17",
6+
version="0.0.18",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)