@@ -25,15 +25,13 @@ def ConvTranspose1d(*args, **kwargs) -> nn.Module:
2525
2626
2727class ConvOut1d (nn .Module ):
28- def __init__ (
29- self , in_channels : int , out_channels : int , kernel_sizes : Sequence [int ]
30- ):
28+ def __init__ (self , channels : int , kernel_sizes : Sequence [int ]):
3129 super ().__init__ ()
32- mid_channels = in_channels * 16
30+ mid_channels = channels * 16
3331
3432 self .convs_in = nn .ModuleList (
3533 Conv1d (
36- in_channels = in_channels ,
34+ in_channels = channels ,
3735 out_channels = mid_channels ,
3836 kernel_size = kernel_size ,
3937 padding = (kernel_size - 1 ) // 2 ,
@@ -49,14 +47,15 @@ def __init__(
4947 )
5048
5149 self .conv_out = Conv1d (
52- in_channels = mid_channels , out_channels = out_channels , kernel_size = 1
50+ in_channels = mid_channels , out_channels = channels , kernel_size = 1
5351 )
5452
5553 def forward (self , x : Tensor ) -> Tensor :
54+ skip = x
5655 xs = torch .stack ([conv (x ) for conv in self .convs_in ])
57- x = reduce (xs , "n b c t -> b c t" , "sum" ) + x
56+ x = reduce (xs , "n b c t -> b c t" , "sum" )
5857 x = self .conv_mid (x )
59- x = self .conv_out (x )
58+ x = self .conv_out (x ) + skip
6059 return x
6160
6261
@@ -932,8 +931,7 @@ def __init__(
932931 ),
933932 Rearrange ("b (c p) l -> b c (l p)" , p = patch_size ),
934933 ConvOut1d (
935- in_channels = out_channels ,
936- out_channels = out_channels ,
934+ channels = out_channels ,
937935 kernel_sizes = kernel_sizes_out ,
938936 )
939937 if exists (kernel_sizes_out )
0 commit comments