Skip to content

Commit 4e6ee2c

Browse files
fix: convout dual channel, add skip connection
1 parent 7004f00 commit 4e6ee2c

File tree

2 files changed

+9
-11
lines changed

2 files changed

+9
-11
lines changed

audio_diffusion_pytorch/modules.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,13 @@ def ConvTranspose1d(*args, **kwargs) -> nn.Module:
2525

2626

2727
class ConvOut1d(nn.Module):
28-
def __init__(
29-
self, in_channels: int, out_channels: int, kernel_sizes: Sequence[int]
30-
):
28+
def __init__(self, channels: int, kernel_sizes: Sequence[int]):
3129
super().__init__()
32-
mid_channels = in_channels * 16
30+
mid_channels = channels * 16
3331

3432
self.convs_in = nn.ModuleList(
3533
Conv1d(
36-
in_channels=in_channels,
34+
in_channels=channels,
3735
out_channels=mid_channels,
3836
kernel_size=kernel_size,
3937
padding=(kernel_size - 1) // 2,
@@ -49,14 +47,15 @@ def __init__(
4947
)
5048

5149
self.conv_out = Conv1d(
52-
in_channels=mid_channels, out_channels=out_channels, kernel_size=1
50+
in_channels=mid_channels, out_channels=channels, kernel_size=1
5351
)
5452

5553
def forward(self, x: Tensor) -> Tensor:
54+
skip = x
5655
xs = torch.stack([conv(x) for conv in self.convs_in])
57-
x = reduce(xs, "n b c t -> b c t", "sum") + x
56+
x = reduce(xs, "n b c t -> b c t", "sum")
5857
x = self.conv_mid(x)
59-
x = self.conv_out(x)
58+
x = self.conv_out(x) + skip
6059
return x
6160

6261

@@ -932,8 +931,7 @@ def __init__(
932931
),
933932
Rearrange("b (c p) l -> b c (l p)", p=patch_size),
934933
ConvOut1d(
935-
in_channels=out_channels,
936-
out_channels=out_channels,
934+
channels=out_channels,
937935
kernel_sizes=kernel_sizes_out,
938936
)
939937
if exists(kernel_sizes_out)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.32",
6+
version="0.0.33",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)