File tree Expand file tree Collapse file tree 2 files changed +7
-2
lines changed Expand file tree Collapse file tree 2 files changed +7
-2
lines changed Original file line number Diff line number Diff line change @@ -152,6 +152,7 @@ def __init__(
152152 encoder_channels : int ,
153153 context_channels : int ,
154154 bottleneck : Optional [Bottleneck ] = None ,
155+ encoder_num_blocks : Optional [Sequence [int ]] = None ,
155156 ** kwargs
156157 ):
157158 super ().__init__ (
@@ -172,14 +173,18 @@ def __init__(
172173 self .encoder_factor = patch_size * prod (factors [0 :encoder_depth ])
173174 self .bottleneck = bottleneck
174175
176+ encoder_num_blocks = default (encoder_num_blocks , num_blocks )
177+ assert_message = "The number of encoder_num_blocks must match encoder_depth"
178+ assert len (encoder_num_blocks ) >= encoder_depth , assert_message
179+
175180 self .encoder = Encoder1d (
176181 in_channels = in_channels ,
177182 channels = channels ,
178183 patch_size = patch_size ,
179184 kernel_sizes_init = kernel_sizes_init ,
180185 multipliers = multipliers ,
181186 factors = factors ,
182- num_blocks = num_blocks ,
187+ num_blocks = encoder_num_blocks ,
183188 resnet_groups = resnet_groups ,
184189 kernel_multiplier_downsample = kernel_multiplier_downsample ,
185190 extract_channels = [0 ] * (encoder_depth - 1 ) + [encoder_channels ],
Original file line number Diff line number Diff line change 33setup (
44 name = "audio-diffusion-pytorch" ,
55 packages = find_packages (exclude = []),
6- version = "0.0.29 " ,
6+ version = "0.0.30 " ,
77 license = "MIT" ,
88 description = "Audio Diffusion - PyTorch" ,
99 long_description_content_type = "text/markdown" ,
You can’t perform that action at this time.
0 commit comments