11from math import log , pi
2- from typing import List , Optional , Sequence , Tuple
2+ from typing import List , Optional , Sequence , Tuple , Union
33
44import torch
55import torch .nn as nn
@@ -26,10 +26,7 @@ def ConvTranspose1d(*args, **kwargs) -> nn.Module:
2626
2727
2828def Downsample1d (
29- in_channels : int ,
30- out_channels : int ,
31- factor : int ,
32- kernel_multiplier : int ,
29+ in_channels : int , out_channels : int , factor : int , kernel_multiplier : int = 2
3330) -> nn .Module :
3431 assert kernel_multiplier % 2 == 0 , "Kernel multiplier must be even"
3532
@@ -464,7 +461,7 @@ def TimePositionalEmbedding(
464461
465462
466463"""
467- UNet Components
464+ Encoder/Decoder Components
468465"""
469466
470467
@@ -475,23 +472,31 @@ def __init__(
475472 out_channels : int ,
476473 * ,
477474 factor : int ,
478- kernel_multiplier : int ,
479- time_context_features : int ,
480475 num_groups : int ,
481476 num_layers : int ,
482- use_pre_downsample : bool ,
483- use_attention : bool ,
477+ kernel_multiplier : int = 2 ,
478+ use_pre_downsample : bool = True ,
479+ use_skip : bool = False ,
480+ use_attention : bool = False ,
484481 attention_heads : Optional [int ] = None ,
485482 attention_features : Optional [int ] = None ,
486483 attention_multiplier : Optional [int ] = None ,
484+ time_context_features : Optional [int ] = None ,
487485 ):
488486 super ().__init__ ()
489-
490487 self .use_pre_downsample = use_pre_downsample
488+ self .use_skip = use_skip
491489 self .use_attention = use_attention
492490
493491 channels = out_channels if use_pre_downsample else in_channels
494492
493+ self .downsample = Downsample1d (
494+ in_channels = in_channels ,
495+ out_channels = out_channels ,
496+ factor = factor ,
497+ kernel_multiplier = kernel_multiplier ,
498+ )
499+
495500 self .blocks = nn .ModuleList (
496501 [
497502 ResnetBlock1d (
@@ -517,51 +522,47 @@ def __init__(
517522 multiplier = attention_multiplier ,
518523 )
519524
520- self .downsample = Downsample1d (
521- in_channels = in_channels ,
522- out_channels = out_channels ,
523- factor = factor ,
524- kernel_multiplier = kernel_multiplier ,
525- )
526-
527- def forward (self , x : Tensor , t : Tensor ) -> Tuple [Tensor , List [Tensor ]]:
525+ def forward (
526+ self , x : Tensor , t : Optional [Tensor ] = None
527+ ) -> Union [Tuple [Tensor , List [Tensor ]], Tensor ]:
528528
529529 if self .use_pre_downsample :
530530 x = self .downsample (x )
531531
532532 skips = []
533533 for block in self .blocks :
534534 x = block (x , t )
535- skips += [x ]
535+ skips += [x ] if self . use_skip else []
536536
537537 if self .use_attention :
538538 x = self .transformer (x )
539- skips += [x ]
539+ skips += [x ] if self . use_skip else []
540540
541541 if not self .use_pre_downsample :
542542 x = self .downsample (x )
543543
544- return x , skips
544+ return ( x , skips ) if self . use_skip else x
545545
546546
547547class UpsampleBlock1d (nn .Module ):
548548 def __init__ (
549549 self ,
550550 in_channels : int ,
551- skip_channels : int ,
552551 out_channels : int ,
553552 * ,
554553 factor : int ,
555- use_nearest : bool ,
556554 num_layers : int ,
557- time_context_features : int ,
558555 num_groups : int ,
559- use_pre_upsample : bool ,
560- use_skip_scale : bool ,
561- use_attention : bool ,
556+ use_nearest : bool = False ,
557+ use_pre_upsample : bool = False ,
558+ use_skip : bool = False ,
559+ skip_channels : int = 0 ,
560+ use_skip_scale : bool = False ,
561+ use_attention : bool = False ,
562562 attention_heads : Optional [int ] = None ,
563563 attention_features : Optional [int ] = None ,
564564 attention_multiplier : Optional [int ] = None ,
565+ time_context_features : Optional [int ] = None ,
565566 ):
566567 super ().__init__ ()
567568
@@ -573,6 +574,7 @@ def __init__(
573574
574575 self .use_pre_upsample = use_pre_upsample
575576 self .use_attention = use_attention
577+ self .use_skip = use_skip
576578 self .skip_scale = 2 ** - 0.5 if use_skip_scale else 1.0
577579
578580 channels = out_channels if use_pre_upsample else in_channels
@@ -612,13 +614,18 @@ def __init__(
612614 def add_skip (self , x : Tensor , skip : Tensor ) -> Tensor :
613615 return torch .cat ([x , skip * self .skip_scale ], dim = 1 )
614616
615- def forward (self , x : Tensor , skips : List [Tensor ], t : Tensor ) -> Tensor :
617+ def forward (
618+ self ,
619+ x : Tensor ,
620+ skips : Optional [List [Tensor ]] = None ,
621+ t : Optional [Tensor ] = None ,
622+ ) -> Tensor :
616623
617624 if self .use_pre_upsample :
618625 x = self .upsample (x )
619626
620627 for block in self .blocks :
621- x = self .add_skip (x , skip = skips .pop ())
628+ x = self .add_skip (x , skip = skips .pop ()) if exists ( skips ) else x
622629 x = block (x , t )
623630
624631 if self .use_attention :
@@ -635,11 +642,11 @@ def __init__(
635642 self ,
636643 channels : int ,
637644 * ,
638- time_context_features : int ,
639645 num_groups : int ,
640- use_attention : bool ,
646+ use_attention : bool = False ,
641647 attention_heads : Optional [int ] = None ,
642648 attention_features : Optional [int ] = None ,
649+ time_context_features : Optional [int ] = None ,
643650 ):
644651 super ().__init__ ()
645652
@@ -675,14 +682,17 @@ def __init__(
675682 time_context_features = time_context_features ,
676683 )
677684
678- def forward (self , x : Tensor , t : Tensor ) -> Tensor :
685+ def forward (self , x : Tensor , t : Optional [ Tensor ] = None ) -> Tensor :
679686 x = self .pre_block (x , t )
680687 if self .use_attention :
681688 x = self .attention (x )
682689 x = self .post_block (x , t )
683690 return x
684691
685692
693+ """ UNets """
694+
695+
686696class UNet1d (nn .Module ):
687697 def __init__ (
688698 self ,
@@ -751,6 +761,7 @@ def __init__(
751761 kernel_multiplier = kernel_multiplier_downsample ,
752762 num_groups = resnet_groups ,
753763 use_pre_downsample = True ,
764+ use_skip = True ,
754765 use_attention = attentions [i ],
755766 attention_heads = attention_heads ,
756767 attention_features = attention_features ,
@@ -773,7 +784,6 @@ def __init__(
773784 [
774785 UpsampleBlock1d (
775786 in_channels = channels * multipliers [i + 1 ],
776- skip_channels = channels * multipliers [i + 1 ],
777787 out_channels = channels * multipliers [i ],
778788 time_context_features = time_context_features ,
779789 num_layers = num_blocks [i ] + (1 if attentions [i ] else 0 ),
@@ -782,6 +792,8 @@ def __init__(
782792 num_groups = resnet_groups ,
783793 use_skip_scale = use_skip_scale ,
784794 use_pre_upsample = False ,
795+ use_skip = True ,
796+ skip_channels = channels * multipliers [i + 1 ],
785797 use_attention = attentions [i ],
786798 attention_heads = attention_heads ,
787799 attention_features = attention_features ,
@@ -825,3 +837,131 @@ def forward(self, x: Tensor, t: Tensor):
825837 x = self .to_out (x ) # t?
826838
827839 return x
840+
841+
842+ """ Autoencoders """
843+
844+
845+ def gaussian_sample (mean : Tensor , logvar : Tensor ) -> Tensor :
846+ std = torch .exp (0.5 * logvar )
847+ sample = mean + std * torch .randn_like (std )
848+ return sample
849+
850+
851+ class AutoEncoder1d (nn .Module ):
852+ def __init__ (
853+ self ,
854+ in_channels : int ,
855+ bottleneck_channels : int ,
856+ channels : int ,
857+ patch_size : int ,
858+ multipliers : Sequence [int ],
859+ factors : Sequence [int ],
860+ num_blocks : Sequence [int ],
861+ resnet_groups : int ,
862+ loss_kl_weight : float ,
863+ kernel_multiplier_downsample : int = 2 ,
864+ ):
865+ super ().__init__ ()
866+
867+ num_layers = len (multipliers ) - 1
868+ self .num_layers = num_layers
869+ self .loss_kl_weight = loss_kl_weight
870+
871+ assert len (factors ) == num_layers and len (num_blocks ) == num_layers
872+
873+ self .to_in = nn .Sequential (
874+ Rearrange ("b c (l p) -> b (c p) l" , p = patch_size ),
875+ Conv1d (
876+ in_channels = in_channels * patch_size ,
877+ out_channels = channels ,
878+ kernel_size = 1 ,
879+ ),
880+ )
881+
882+ self .downsamples = nn .ModuleList (
883+ [
884+ DownsampleBlock1d (
885+ in_channels = channels * multipliers [i ],
886+ out_channels = channels * multipliers [i + 1 ],
887+ num_layers = num_blocks [i ],
888+ factor = factors [i ],
889+ kernel_multiplier = kernel_multiplier_downsample ,
890+ num_groups = resnet_groups ,
891+ )
892+ for i in range (num_layers )
893+ ]
894+ )
895+
896+ self .pre_bottleneck = Conv1d (
897+ in_channels = channels * multipliers [- 1 ],
898+ out_channels = bottleneck_channels * 2 ,
899+ kernel_size = 1 ,
900+ )
901+
902+ self .post_bottleneck = Conv1d (
903+ in_channels = bottleneck_channels ,
904+ out_channels = channels * multipliers [- 1 ],
905+ kernel_size = 1 ,
906+ )
907+
908+ self .upsamples = nn .ModuleList (
909+ [
910+ UpsampleBlock1d (
911+ in_channels = channels * multipliers [i + 1 ],
912+ out_channels = channels * multipliers [i ],
913+ num_layers = num_blocks [i ],
914+ factor = factors [i ],
915+ num_groups = resnet_groups ,
916+ )
917+ for i in reversed (range (num_layers ))
918+ ]
919+ )
920+
921+ self .to_out = nn .Sequential (
922+ Conv1d (
923+ in_channels = channels ,
924+ out_channels = in_channels * patch_size ,
925+ kernel_size = 1 ,
926+ ),
927+ Rearrange ("b (c p) l -> b c (l p)" , p = patch_size ),
928+ )
929+
930+ def encode (
931+ self , x : Tensor , * , with_kl_loss : bool = False
932+ ) -> Union [Tensor , Tuple [Tensor , Tensor ]]:
933+ x = self .to_in (x )
934+
935+ for downsample in self .downsamples :
936+ x = downsample (x )
937+
938+ mean_and_var = self .pre_bottleneck (x )
939+
940+ # Chunk channels to mean and log variance and sample in VAE style
941+ mean , logvar = torch .chunk (mean_and_var , chunks = 2 , dim = 1 )
942+ logvar = torch .clamp (logvar , - 30.0 , 20.0 )
943+ bottleneck = gaussian_sample (mean , logvar )
944+
945+ if with_kl_loss :
946+ # KL-Loss: diagonal gaussian with mean 0, variance 1, logvar 0
947+ b = x .shape [0 ]
948+ var = torch .exp (logvar )
949+ loss = 0.5 * torch .sum (torch .pow (mean , 2 ) + (var - 1.0 ) - logvar ) / b
950+ return bottleneck , loss
951+
952+ return bottleneck
953+
954+ def decode (self , x : Tensor ) -> Tensor :
955+ x = self .post_bottleneck (x )
956+
957+ for upsample in self .upsamples :
958+ x = upsample (x )
959+
960+ return self .to_out (x )
961+
962+ def forward (self , x : Tensor ) -> Tensor :
963+ """Returns autoencoding loss"""
964+ z , kl_loss = self .encode (x , with_kl_loss = True )
965+ y = self .decode (z )
966+ loss = F .mse_loss (x , y ) + kl_loss * self .loss_kl_weight
967+ return loss
0 commit comments