Skip to content

Commit a87fdd6

Browse files
feat: change conv out block
1 parent 508de5b commit a87fdd6

File tree

2 files changed

+29
-41
lines changed

2 files changed

+29
-41
lines changed

audio_diffusion_pytorch/modules.py

Lines changed: 28 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
import torch.nn as nn
6+
import torch.nn.functional as F
67
from einops import rearrange, reduce, repeat
78
from einops.layers.torch import Rearrange
89
from einops_exts import rearrange_many
@@ -135,8 +136,11 @@ def __init__(
135136
in_channels: int,
136137
out_channels: int,
137138
*,
138-
num_groups: int,
139+
kernel_size: int = 3,
140+
stride: int = 1,
141+
padding: int = 1,
139142
dilation: int = 1,
143+
num_groups: int,
140144
context_mapping_features: Optional[int] = None,
141145
context_embedding_features: Optional[int] = None,
142146
context_heads: Optional[int] = None,
@@ -150,8 +154,11 @@ def __init__(
150154
self.block1 = ConvBlock1d(
151155
in_channels=in_channels,
152156
out_channels=out_channels,
153-
num_groups=num_groups,
157+
kernel_size=kernel_size,
158+
stride=stride,
159+
padding=padding,
154160
dilation=dilation,
161+
num_groups=num_groups,
155162
)
156163

157164
if self.use_mapping:
@@ -211,51 +218,33 @@ def forward(
211218
class ConvOut1d(nn.Module):
212219
def __init__(
213220
self,
214-
channels: int,
215-
kernel_sizes: Sequence[int],
221+
in_channels: int,
216222
context_mapping_features: Optional[int] = None,
217223
):
218224
super().__init__()
219-
mid_channels = channels * 16
220-
self.use_mapping = exists(context_mapping_features)
225+
mid_channels = in_channels * 32
221226

222-
if self.use_mapping:
223-
assert exists(context_mapping_features)
224-
self.to_scale_shift = MappingToScaleShift(
225-
features=context_mapping_features, channels=mid_channels
226-
)
227-
228-
self.convs_in = nn.ModuleList(
229-
ConvBlock1d(
230-
in_channels=channels,
227+
self.layers = nn.ModuleList(
228+
ResnetBlock1d(
229+
in_channels=in_channels if i == 0 else mid_channels,
231230
out_channels=mid_channels,
232-
kernel_size=kernel_size,
233-
padding=(kernel_size - 1) // 2,
231+
kernel_size=3,
232+
padding=3 ** (i + 1),
233+
dilation=3 ** (i + 1),
234234
num_groups=1,
235+
context_mapping_features=context_mapping_features,
235236
)
236-
for kernel_size in kernel_sizes
237+
for i in range(3)
237238
)
238239

239-
self.conv_mid = ConvBlock1d(
240-
in_channels=mid_channels,
241-
out_channels=mid_channels,
242-
kernel_size=3,
243-
padding=1,
244-
num_groups=8,
245-
)
246-
247-
self.conv_out = Conv1d(
248-
in_channels=mid_channels, out_channels=channels, kernel_size=1
240+
self.to_out = nn.Conv1d(
241+
in_channels=mid_channels, out_channels=in_channels, kernel_size=1
249242
)
250243

251244
def forward(self, x: Tensor, mapping: Optional[Tensor] = None) -> Tensor:
252-
scale_shift = None
253-
if self.use_mapping:
254-
scale_shift = self.to_scale_shift(mapping)
255-
xs = torch.stack([conv(x) for conv in self.convs_in])
256-
x = reduce(xs, "n b c t -> b c t", "sum")
257-
x = self.conv_mid(x, scale_shift)
258-
x = self.conv_out(x)
245+
for layer in self.layers:
246+
x = F.elu(layer(x, mapping))
247+
x = self.to_out(x)
259248
return x
260249

261250

@@ -852,7 +841,7 @@ def __init__(
852841
context_features: Optional[int] = None,
853842
context_channels: Optional[Sequence[int]] = None,
854843
context_embedding_features: Optional[int] = None,
855-
kernel_sizes_out: Optional[Sequence[int]] = None,
844+
use_post_out_block: bool = False,
856845
):
857846
super().__init__()
858847

@@ -867,7 +856,7 @@ def __init__(
867856
self.use_context_time = use_context_time
868857
self.use_context_features = use_context_features
869858
self.use_context_channels = use_context_channels
870-
self.use_post_out_block = exists(kernel_sizes_out)
859+
self.use_post_out_block = use_post_out_block
871860

872861
context_channels_pad_length = num_layers + 1 - len(context_channels)
873862
context_channels = context_channels + [0] * context_channels_pad_length
@@ -996,10 +985,9 @@ def __init__(
996985
)
997986

998987
if self.use_post_out_block:
999-
assert exists(kernel_sizes_out)
1000988
self.to_post_out = ConvOut1d(
1001-
channels=out_channels,
1002-
kernel_sizes=kernel_sizes_out,
989+
in_channels=out_channels,
990+
context_mapping_features=context_mapping_features,
1003991
)
1004992

1005993
def get_channels(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.36",
6+
version="0.0.37",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)