Skip to content

Commit c76e8a6

Browse files
feat: add additional skip connection
1 parent da667fc commit c76e8a6

File tree

3 files changed

+4
-3
lines changed

3 files changed

+4
-3
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ from audio_diffusion_pytorch import UNet1d
125125
unet = UNet1d(
126126
in_channels=1,
127127
channels=128,
128-
patch_size=16,
128+
patch_blocks=4,
129129
kernel_sizes_init=[1, 3, 7],
130130
multipliers=[1, 2, 4, 4, 4, 4, 4],
131131
factors=[4, 4, 4, 2, 2, 2],

audio_diffusion_pytorch/modules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1079,8 +1079,8 @@ def forward(
10791079
mapping = self.get_mapping(time, features)
10801080

10811081
x = self.to_in(x, mapping)
1082+
skips_list = [x]
10821083

1083-
skips_list = []
10841084
for i, downsample in enumerate(self.downsamples):
10851085
channels = self.get_channels(channels_list, layer=i + 1)
10861086
x, skips = downsample(
@@ -1094,6 +1094,7 @@ def forward(
10941094
skips = skips_list.pop()
10951095
x = upsample(x, skips, mapping=mapping, embedding=embedding)
10961096

1097+
x += skips_list.pop()
10971098
x = self.to_out(x, mapping)
10981099

10991100
return x

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.40",
6+
version="0.0.41",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)