Skip to content

Commit f46557b

Browse files
feat: remove unsucessful convmean, provide context only during downsampling
1 parent d12456b commit f46557b

File tree

3 files changed

+3
-11
lines changed

3 files changed

+3
-11
lines changed

audio_diffusion_pytorch/model.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def __init__(
2323
channels: int,
2424
patch_size: int,
2525
kernel_sizes_init: Sequence[int],
26-
out_means: int,
2726
multipliers: Sequence[int],
2827
factors: Sequence[int],
2928
num_blocks: Sequence[int],
@@ -51,7 +50,6 @@ def __init__(
5150
resnet_groups=resnet_groups,
5251
kernel_multiplier_downsample=kernel_multiplier_downsample,
5352
kernel_sizes_init=kernel_sizes_init,
54-
out_means=out_means,
5553
multipliers=multipliers,
5654
factors=factors,
5755
num_blocks=num_blocks,
@@ -100,7 +98,6 @@ def __init__(self, *args, **kwargs):
10098
channels=128,
10199
patch_size=16,
102100
kernel_sizes_init=[1, 3, 7],
103-
out_means=4,
104101
multipliers=[1, 2, 4, 4, 4, 4, 4],
105102
factors=[4, 4, 4, 2, 2, 2],
106103
num_blocks=[2, 2, 2, 2, 2, 2],
@@ -136,7 +133,6 @@ def __init__(self, factor: int, in_channels: int = 1, *args, **kwargs):
136133
in_channels=in_channels,
137134
channels=128,
138135
patch_size=16,
139-
out_means=4,
140136
kernel_sizes_init=[1, 3, 7],
141137
multipliers=[1, 2, 4, 4, 4, 4, 4],
142138
factors=[4, 4, 4, 2, 2, 2],

audio_diffusion_pytorch/modules.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,6 @@ def __init__(
724724
use_skip_scale: bool,
725725
use_attention_bottleneck: bool,
726726
out_channels: Optional[int] = None,
727-
out_means: int = 1,
728727
context_channels: Optional[Sequence[int]] = None,
729728
):
730729
super().__init__()
@@ -802,11 +801,10 @@ def __init__(
802801
attention_features=attention_features,
803802
)
804803

805-
context_channels = context_channels + [0] # Upsample skips first context
806804
self.upsamples = nn.ModuleList(
807805
[
808806
UpsampleBlock1d(
809-
in_channels=channels * multipliers[i + 1] + context_channels[i + 2],
807+
in_channels=channels * multipliers[i + 1],
810808
out_channels=channels * multipliers[i],
811809
time_context_features=time_context_features,
812810
num_layers=num_blocks[i] + (1 if attentions[i] else 0),
@@ -833,8 +831,7 @@ def __init__(
833831
num_groups=resnet_groups,
834832
time_context_features=time_context_features,
835833
),
836-
ConvMean1d(
837-
num_means=out_means,
834+
Conv1d(
838835
in_channels=channels,
839836
out_channels=out_channels * patch_size,
840837
kernel_size=1,
@@ -889,7 +886,6 @@ def forward(
889886
for i, upsample in enumerate(self.upsamples):
890887
skips = skips_list.pop()
891888
x = upsample(x, skips, t)
892-
x = self.add_context(x, context, layer=len(self.upsamples) - i)
893889

894890
x = self.to_out(x) # t?
895891

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.21",
6+
version="0.0.22",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)