Skip to content

Commit db6a21a

Browse files
feat: add option to change encoder num_blocks
1 parent d582296 commit db6a21a

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

audio_diffusion_pytorch/model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def __init__(
152152
encoder_channels: int,
153153
context_channels: int,
154154
bottleneck: Optional[Bottleneck] = None,
155+
encoder_num_blocks: Optional[Sequence[int]] = None,
155156
**kwargs
156157
):
157158
super().__init__(
@@ -172,14 +173,18 @@ def __init__(
172173
self.encoder_factor = patch_size * prod(factors[0:encoder_depth])
173174
self.bottleneck = bottleneck
174175

176+
encoder_num_blocks = default(encoder_num_blocks, num_blocks)
177+
assert_message = "The number of encoder_num_blocks must match encoder_depth"
178+
assert len(encoder_num_blocks) >= encoder_depth, assert_message
179+
175180
self.encoder = Encoder1d(
176181
in_channels=in_channels,
177182
channels=channels,
178183
patch_size=patch_size,
179184
kernel_sizes_init=kernel_sizes_init,
180185
multipliers=multipliers,
181186
factors=factors,
182-
num_blocks=num_blocks,
187+
num_blocks=encoder_num_blocks,
183188
resnet_groups=resnet_groups,
184189
kernel_multiplier_downsample=kernel_multiplier_downsample,
185190
extract_channels=[0] * (encoder_depth - 1) + [encoder_channels],

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.29",
6+
version="0.0.30",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)