Skip to content

Commit d12456b

Browse files
feat: add convmean tail
1 parent c83ba38 commit d12456b

File tree

4 files changed

+20
-4
lines changed

4 files changed

+20
-4
lines changed

audio_diffusion_pytorch/diffusion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tens
153153
x_next = x + d * (sigma_down - sigma)
154154
# Add randomness
155155
x_next = x_next + torch.randn_like(x) * sigma_up
156-
print(sigma_up)
157156
return x_next
158157

159158
def forward(

audio_diffusion_pytorch/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
channels: int,
2424
patch_size: int,
2525
kernel_sizes_init: Sequence[int],
26+
out_means: int,
2627
multipliers: Sequence[int],
2728
factors: Sequence[int],
2829
num_blocks: Sequence[int],
@@ -50,6 +51,7 @@ def __init__(
5051
resnet_groups=resnet_groups,
5152
kernel_multiplier_downsample=kernel_multiplier_downsample,
5253
kernel_sizes_init=kernel_sizes_init,
54+
out_means=out_means,
5355
multipliers=multipliers,
5456
factors=factors,
5557
num_blocks=num_blocks,
@@ -98,6 +100,7 @@ def __init__(self, *args, **kwargs):
98100
channels=128,
99101
patch_size=16,
100102
kernel_sizes_init=[1, 3, 7],
103+
out_means=4,
101104
multipliers=[1, 2, 4, 4, 4, 4, 4],
102105
factors=[4, 4, 4, 2, 2, 2],
103106
num_blocks=[2, 2, 2, 2, 2, 2],
@@ -133,6 +136,7 @@ def __init__(self, factor: int, in_channels: int = 1, *args, **kwargs):
133136
in_channels=in_channels,
134137
channels=128,
135138
patch_size=16,
139+
out_means=4,
136140
kernel_sizes_init=[1, 3, 7],
137141
multipliers=[1, 2, 4, 4, 4, 4, 4],
138142
factors=[4, 4, 4, 2, 2, 2],

audio_diffusion_pytorch/modules.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55
import torch.nn as nn
6-
from einops import rearrange
6+
from einops import rearrange, reduce
77
from einops.layers.torch import Rearrange
88
from einops_exts import rearrange_many, repeat_many
99
from einops_exts.torch import EinopsToAndFrom
@@ -25,6 +25,17 @@ def ConvTranspose1d(*args, **kwargs) -> nn.Module:
2525
return nn.ConvTranspose1d(*args, **kwargs)
2626

2727

28+
class ConvMean1d(nn.Module):
29+
def __init__(self, num_means: int, *args, **kwargs):
30+
super().__init__()
31+
self.convs = nn.ModuleList([Conv1d(*args, **kwargs) for _ in range(num_means)])
32+
33+
def forward(self, x: Tensor) -> Tensor:
34+
xs = torch.stack([conv(x) for conv in self.convs])
35+
x = reduce(xs, "n b c t -> b c t", "mean")
36+
return x
37+
38+
2839
def Downsample1d(
2940
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
3041
) -> nn.Module:
@@ -713,6 +724,7 @@ def __init__(
713724
use_skip_scale: bool,
714725
use_attention_bottleneck: bool,
715726
out_channels: Optional[int] = None,
727+
out_means: int = 1,
716728
context_channels: Optional[Sequence[int]] = None,
717729
):
718730
super().__init__()
@@ -821,7 +833,8 @@ def __init__(
821833
num_groups=resnet_groups,
822834
time_context_features=time_context_features,
823835
),
824-
Conv1d(
836+
ConvMean1d(
837+
num_means=out_means,
825838
in_channels=channels,
826839
out_channels=out_channels * patch_size,
827840
kernel_size=1,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.20",
6+
version="0.0.21",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)