33
44import torch
55import torch .nn as nn
6+ import torch .nn .functional as F
67from einops import rearrange , reduce , repeat
78from einops .layers .torch import Rearrange
89from einops_exts import rearrange_many
@@ -135,8 +136,11 @@ def __init__(
135136 in_channels : int ,
136137 out_channels : int ,
137138 * ,
138- num_groups : int ,
139+ kernel_size : int = 3 ,
140+ stride : int = 1 ,
141+ padding : int = 1 ,
139142 dilation : int = 1 ,
143+ num_groups : int ,
140144 context_mapping_features : Optional [int ] = None ,
141145 context_embedding_features : Optional [int ] = None ,
142146 context_heads : Optional [int ] = None ,
@@ -150,8 +154,11 @@ def __init__(
150154 self .block1 = ConvBlock1d (
151155 in_channels = in_channels ,
152156 out_channels = out_channels ,
153- num_groups = num_groups ,
157+ kernel_size = kernel_size ,
158+ stride = stride ,
159+ padding = padding ,
154160 dilation = dilation ,
161+ num_groups = num_groups ,
155162 )
156163
157164 if self .use_mapping :
@@ -211,51 +218,33 @@ def forward(
211218class ConvOut1d (nn .Module ):
212219 def __init__ (
213220 self ,
214- channels : int ,
215- kernel_sizes : Sequence [int ],
221+ in_channels : int ,
216222 context_mapping_features : Optional [int ] = None ,
217223 ):
218224 super ().__init__ ()
219- mid_channels = channels * 16
220- self .use_mapping = exists (context_mapping_features )
225+ mid_channels = in_channels * 32
221226
222- if self .use_mapping :
223- assert exists (context_mapping_features )
224- self .to_scale_shift = MappingToScaleShift (
225- features = context_mapping_features , channels = mid_channels
226- )
227-
228- self .convs_in = nn .ModuleList (
229- ConvBlock1d (
230- in_channels = channels ,
227+ self .layers = nn .ModuleList (
228+ ResnetBlock1d (
229+ in_channels = in_channels if i == 0 else mid_channels ,
231230 out_channels = mid_channels ,
232- kernel_size = kernel_size ,
233- padding = (kernel_size - 1 ) // 2 ,
231+ kernel_size = 3 ,
232+ padding = 3 ** (i + 1 ),
233+ dilation = 3 ** (i + 1 ),
234234 num_groups = 1 ,
235+ context_mapping_features = context_mapping_features ,
235236 )
236- for kernel_size in kernel_sizes
237+ for i in range ( 3 )
237238 )
238239
239- self .conv_mid = ConvBlock1d (
240- in_channels = mid_channels ,
241- out_channels = mid_channels ,
242- kernel_size = 3 ,
243- padding = 1 ,
244- num_groups = 8 ,
245- )
246-
247- self .conv_out = Conv1d (
248- in_channels = mid_channels , out_channels = channels , kernel_size = 1
240+ self .to_out = nn .Conv1d (
241+ in_channels = mid_channels , out_channels = in_channels , kernel_size = 1
249242 )
250243
251244 def forward (self , x : Tensor , mapping : Optional [Tensor ] = None ) -> Tensor :
252- scale_shift = None
253- if self .use_mapping :
254- scale_shift = self .to_scale_shift (mapping )
255- xs = torch .stack ([conv (x ) for conv in self .convs_in ])
256- x = reduce (xs , "n b c t -> b c t" , "sum" )
257- x = self .conv_mid (x , scale_shift )
258- x = self .conv_out (x )
245+ for layer in self .layers :
246+ x = F .elu (layer (x , mapping ))
247+ x = self .to_out (x )
259248 return x
260249
261250
@@ -852,7 +841,7 @@ def __init__(
852841 context_features : Optional [int ] = None ,
853842 context_channels : Optional [Sequence [int ]] = None ,
854843 context_embedding_features : Optional [int ] = None ,
855- kernel_sizes_out : Optional [ Sequence [ int ]] = None ,
844+ use_post_out_block : bool = False ,
856845 ):
857846 super ().__init__ ()
858847
@@ -867,7 +856,7 @@ def __init__(
867856 self .use_context_time = use_context_time
868857 self .use_context_features = use_context_features
869858 self .use_context_channels = use_context_channels
870- self .use_post_out_block = exists ( kernel_sizes_out )
859+ self .use_post_out_block = use_post_out_block
871860
872861 context_channels_pad_length = num_layers + 1 - len (context_channels )
873862 context_channels = context_channels + [0 ] * context_channels_pad_length
@@ -996,10 +985,9 @@ def __init__(
996985 )
997986
998987 if self .use_post_out_block :
999- assert exists (kernel_sizes_out )
1000988 self .to_post_out = ConvOut1d (
1001- channels = out_channels ,
1002- kernel_sizes = kernel_sizes_out ,
989+ in_channels = out_channels ,
990+ context_mapping_features = context_mapping_features ,
1003991 )
1004992
1005993 def get_channels (
0 commit comments