Skip to content

Commit 7004f00

Browse files
fix: context channels error, update convout block
1 parent bfa7952 commit 7004f00

File tree

3 files changed

+42
-73
lines changed

3 files changed

+42
-73
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ y_long = composer(y, keep_start=True) # [1, 1, 98304]
205205
- [x] Add ancestral euler sampler `AEulerSampler`.
206206
- [x] Add diffusion autoencoder.
207207
- [x] Add autoencoder bottleneck option for quantization.
208+
- [x] Add option to provide context tokens (resnet cross attention).
208209

209210
## Appreciation
210211

audio_diffusion_pytorch/modules.py

Lines changed: 40 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
import torch.nn as nn
66
from einops import rearrange, reduce
77
from einops.layers.torch import Rearrange
8-
from einops_exts import rearrange_many, repeat_many
8+
from einops_exts import rearrange_many
99
from einops_exts.torch import EinopsToAndFrom
1010
from torch import Tensor, einsum
11-
from torch.nn import functional as F
1211

1312
from .utils import default, exists
1413

@@ -25,6 +24,42 @@ def ConvTranspose1d(*args, **kwargs) -> nn.Module:
2524
return nn.ConvTranspose1d(*args, **kwargs)
2625

2726

27+
class ConvOut1d(nn.Module):
28+
def __init__(
29+
self, in_channels: int, out_channels: int, kernel_sizes: Sequence[int]
30+
):
31+
super().__init__()
32+
mid_channels = in_channels * 16
33+
34+
self.convs_in = nn.ModuleList(
35+
Conv1d(
36+
in_channels=in_channels,
37+
out_channels=mid_channels,
38+
kernel_size=kernel_size,
39+
padding=(kernel_size - 1) // 2,
40+
)
41+
for kernel_size in kernel_sizes
42+
)
43+
44+
self.conv_mid = nn.Conv1d(
45+
in_channels=mid_channels,
46+
out_channels=mid_channels,
47+
kernel_size=3,
48+
padding=1,
49+
)
50+
51+
self.conv_out = Conv1d(
52+
in_channels=mid_channels, out_channels=out_channels, kernel_size=1
53+
)
54+
55+
def forward(self, x: Tensor) -> Tensor:
56+
xs = torch.stack([conv(x) for conv in self.convs_in])
57+
x = reduce(xs, "n b c t -> b c t", "sum") + x
58+
x = self.conv_mid(x)
59+
x = self.conv_out(x)
60+
return x
61+
62+
2863
def Downsample1d(
2964
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
3065
) -> nn.Module:
@@ -273,25 +308,6 @@ def forward(self, x: Tensor) -> Tensor:
273308
"""
274309

275310

276-
class InsertNullTokens(nn.Module):
277-
def __init__(self, head_features: int, num_heads: int):
278-
super().__init__()
279-
self.num_heads = num_heads
280-
self.tokens = nn.Parameter(torch.randn(2, head_features))
281-
282-
def forward(
283-
self, k: Tensor, v: Tensor, *, mask: Tensor = None
284-
) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
285-
b = k.shape[0]
286-
nk, nv = repeat_many(
287-
self.tokens.unbind(dim=-2), "d -> b h 1 d", h=self.num_heads, b=b
288-
)
289-
k = torch.cat((nk, k), dim=-2)
290-
v = torch.cat((nv, v), dim=-2)
291-
mask = F.pad(mask, pad=(1, 0), value=True) if exists(mask) else None
292-
return k, v, mask
293-
294-
295311
def FeedForward1d(channels: int, multiplier: int = 2):
296312
mid_channels = int(channels * multiplier)
297313
return nn.Sequential(
@@ -324,19 +340,14 @@ def __init__(
324340
*,
325341
head_features: int = 64,
326342
num_heads: int = 8,
327-
use_null_tokens: bool = True,
328343
out_features: Optional[int] = None,
329344
):
330345
super().__init__()
331346
self.scale = head_features ** -0.5
332347
self.num_heads = num_heads
333-
self.use_null_tokens = use_null_tokens
334348
mid_features = head_features * num_heads
335349
out_features = out_features if exists(out_features) else features
336350

337-
self.insert_null_tokens = InsertNullTokens(
338-
head_features=head_features, num_heads=num_heads
339-
)
340351
self.to_out = nn.Sequential(
341352
nn.Linear(in_features=mid_features, out_features=out_features, bias=False),
342353
LayerNorm(features=out_features, bias=False),
@@ -356,10 +367,6 @@ def forward(
356367
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
357368
q = q * self.scale
358369

359-
# Insert null tokens
360-
if self.use_null_tokens:
361-
k, v, mask = self.insert_null_tokens(k, v, mask=mask)
362-
363370
# Compute similarity matrix with bias and mask
364371
sim = einsum("... n d, ... m d -> ... n m", q, k)
365372
sim = sim + attention_bias if exists(attention_bias) else sim
@@ -402,7 +409,6 @@ def __init__(
402409
features,
403410
num_heads=num_heads,
404411
head_features=head_features,
405-
use_null_tokens=False,
406412
out_features=out_features,
407413
)
408414

@@ -436,10 +442,7 @@ def __init__(
436442
in_features=context_features, out_features=mid_features * 2, bias=False
437443
)
438444
self.attention = AttentionBase(
439-
features,
440-
num_heads=num_heads,
441-
head_features=head_features,
442-
use_null_tokens=False,
445+
features, num_heads=num_heads, head_features=head_features
443446
)
444447

445448
def forward(self, x: Tensor, context: Tensor, mask: Tensor = None) -> Tensor:
@@ -556,7 +559,7 @@ def __init__(
556559
self.blocks = nn.ModuleList(
557560
[
558561
ResnetBlock1d(
559-
in_channels=channels + (context_channels if i == 0 else 0),
562+
in_channels=channels + context_channels if i == 0 else channels,
560563
out_channels=channels,
561564
num_groups=num_groups,
562565
time_context_features=time_context_features,
@@ -790,41 +793,6 @@ def forward(
790793
"""
791794

