Skip to content

Commit 4590e0c

Browse files
feat: add conditional model, classifier-free guidance, refactor context names
1 parent 4e6ee2c commit 4590e0c

File tree

5 files changed

+275
-212
lines changed

5 files changed

+275
-212
lines changed

README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,35 @@ decoded = autoencoder.decode(
8686
) # [2, 32, 2**18]
8787
```
8888

89+
90+
### Conditional Generation
91+
```py
92+
from audio_diffusion_pytorch import AudioDiffusionConditional
93+
94+
model = AudioDiffusionConditional(
95+
in_channels=1,
96+
embedding_max_length=512,
97+
embedding_features=768,
98+
embedding_mask_proba=0.1 # Conditional dropout of batch elements
99+
)
100+
101+
# Train on pairs of audio and embedding data (e.g. from a transformer output)
102+
x = torch.randn(2, 1, 2 ** 18)
103+
embedding = torch.randn(2, 512, 768)
104+
loss = model(x, embedding=embedding)
105+
loss.backward()
106+
107+
# Given start embedding and noise sample new source
108+
embedding = torch.randn(1, 512, 768)
109+
noise = torch.randn(1, 1, 2 ** 18)
110+
sampled = model.sample(
111+
noise,
112+
embedding=embedding,
113+
embedding_scale=5.0, # Classifier-free guidance scale
114+
num_steps=5
115+
) # [1, 1, 2 ** 18]
116+
```
117+
89118
## Usage with Components
90119

91120
### UNet1d
@@ -206,6 +235,7 @@ y_long = composer(y, keep_start=True) # [1, 1, 98304]
206235
- [x] Add diffusion autoencoder.
207236
- [x] Add autoencoder bottleneck option for quantization.
208237
- [x] Add option to provide context tokens (resnet cross attention).
238+
- [x] Add conditional model with classifier-free guidance.
209239

210240
## Appreciation
211241

audio_diffusion_pytorch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
)
1515
from .model import (
1616
AudioDiffusionAutoencoder,
17+
AudioDiffusionConditional,
1718
AudioDiffusionModel,
1819
AudioDiffusionUpsampler,
1920
DiffusionAutoencoder1d,
2021
DiffusionUpsampler1d,
2122
Model1d,
2223
)
23-
from .modules import Encoder1d, UNet1d
24+
from .modules import Encoder1d, UNet1d, UNetConditional1d

audio_diffusion_pytorch/model.py

Lines changed: 90 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -14,61 +14,28 @@
1414
Sampler,
1515
Schedule,
1616
)
17-
from .modules import Encoder1d, ResnetBlock1d, UNet1d
17+
from .modules import Encoder1d, ResnetBlock1d, UNet1d, UNetConditional1d
1818
from .utils import default, exists, prod, to_list
1919

20-
""" Diffusion Classes (generic for 1d data) """
20+
"""
21+
Diffusion Classes (generic for 1d data)
22+
"""
2123

2224

2325
class Model1d(nn.Module):
2426
def __init__(
2527
self,
26-
in_channels: int,
27-
channels: int,
28-
patch_size: int,
29-
kernel_sizes_init: Sequence[int],
30-
multipliers: Sequence[int],
31-
factors: Sequence[int],
32-
num_blocks: Sequence[int],
33-
attentions: Sequence[bool],
34-
attention_heads: int,
35-
attention_features: int,
36-
attention_multiplier: int,
37-
use_attention_bottleneck: bool,
38-
resnet_groups: int,
39-
kernel_multiplier_downsample: int,
40-
use_nearest_upsample: bool,
41-
use_skip_scale: bool,
4228
diffusion_sigma_distribution: Distribution,
4329
diffusion_sigma_data: int,
4430
diffusion_dynamic_threshold: float,
45-
out_channels: Optional[int] = None,
46-
context_channels: Optional[Sequence[int]] = None,
31+
use_classifier_free_guidance: bool = False,
4732
**kwargs
4833
):
4934
super().__init__()
5035

51-
self.unet = UNet1d(
52-
in_channels=in_channels,
53-
channels=channels,
54-
patch_size=patch_size,
55-
kernel_sizes_init=kernel_sizes_init,
56-
multipliers=multipliers,
57-
factors=factors,
58-
num_blocks=num_blocks,
59-
attentions=attentions,
60-
attention_heads=attention_heads,
61-
attention_features=attention_features,
62-
attention_multiplier=attention_multiplier,
63-
use_attention_bottleneck=use_attention_bottleneck,
64-
resnet_groups=resnet_groups,
65-
kernel_multiplier_downsample=kernel_multiplier_downsample,
66-
use_nearest_upsample=use_nearest_upsample,
67-
use_skip_scale=use_skip_scale,
68-
out_channels=out_channels,
69-
context_channels=context_channels,
70-
**kwargs
71-
)
36+
UNet = UNetConditional1d if use_classifier_free_guidance else UNet1d
37+
38+
self.unet = UNet(**kwargs)
7239

7340
self.diffusion = Diffusion(
7441
net=self.unet,
@@ -114,18 +81,18 @@ def forward(self, x: Tensor, factor: Optional[int] = None, **kwargs) -> Tensor:
11481
# Downsample by picking every `factor` item
11582
downsampled = x[:, :, ::factor]
11683
# Upsample by interleaving to get context
117-
context = torch.repeat_interleave(downsampled, repeats=factor, dim=2)
118-
return self.diffusion(x, context=[context], **kwargs)
84+
channels = torch.repeat_interleave(downsampled, repeats=factor, dim=2)
85+
return self.diffusion(x, channels_list=[channels], **kwargs)
11986

12087
def sample( # type: ignore
12188
self, undersampled: Tensor, factor: Optional[int] = None, *args, **kwargs
12289
):
12390
# Either user provides factor or we pick the first
12491
factor = default(factor, self.factor[0])
125-
# Upsample context by interleaving
126-
context = torch.repeat_interleave(undersampled, repeats=factor, dim=2)
127-
noise = torch.randn_like(context)
128-
default_kwargs = dict(context=[context])
92+
# Upsample channels by interleaving
93+
channels = torch.repeat_interleave(undersampled, repeats=factor, dim=2)
94+
noise = torch.randn_like(channels)
95+
default_kwargs = dict(channels_list=[channels])
12996
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore
13097

13198

@@ -166,7 +133,7 @@ def __init__(
166133
resnet_groups=resnet_groups,
167134
kernel_multiplier_downsample=kernel_multiplier_downsample,
168135
context_channels=[0] * encoder_depth + [context_channels],
169-
**kwargs
136+
**kwargs,
170137
)
171138

172139
self.in_channels = in_channels
@@ -190,7 +157,7 @@ def __init__(
190157
extract_channels=[0] * (encoder_depth - 1) + [encoder_channels],
191158
)
192159

193-
self.to_context = ResnetBlock1d(
160+
self.to_context_channels = ResnetBlock1d(
194161
in_channels=encoder_channels,
195162
out_channels=context_channels,
196163
num_groups=resnet_groups,
@@ -204,8 +171,8 @@ def forward( # type: ignore
204171
else:
205172
latent = self.encode(x)
206173

207-
context = self.to_context(latent)
208-
loss = self.diffusion(x, context=[context], **kwargs)
174+
channels = self.to_context_channels(latent)
175+
loss = self.diffusion(x, channels_list=[channels], **kwargs)
209176
return (loss, info) if with_info else loss
210177

211178
def encode(
@@ -224,114 +191,106 @@ def decode(self, latent: Tensor, **kwargs) -> Tensor:
224191
# Compute noise by inferring shape from latent length
225192
noise = torch.randn(b, self.in_channels, length).to(latent)
226193
# Compute context form latent
227-
context = self.to_context(latent)
228-
default_kwargs = dict(context=[context])
229-
# Decode by sampling while conditioning on latent context
194+
channels = self.to_context_channels(latent)
195+
default_kwargs = dict(channels_list=[channels])
196+
# Decode by sampling while conditioning on latent channels
230197
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore
231198

232199

233-
""" Audio Diffusion Classes (specific for 1d audio data) """
200+
"""
201+
Audio Diffusion Classes (specific for 1d audio data)
202+
"""
203+
204+
205+
def get_default_model_kwargs():
206+
return dict(
207+
channels=128,
208+
patch_size=16,
209+
kernel_sizes_init=[1, 3, 7],
210+
multipliers=[1, 2, 4, 4, 4, 4, 4],
211+
factors=[4, 4, 4, 2, 2, 2],
212+
num_blocks=[2, 2, 2, 2, 2, 2],
213+
attentions=[False, False, False, True, True, True],
214+
attention_heads=8,
215+
attention_features=64,
216+
attention_multiplier=2,
217+
use_attention_bottleneck=True,
218+
resnet_groups=8,
219+
kernel_multiplier_downsample=2,
220+
use_nearest_upsample=False,
221+
use_skip_scale=True,
222+
diffusion_sigma_distribution=LogNormalDistribution(mean=-3.0, std=1.0),
223+
diffusion_sigma_data=0.1,
224+
diffusion_dynamic_threshold=0.0,
225+
)
226+
227+
228+
def get_default_sampling_kwargs():
229+
return dict(
230+
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
231+
sampler=ADPM2Sampler(rho=1.0),
232+
)
234233

235234

236235
class AudioDiffusionModel(Model1d):
237-
def __init__(self, *args, **kwargs):
238-
default_kwargs = dict(
239-
channels=128,
240-
patch_size=16,
241-
kernel_sizes_init=[1, 3, 7],
242-
multipliers=[1, 2, 4, 4, 4, 4, 4],
243-
factors=[4, 4, 4, 2, 2, 2],
244-
num_blocks=[2, 2, 2, 2, 2, 2],
245-
attentions=[False, False, False, True, True, True],
246-
attention_heads=8,
247-
attention_features=64,
248-
attention_multiplier=2,
249-
use_attention_bottleneck=True,
250-
resnet_groups=8,
251-
kernel_multiplier_downsample=2,
252-
use_nearest_upsample=False,
253-
use_skip_scale=True,
254-
diffusion_sigma_distribution=LogNormalDistribution(mean=-3.0, std=1.0),
255-
diffusion_sigma_data=0.1,
256-
diffusion_dynamic_threshold=0.0,
257-
)
258-
259-
super().__init__(*args, **{**default_kwargs, **kwargs})
236+
def __init__(self, **kwargs):
237+
super().__init__(**{**get_default_model_kwargs(), **kwargs})
260238

261239
def sample(self, *args, **kwargs):
262-
default_kwargs = dict(
263-
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
264-
sampler=ADPM2Sampler(rho=1.0),
265-
)
266-
return super().sample(*args, **{**default_kwargs, **kwargs})
240+
return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})
267241

268242

269243
class AudioDiffusionUpsampler(DiffusionUpsampler1d):
270-
def __init__(self, in_channels: int, *args, **kwargs):
244+
def __init__(self, in_channels: int, **kwargs):
271245
default_kwargs = dict(
246+
**get_default_model_kwargs(),
272247
in_channels=in_channels,
273-
channels=128,
274-
patch_size=16,
275-
kernel_sizes_init=[1, 3, 7],
276-
multipliers=[1, 2, 4, 4, 4, 4, 4],
277-
factors=[4, 4, 4, 2, 2, 2],
278-
num_blocks=[2, 2, 2, 2, 2, 2],
279-
attentions=[False, False, False, True, True, True],
280-
attention_heads=8,
281-
attention_features=64,
282-
attention_multiplier=2,
283-
use_attention_bottleneck=True,
284-
resnet_groups=8,
285-
kernel_multiplier_downsample=2,
286-
use_nearest_upsample=False,
287-
use_skip_scale=True,
288-
diffusion_sigma_distribution=LogNormalDistribution(mean=-3.0, std=1.0),
289-
diffusion_sigma_data=0.1,
290-
diffusion_dynamic_threshold=0.0,
291248
context_channels=[in_channels],
292249
)
293-
294-
super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore
250+
super().__init__(**{**default_kwargs, **kwargs}) # type: ignore
295251

296252
def sample(self, *args, **kwargs):
297-
default_kwargs = dict(
298-
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
299-
sampler=ADPM2Sampler(rho=1.0),
300-
)
301-
return super().sample(*args, **{**default_kwargs, **kwargs})
253+
return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})
302254

303255

304256
class AudioDiffusionAutoencoder(DiffusionAutoencoder1d):
305257
def __init__(self, *args, **kwargs):
306258
default_kwargs = dict(
307-
channels=128,
308-
patch_size=16,
309-
kernel_sizes_init=[1, 3, 7],
310-
multipliers=[1, 2, 4, 4, 4, 4, 4],
311-
factors=[4, 4, 4, 2, 2, 2],
312-
num_blocks=[2, 2, 2, 2, 2, 2],
313-
attentions=[False, False, False, True, True, True],
314-
attention_heads=8,
315-
attention_features=64,
316-
attention_multiplier=2,
317-
use_attention_bottleneck=True,
318-
resnet_groups=8,
319-
kernel_multiplier_downsample=2,
320-
use_nearest_upsample=False,
321-
use_skip_scale=True,
259+
**get_default_model_kwargs(),
322260
encoder_depth=4,
323261
encoder_channels=32,
324262
context_channels=512,
325-
diffusion_sigma_distribution=LogNormalDistribution(mean=-3.0, std=1.0),
326-
diffusion_sigma_data=0.1,
327-
diffusion_dynamic_threshold=0.0,
328263
)
329-
330264
super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore
331265

332266
def decode(self, *args, **kwargs):
267+
return super().decode(*args, **{**get_default_sampling_kwargs(), **kwargs})
268+
269+
270+
class AudioDiffusionConditional(Model1d):
271+
def __init__(
272+
self,
273+
embedding_features: int,
274+
embedding_max_length: int,
275+
embedding_mask_proba: float = 0.1,
276+
**kwargs
277+
):
278+
self.embedding_mask_proba = embedding_mask_proba
333279
default_kwargs = dict(
334-
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
335-
sampler=ADPM2Sampler(rho=1.0),
280+
**get_default_model_kwargs(),
281+
context_embedding_features=embedding_features,
282+
context_embedding_max_length=embedding_max_length,
283+
use_classifier_free_guidance=True,
336284
)
337-
return super().decode(*args, **{**default_kwargs, **kwargs})
285+
super().__init__(**{**default_kwargs, **kwargs})
286+
287+
def forward(self, *args, **kwargs):
288+
default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
289+
return super().forward(*args, **{**default_kwargs, **kwargs})
290+
291+
def sample(self, *args, **kwargs):
292+
default_kwargs = dict(
293+
**get_default_sampling_kwargs(),
294+
embedding_scale=5.0,
295+
)
296+
return super().sample(*args, **{**default_kwargs, **kwargs})

0 commit comments

Comments
 (0)