Skip to content

Commit c937e49

Browse files
feat: add diffusion ar, remove multiblock patching, move pretrained models
1 parent c0020d5 commit c937e49

File tree

5 files changed

+118
-103
lines changed

5 files changed

+118
-103
lines changed

README.md

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<img src="./LOGO.png"></img>
22

33
Unconditional audio generation using diffusion models, in PyTorch. The goal of this repository is to explore different architectures and diffusion models to generate audio (speech and music) directly from/to the waveform.
4-
Progress will be documented in the [experiments](#experiments) section. You can use the [`audio-diffusion-pytorch-trainer`](https://github.com/archinetai/audio-diffusion-pytorch-trainer) to run your own experiments – please share your findings in the [discussions](https://github.com/archinetai/audio-diffusion-pytorch/discussions) page!
4+
Progress will be documented in the [experiments](#experiments) section. You can use the [`audio-diffusion-pytorch-trainer`](https://github.com/archinetai/audio-diffusion-pytorch-trainer) to run your own experiments – please share your findings in the [discussions](https://github.com/archinetai/audio-diffusion-pytorch/discussions) page! Pretrained models can be found at [`archisound`](https://github.com/archinetai/archisound).
55

66
## Install
77

@@ -241,27 +241,6 @@ composer = SpanBySpanComposer(
241241
y_long = composer(y, keep_start=True) # [1, 1, 98304]
242242
```
243243

244-
## Pretrained Models
245-
246-
### Diffusion (Magnitude) AutoEncoder ([`dmae1d-ATC64-v1`](https://huggingface.co/archinetai/dmae1d-ATC64-v1/tree/main))
247-
```py
248-
from audio_diffusion_pytorch import AudioModel
249-
250-
autoencoder = AudioModel.from_pretrained("dmae1d-ATC64-v1")
251-
252-
x = torch.randn(1, 2, 2**18)
253-
z = autoencoder.encode(x) # [1, 32, 256]
254-
y = autoencoder.decode(z, num_steps=20) # [1, 2, 262144]
255-
```
256-
257-
| Info | |
258-
| ------------- | ------------- |
259-
| Input type | Audio (stereo @ 48kHz) |
260-
| Number of parameters | 234.2M |
261-
| Compression Factor | 64x |
262-
| Downsampling Factor | 1024x |
263-
| Bottleneck Type | Tanh |
264-
265244

266245
## Experiments
267246

audio_diffusion_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
AudioDiffusionUpphaser,
2828
AudioDiffusionUpsampler,
2929
AudioDiffusionVocoder,
30-
AudioModel,
30+
DiffusionAR1d,
3131
DiffusionAutoencoder1d,
3232
DiffusionMAE1d,
3333
DiffusionUpphaser1d,

audio_diffusion_pytorch/model.py

Lines changed: 103 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from math import pi
2+
from random import randint
23
from typing import Any, Optional, Sequence, Tuple, Union
34

45
import torch
56
from audio_encoders_pytorch import Bottleneck, Encoder1d
67
from einops import rearrange
78
from torch import Tensor, nn
9+
from tqdm import tqdm
810

911
from .diffusion import LinearSchedule, UniformDistribution, VSampler, XDiffusion
10-
from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetConditional1d
12+
from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetConditional1d, rand_bool
1113
from .utils import (
1214
closest_power_2,
1315
default,
@@ -355,6 +357,105 @@ def forward(self, x: Tensor, **kwargs) -> Tensor:
355357
return self.diffusion(x, channels_list=[resampled], features=features, **kwargs)
356358

357359

360+
class DiffusionAR1d(Model1d):
361+
def __init__(
362+
self,
363+
in_channels: int,
364+
chunk_length: int,
365+
upsample: int = 0,
366+
dropout: float = 0.05,
367+
verbose: int = 0,
368+
**kwargs,
369+
):
370+
self.in_channels = in_channels
371+
self.chunk_length = chunk_length
372+
self.dropout = dropout
373+
self.upsample = upsample
374+
self.verbose = verbose
375+
super().__init__(
376+
in_channels=in_channels,
377+
context_channels=[in_channels * (2 if upsample > 0 else 1)],
378+
**kwargs,
379+
)
380+
381+
def reupsample(self, x: Tensor) -> Tensor:
382+
x = x.clone()
383+
x = downsample(x, factor=self.upsample)
384+
x = upsample(x, factor=self.upsample)
385+
return x
386+
387+
def forward(self, x: Tensor, **kwargs) -> Tensor:
388+
b, _, t, device = *x.shape, x.device
389+
cl, num_chunks = self.chunk_length, t // self.chunk_length
390+
assert num_chunks >= 2, "Input tensor length must be >= chunk_length * 2"
391+
392+
# Get prev and current target chunks
393+
chunk_index = randint(0, num_chunks - 2)
394+
chunk_pos = cl * (chunk_index + 1)
395+
chunk_prev = x[:, :, cl * chunk_index : chunk_pos]
396+
chunk_curr = x[:, :, chunk_pos : cl * (chunk_index + 2)]
397+
398+
# Randomly dropout source chunks to allow for zero AR start
399+
if self.dropout > 0:
400+
batch_mask = rand_bool(shape=(b, 1, 1), proba=self.dropout, device=device)
401+
chunk_zeros = torch.zeros_like(chunk_prev)
402+
chunk_prev = torch.where(batch_mask, chunk_zeros, chunk_prev)
403+
404+
# Condition on previous chunk and reupsampled current if required
405+
if self.upsample > 0:
406+
chunk_reupsampled = self.reupsample(chunk_curr)
407+
channels_list = [torch.cat([chunk_prev, chunk_reupsampled], dim=1)]
408+
else:
409+
channels_list = [chunk_prev]
410+
411+
# Diffuse current current chunk
412+
return self.diffusion(chunk_curr, channels_list=channels_list, **kwargs)
413+
414+
def sample(self, x: Tensor, start: Optional[Tensor] = None, **kwargs) -> Tensor: # type: ignore # noqa
415+
noise = x
416+
417+
if self.upsample > 0:
418+
# In this case we assume that x is the downsampled audio instead of noise
419+
upsampled = upsample(x, factor=self.upsample)
420+
noise = torch.randn_like(upsampled)
421+
422+
b, c, t, device = *noise.shape, noise.device
423+
cl, num_chunks = self.chunk_length, t // self.chunk_length
424+
assert c == self.in_channels
425+
assert t % cl == 0, "noise must be divisible by chunk_length"
426+
427+
# Initialize previous chunk
428+
if exists(start):
429+
chunk_prev = start[:, :, -cl:]
430+
else:
431+
chunk_prev = torch.zeros(b, c, cl).to(device)
432+
433+
# Computed chunks
434+
chunks = []
435+
436+
for i in tqdm(range(num_chunks), disable=(self.verbose == 0)):
437+
# Chunk noise
438+
chunk_start, chunk_end = cl * i, cl * (i + 1)
439+
noise_curr = noise[:, :, chunk_start:chunk_end]
440+
441+
# Condition on previous chunk and artifically upsampled current if required
442+
if self.upsample > 0:
443+
chunk_upsampled = upsampled[:, :, chunk_start:chunk_end]
444+
channels_list = [torch.cat([chunk_prev, chunk_upsampled], dim=1)]
445+
else:
446+
channels_list = [chunk_prev]
447+
default_kwargs = dict(channels_list=channels_list)
448+
449+
# Sample current chunk
450+
chunk_curr = super().sample(noise_curr, **{**default_kwargs, **kwargs})
451+
452+
# Save chunk and use current as prev
453+
chunks += [chunk_curr]
454+
chunk_prev = chunk_curr
455+
456+
return rearrange(chunks, "l b c t -> b c (l t)")
457+
458+
358459
"""
359460
Audio Diffusion Classes (specific for 1d audio data)
360461
"""
@@ -363,7 +464,7 @@ def forward(self, x: Tensor, **kwargs) -> Tensor:
363464
def get_default_model_kwargs():
364465
return dict(
365466
channels=128,
366-
patch_factor=16,
467+
patch_size=16,
367468
multipliers=[1, 2, 4, 4, 4, 4, 4],
368469
factors=[4, 4, 4, 2, 2, 2],
369470
num_blocks=[2, 2, 2, 2, 2, 2],
@@ -500,18 +601,3 @@ def __init__(self, in_channels: int, **kwargs):
500601

501602
def sample(self, *args, **kwargs):
502603
return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})
503-
504-
505-
""" Pretrained Models Helper """
506-
507-
REVISION = {"dmae1d-ATC64-v1": "07885065867977af43b460bb9c1422bdc90c29a0"}
508-
509-
510-
class AudioModel:
511-
@staticmethod
512-
def from_pretrained(name: str) -> nn.Module:
513-
from transformers import AutoModel
514-
515-
return AutoModel.from_pretrained(
516-
f"archinetai/{name}", trust_remote_code=True, revision=REVISION[name]
517-
)

audio_diffusion_pytorch/modules.py

Lines changed: 11 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -207,12 +207,12 @@ def forward(self, x: Tensor, mapping: Optional[Tensor] = None) -> Tensor:
207207
return h + self.to_out(x)
208208

209209

210-
class PatchBlock(nn.Module):
210+
class Patcher(nn.Module):
211211
def __init__(
212212
self,
213213
in_channels: int,
214214
out_channels: int,
215-
patch_size: int = 2,
215+
patch_size: int,
216216
context_mapping_features: Optional[int] = None,
217217
):
218218
super().__init__()
@@ -223,7 +223,7 @@ def __init__(
223223
self.block = ResnetBlock1d(
224224
in_channels=in_channels,
225225
out_channels=out_channels // patch_size,
226-
num_groups=min(patch_size, in_channels),
226+
num_groups=1,
227227
context_mapping_features=context_mapping_features,
228228
)
229229

@@ -233,12 +233,12 @@ def forward(self, x: Tensor, mapping: Optional[Tensor] = None) -> Tensor:
233233
return x
234234

235235

236-
class UnpatchBlock(nn.Module):
236+
class Unpatcher(nn.Module):
237237
def __init__(
238238
self,
239239
in_channels: int,
240240
out_channels: int,
241-
patch_size: int = 2,
241+
patch_size: int,
242242
context_mapping_features: Optional[int] = None,
243243
):
244244
super().__init__()
@@ -249,7 +249,7 @@ def __init__(
249249
self.block = ResnetBlock1d(
250250
in_channels=in_channels // patch_size,
251251
out_channels=out_channels,
252-
num_groups=min(patch_size, out_channels),
252+
num_groups=1,
253253
context_mapping_features=context_mapping_features,
254254
)
255255

@@ -259,56 +259,6 @@ def forward(self, x: Tensor, mapping: Optional[Tensor] = None) -> Tensor:
259259
return x
260260

261261

262-
class Patcher(ConditionedSequential):
263-
def __init__(
264-
self,
265-
in_channels: int,
266-
out_channels: int,
267-
blocks: int,
268-
factor: int,
269-
context_mapping_features: Optional[int] = None,
270-
):
271-
channels_pre = [in_channels * (factor ** i) for i in range(blocks)]
272-
channels_post = [in_channels * (factor ** (i + 1)) for i in range(blocks - 1)]
273-
channels_post += [out_channels]
274-
275-
super().__init__(
276-
PatchBlock(
277-
in_channels=channels_pre[i],
278-
out_channels=channels_post[i],
279-
patch_size=factor,
280-
context_mapping_features=context_mapping_features,
281-
)
282-
for i in range(blocks)
283-
)
284-
285-
286-
class Unpatcher(ConditionedSequential):
287-
def __init__(
288-
self,
289-
in_channels: int,
290-
out_channels: int,
291-
blocks: int,
292-
factor: int,
293-
context_mapping_features: Optional[int] = None,
294-
):
295-
channels_pre = [in_channels]
296-
channels_pre += [
297-
out_channels * (factor ** (i + 1)) for i in reversed(range(blocks - 1))
298-
]
299-
channels_post = [out_channels * (factor ** i) for i in reversed(range(blocks))]
300-
301-
super().__init__(
302-
UnpatchBlock(
303-
in_channels=channels_pre[i],
304-
out_channels=channels_post[i],
305-
patch_size=factor,
306-
context_mapping_features=context_mapping_features,
307-
)
308-
for i in range(blocks)
309-
)
310-
311-
312262
"""
313263
Attention Components
314264
"""
@@ -927,8 +877,7 @@ def __init__(
927877
factors: Sequence[int],
928878
num_blocks: Sequence[int],
929879
attentions: Sequence[int],
930-
patch_blocks: int = 1,
931-
patch_factor: int = 1,
880+
patch_size: int = 1,
932881
resnet_groups: int = 8,
933882
use_context_time: bool = True,
934883
kernel_multiplier_downsample: int = 2,
@@ -1013,11 +962,12 @@ def __init__(
1013962
assert exists(in_channels) and exists(out_channels)
1014963
self.stft = STFT(**stft_kwargs)
1015964

965+
assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}"
966+
1016967
self.to_in = Patcher(
1017968
in_channels=in_channels + context_channels[0],
1018969
out_channels=channels * multipliers[0],
1019-
blocks=patch_blocks,
1020-
factor=patch_factor,
970+
patch_size=patch_size,
1021971
context_mapping_features=context_mapping_features,
1022972
)
1023973

@@ -1076,8 +1026,7 @@ def __init__(
10761026
self.to_out = Unpatcher(
10771027
in_channels=channels * multipliers[0],
10781028
out_channels=out_channels,
1079-
blocks=patch_blocks,
1080-
factor=patch_factor,
1029+
patch_size=patch_size,
10811030
context_mapping_features=context_mapping_features,
10821031
)
10831032

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.92",
6+
version="0.0.93",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",
@@ -12,6 +12,7 @@
1212
url="https://github.com/archinetai/audio-diffusion-pytorch",
1313
keywords=["artificial intelligence", "deep learning", "audio generation"],
1414
install_requires=[
15+
"tqdm",
1516
"torch>=1.6",
1617
"data-science-types>=0.2",
1718
"einops>=0.4",

0 commit comments

Comments
 (0)