@@ -25,17 +25,6 @@ def ConvTranspose1d(*args, **kwargs) -> nn.Module:
2525 return nn .ConvTranspose1d (* args , ** kwargs )
2626
2727
28- class ConvMean1d (nn .Module ):
29- def __init__ (self , num_means : int , * args , ** kwargs ):
30- super ().__init__ ()
31- self .convs = nn .ModuleList ([Conv1d (* args , ** kwargs ) for _ in range (num_means )])
32-
33- def forward (self , x : Tensor ) -> Tensor :
34- xs = torch .stack ([conv (x ) for conv in self .convs ])
35- x = reduce (xs , "n b c t -> b c t" , "mean" )
36- return x
37-
38-
3928def Downsample1d (
4029 in_channels : int , out_channels : int , factor : int , kernel_multiplier : int = 2
4130) -> nn .Module :
@@ -709,6 +698,40 @@ def forward(self, x: Tensor, t: Optional[Tensor] = None) -> Tensor:
709698"""
710699
711700
701+ class ConvOut1d (nn .Module ):
702+ def __init__ (
703+ self , in_channels : int , out_channels : int , kernel_sizes : Sequence [int ]
704+ ):
705+ super ().__init__ ()
706+
707+ self .block1 = nn .ModuleList (
708+ Conv1d (
709+ in_channels = in_channels ,
710+ out_channels = out_channels ,
711+ kernel_size = kernel_size ,
712+ padding = (kernel_size - 1 ) // 2 ,
713+ )
714+ for kernel_size in kernel_sizes
715+ )
716+
717+ self .block2 = nn .ModuleList (
718+ Conv1d (
719+ in_channels = in_channels ,
720+ out_channels = out_channels ,
721+ kernel_size = kernel_size ,
722+ padding = (kernel_size - 1 ) // 2 ,
723+ )
724+ for kernel_size in kernel_sizes
725+ )
726+
727+ def forward (self , x : Tensor ) -> Tensor :
728+ xs = torch .stack ([x ] + [conv (x ) for conv in self .block1 ])
729+ x = reduce (xs , "n b c t -> b c t" , "sum" )
730+ xs = torch .stack ([x ] + [conv (x ) for conv in self .block2 ])
731+ x = reduce (xs , "n b c t -> b c t" , "sum" )
732+ return x
733+
734+
712735class UNet1d (nn .Module ):
713736 def __init__ (
714737 self ,
@@ -730,6 +753,7 @@ def __init__(
730753 use_attention_bottleneck : bool ,
731754 out_channels : Optional [int ] = None ,
732755 context_channels : Optional [Sequence [int ]] = None ,
756+ kernel_sizes_out : Optional [Sequence [int ]] = None ,
733757 ):
734758 super ().__init__ ()
735759
@@ -835,14 +859,20 @@ def __init__(
835859 in_channels = channels + context_channels [1 ],
836860 out_channels = channels ,
837861 num_groups = resnet_groups ,
838- time_context_features = time_context_features ,
839862 ),
840863 Conv1d (
841864 in_channels = channels ,
842865 out_channels = out_channels * patch_size ,
843866 kernel_size = 1 ,
844867 ),
845868 Rearrange ("b (c p) l -> b c (l p)" , p = patch_size ),
869+ ConvOut1d (
870+ in_channels = out_channels ,
871+ out_channels = out_channels ,
872+ kernel_sizes = kernel_sizes_out ,
873+ )
874+ if exists (kernel_sizes_out )
875+ else nn .Identity (),
846876 )
847877
848878 def get_context (
0 commit comments