Skip to content

Commit e0ba36c

Browse files
feat: add convout kernels option
1 parent 28ec933 commit e0ba36c

File tree

3 files changed

+45
-13
lines changed

3 files changed

+45
-13
lines changed

audio_diffusion_pytorch/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
diffusion_dynamic_threshold: float,
4545
out_channels: Optional[int] = None,
4646
context_channels: Optional[Sequence[int]] = None,
47+
**kwargs
4748
):
4849
super().__init__()
4950

@@ -66,6 +67,7 @@ def __init__(
6667
use_skip_scale=use_skip_scale,
6768
out_channels=out_channels,
6869
context_channels=context_channels,
70+
**kwargs
6971
)
7072

7173
self.diffusion = Diffusion(

audio_diffusion_pytorch/modules.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,6 @@ def ConvTranspose1d(*args, **kwargs) -> nn.Module:
2525
return nn.ConvTranspose1d(*args, **kwargs)
2626

2727

28-
class ConvMean1d(nn.Module):
29-
def __init__(self, num_means: int, *args, **kwargs):
30-
super().__init__()
31-
self.convs = nn.ModuleList([Conv1d(*args, **kwargs) for _ in range(num_means)])
32-
33-
def forward(self, x: Tensor) -> Tensor:
34-
xs = torch.stack([conv(x) for conv in self.convs])
35-
x = reduce(xs, "n b c t -> b c t", "mean")
36-
return x
37-
38-
3928
def Downsample1d(
4029
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
4130
) -> nn.Module:
@@ -709,6 +698,40 @@ def forward(self, x: Tensor, t: Optional[Tensor] = None) -> Tensor:
709698
"""
710699

711700

701+
class ConvOut1d(nn.Module):
702+
def __init__(
703+
self, in_channels: int, out_channels: int, kernel_sizes: Sequence[int]
704+
):
705+
super().__init__()
706+
707+
self.block1 = nn.ModuleList(
708+
Conv1d(
709+
in_channels=in_channels,
710+
out_channels=out_channels,
711+
kernel_size=kernel_size,
712+
padding=(kernel_size - 1) // 2,
713+
)
714+
for kernel_size in kernel_sizes
715+
)
716+
717+
self.block2 = nn.ModuleList(
718+
Conv1d(
719+
in_channels=in_channels,
720+
out_channels=out_channels,
721+
kernel_size=kernel_size,
722+
padding=(kernel_size - 1) // 2,
723+
)
724+
for kernel_size in kernel_sizes
725+
)
726+
727+
def forward(self, x: Tensor) -> Tensor:
728+
xs = torch.stack([x] + [conv(x) for conv in self.block1])
729+
x = reduce(xs, "n b c t -> b c t", "sum")
730+
xs = torch.stack([x] + [conv(x) for conv in self.block2])
731+
x = reduce(xs, "n b c t -> b c t", "sum")
732+
return x
733+
734+
712735
class UNet1d(nn.Module):
713736
def __init__(
714737
self,
@@ -730,6 +753,7 @@ def __init__(
730753
use_attention_bottleneck: bool,
731754
out_channels: Optional[int] = None,
732755
context_channels: Optional[Sequence[int]] = None,
756+
kernel_sizes_out: Optional[Sequence[int]] = None,
733757
):
734758
super().__init__()
735759

@@ -835,14 +859,20 @@ def __init__(
835859
in_channels=channels + context_channels[1],
836860
out_channels=channels,
837861
num_groups=resnet_groups,
838-
time_context_features=time_context_features,
839862
),
840863
Conv1d(
841864
in_channels=channels,
842865
out_channels=out_channels * patch_size,
843866
kernel_size=1,
844867
),
845868
Rearrange("b (c p) l -> b c (l p)", p=patch_size),
869+
ConvOut1d(
870+
in_channels=out_channels,
871+
out_channels=out_channels,
872+
kernel_sizes=kernel_sizes_out,
873+
)
874+
if exists(kernel_sizes_out)
875+
else nn.Identity(),
846876
)
847877

848878
def get_context(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.26",
6+
version="0.0.27",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)