792795

793-
class ConvOut1d(nn.Module):
794-
def __init__(
795-
self, in_channels: int, out_channels: int, kernel_sizes: Sequence[int]
796-
):
797-
super().__init__()
798-
mid_channels = in_channels * 8
799-
800-
self.block1 = nn.ModuleList(
801-
Conv1d(
802-
in_channels=in_channels,
803-
out_channels=mid_channels,
804-
kernel_size=kernel_size,
805-
padding=(kernel_size - 1) // 2,
806-
)
807-
for kernel_size in kernel_sizes
808-
)
809-
810-
self.block2 = nn.ModuleList(
811-
Conv1d(
812-
in_channels=mid_channels,
813-
out_channels=out_channels,
814-
kernel_size=kernel_size,
815-
padding=(kernel_size - 1) // 2,
816-
)
817-
for kernel_size in kernel_sizes
818-
)
819-
820-
def forward(self, x: Tensor) -> Tensor:
821-
xs = torch.stack([conv(x) for conv in self.block1])
822-
x = reduce(xs, "n b c t -> b c t", "sum")
823-
xs = torch.stack([conv(x) for conv in self.block2])
824-
x = reduce(xs, "n b c t -> b c t", "sum")
825-
return x
826-
827-
828796
class UNet1d(nn.Module):
829797
def __init__(
830798
self,
@@ -953,7 +921,7 @@ def __init__(
953921

954922
self.to_out = nn.Sequential(
955923
ResnetBlock1d(
956-
in_channels=channels + context_channels[1],
924+
in_channels=channels,
957925
out_channels=channels,
958926
num_groups=resnet_groups,
959927
),

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.31",
6+
version="0.0.32",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)