55import torch .nn as nn
66from einops import rearrange , reduce
77from einops .layers .torch import Rearrange
8- from einops_exts import rearrange_many , repeat_many
8+ from einops_exts import rearrange_many
99from einops_exts .torch import EinopsToAndFrom
1010from torch import Tensor , einsum
11- from torch .nn import functional as F
1211
1312from .utils import default , exists
1413
@@ -25,6 +24,42 @@ def ConvTranspose1d(*args, **kwargs) -> nn.Module:
2524 return nn .ConvTranspose1d (* args , ** kwargs )
2625
2726
27+ class ConvOut1d (nn .Module ):
28+ def __init__ (
29+ self , in_channels : int , out_channels : int , kernel_sizes : Sequence [int ]
30+ ):
31+ super ().__init__ ()
32+ mid_channels = in_channels * 16
33+
34+ self .convs_in = nn .ModuleList (
35+ Conv1d (
36+ in_channels = in_channels ,
37+ out_channels = mid_channels ,
38+ kernel_size = kernel_size ,
39+ padding = (kernel_size - 1 ) // 2 ,
40+ )
41+ for kernel_size in kernel_sizes
42+ )
43+
44+ self .conv_mid = nn .Conv1d (
45+ in_channels = mid_channels ,
46+ out_channels = mid_channels ,
47+ kernel_size = 3 ,
48+ padding = 1 ,
49+ )
50+
51+ self .conv_out = Conv1d (
52+ in_channels = mid_channels , out_channels = out_channels , kernel_size = 1
53+ )
54+
55+ def forward (self , x : Tensor ) -> Tensor :
56+ xs = torch .stack ([conv (x ) for conv in self .convs_in ])
57+ x = reduce (xs , "n b c t -> b c t" , "sum" ) + x
58+ x = self .conv_mid (x )
59+ x = self .conv_out (x )
60+ return x
61+
62+
2863def Downsample1d (
2964 in_channels : int , out_channels : int , factor : int , kernel_multiplier : int = 2
3065) -> nn .Module :
@@ -273,25 +308,6 @@ def forward(self, x: Tensor) -> Tensor:
273308"""
274309
275310
276- class InsertNullTokens (nn .Module ):
277- def __init__ (self , head_features : int , num_heads : int ):
278- super ().__init__ ()
279- self .num_heads = num_heads
280- self .tokens = nn .Parameter (torch .randn (2 , head_features ))
281-
282- def forward (
283- self , k : Tensor , v : Tensor , * , mask : Tensor = None
284- ) -> Tuple [Tensor , Tensor , Optional [Tensor ]]:
285- b = k .shape [0 ]
286- nk , nv = repeat_many (
287- self .tokens .unbind (dim = - 2 ), "d -> b h 1 d" , h = self .num_heads , b = b
288- )
289- k = torch .cat ((nk , k ), dim = - 2 )
290- v = torch .cat ((nv , v ), dim = - 2 )
291- mask = F .pad (mask , pad = (1 , 0 ), value = True ) if exists (mask ) else None
292- return k , v , mask
293-
294-
295311def FeedForward1d (channels : int , multiplier : int = 2 ):
296312 mid_channels = int (channels * multiplier )
297313 return nn .Sequential (
@@ -324,19 +340,14 @@ def __init__(
324340 * ,
325341 head_features : int = 64 ,
326342 num_heads : int = 8 ,
327- use_null_tokens : bool = True ,
328343 out_features : Optional [int ] = None ,
329344 ):
330345 super ().__init__ ()
331346 self .scale = head_features ** - 0.5
332347 self .num_heads = num_heads
333- self .use_null_tokens = use_null_tokens
334348 mid_features = head_features * num_heads
335349 out_features = out_features if exists (out_features ) else features
336350
337- self .insert_null_tokens = InsertNullTokens (
338- head_features = head_features , num_heads = num_heads
339- )
340351 self .to_out = nn .Sequential (
341352 nn .Linear (in_features = mid_features , out_features = out_features , bias = False ),
342353 LayerNorm (features = out_features , bias = False ),
@@ -356,10 +367,6 @@ def forward(
356367 q , k , v = rearrange_many ((q , k , v ), "b n (h d) -> b h n d" , h = self .num_heads )
357368 q = q * self .scale
358369
359- # Insert null tokens
360- if self .use_null_tokens :
361- k , v , mask = self .insert_null_tokens (k , v , mask = mask )
362-
363370 # Compute similarity matrix with bias and mask
364371 sim = einsum ("... n d, ... m d -> ... n m" , q , k )
365372 sim = sim + attention_bias if exists (attention_bias ) else sim
@@ -402,7 +409,6 @@ def __init__(
402409 features ,
403410 num_heads = num_heads ,
404411 head_features = head_features ,
405- use_null_tokens = False ,
406412 out_features = out_features ,
407413 )
408414
@@ -436,10 +442,7 @@ def __init__(
436442 in_features = context_features , out_features = mid_features * 2 , bias = False
437443 )
438444 self .attention = AttentionBase (
439- features ,
440- num_heads = num_heads ,
441- head_features = head_features ,
442- use_null_tokens = False ,
445+ features , num_heads = num_heads , head_features = head_features
443446 )
444447
445448 def forward (self , x : Tensor , context : Tensor , mask : Tensor = None ) -> Tensor :
@@ -556,7 +559,7 @@ def __init__(
556559 self .blocks = nn .ModuleList (
557560 [
558561 ResnetBlock1d (
559- in_channels = channels + ( context_channels if i == 0 else 0 ) ,
562+ in_channels = channels + context_channels if i == 0 else channels ,
560563 out_channels = channels ,
561564 num_groups = num_groups ,
562565 time_context_features = time_context_features ,
@@ -790,41 +793,6 @@ def forward(
790793"""
791794
792795
793- class ConvOut1d (nn .Module ):
794- def __init__ (
795- self , in_channels : int , out_channels : int , kernel_sizes : Sequence [int ]
796- ):
797- super ().__init__ ()
798- mid_channels = in_channels * 8
799-
800- self .block1 = nn .ModuleList (
801- Conv1d (
802- in_channels = in_channels ,
803- out_channels = mid_channels ,
804- kernel_size = kernel_size ,
805- padding = (kernel_size - 1 ) // 2 ,
806- )
807- for kernel_size in kernel_sizes
808- )
809-
810- self .block2 = nn .ModuleList (
811- Conv1d (
812- in_channels = mid_channels ,
813- out_channels = out_channels ,
814- kernel_size = kernel_size ,
815- padding = (kernel_size - 1 ) // 2 ,
816- )
817- for kernel_size in kernel_sizes
818- )
819-
820- def forward (self , x : Tensor ) -> Tensor :
821- xs = torch .stack ([conv (x ) for conv in self .block1 ])
822- x = reduce (xs , "n b c t -> b c t" , "sum" )
823- xs = torch .stack ([conv (x ) for conv in self .block2 ])
824- x = reduce (xs , "n b c t -> b c t" , "sum" )
825- return x
826-
827-
828796class UNet1d (nn .Module ):
829797 def __init__ (
830798 self ,
@@ -953,7 +921,7 @@ def __init__(
953921
954922 self .to_out = nn .Sequential (
955923 ResnetBlock1d (
956- in_channels = channels + context_channels [ 1 ] ,
924+ in_channels = channels ,
957925 out_channels = channels ,
958926 num_groups = resnet_groups ,
959927 ),
0 commit comments