@@ -480,6 +480,7 @@ def __init__(
480480 kernel_multiplier : int = 2 ,
481481 use_pre_downsample : bool = True ,
482482 use_skip : bool = False ,
483+ extract_channels : int = 0 ,
483484 use_attention : bool = False ,
484485 attention_heads : Optional [int ] = None ,
485486 attention_features : Optional [int ] = None ,
@@ -490,6 +491,7 @@ def __init__(
490491 self .use_pre_downsample = use_pre_downsample
491492 self .use_skip = use_skip
492493 self .use_attention = use_attention
494+ self .use_extract = extract_channels > 0
493495
494496 channels = out_channels if use_pre_downsample else in_channels
495497
@@ -525,6 +527,14 @@ def __init__(
525527 multiplier = attention_multiplier ,
526528 )
527529
530+ if self .use_extract :
531+ num_extract_groups = min (num_groups , extract_channels )
532+ self .to_extracted = ResnetBlock1d (
533+ in_channels = out_channels ,
534+ out_channels = extract_channels ,
535+ num_groups = num_extract_groups ,
536+ )
537+
528538 def forward (
529539 self , x : Tensor , t : Optional [Tensor ] = None
530540 ) -> Union [Tuple [Tensor , List [Tensor ]], Tensor ]:
@@ -544,6 +554,10 @@ def forward(
544554 if not self .use_pre_downsample :
545555 x = self .downsample (x )
546556
557+ if self .use_extract :
558+ extracted = self .to_extracted (x )
559+ return x , extracted
560+
547561 return (x , skips ) if self .use_skip else x
548562
549563
@@ -693,7 +707,9 @@ def forward(self, x: Tensor, t: Optional[Tensor] = None) -> Tensor:
693707 return x
694708
695709
696- """ UNets """
710+ """
711+ UNet
712+ """
697713
698714
699715class UNet1d (nn .Module ):
@@ -738,7 +754,7 @@ def __init__(
738754 if use_context :
739755 has_context = [c > 0 for c in context_channels ]
740756 self .has_context = has_context
741- self .context_id = [sum (has_context [:i ]) for i in range (len (has_context ))]
757+ self .context_ids = [sum (has_context [:i ]) for i in range (len (has_context ))]
742758
743759 assert (
744760 len (factors ) == num_layers
@@ -846,7 +862,7 @@ def add_context(
846862 return x
847863 assert exists (context_list ), "Missing context"
848864 # Get context index (skipping zero channel contexts)
849- context_id = self .context_id [layer ]
865+ context_id = self .context_ids [layer ]
850866 # Get context
851867 context = context_list [context_id ]
852868 message = f"Missing context for layer { layer } at index { context_id } "
@@ -891,7 +907,81 @@ def forward(
891907 return x
892908
893909
894- """ Autoencoders """
910+ """
911+ Encoder
912+ """
913+
914+
915+ class Encoder1d (nn .Module ):
916+ def __init__ (
917+ self ,
918+ in_channels : int ,
919+ channels : int ,
920+ patch_size : int ,
921+ resnet_groups : int ,
922+ kernel_multiplier_downsample : int ,
923+ kernel_sizes_init : Sequence [int ],
924+ multipliers : Sequence [int ],
925+ factors : Sequence [int ],
926+ num_blocks : Sequence [int ],
927+ extract_channels : Sequence [int ],
928+ ):
929+ super ().__init__ ()
930+
931+ num_layers = len (extract_channels )
932+ self .num_layers = num_layers
933+
934+ use_extract = [channels > 0 for channels in extract_channels ]
935+ self .use_extract = use_extract
936+
937+ assert (
938+ len (multipliers ) >= num_layers + 1
939+ and len (factors ) >= num_layers
940+ and len (num_blocks ) >= num_layers
941+ )
942+
943+ self .to_in = nn .Sequential (
944+ Rearrange ("b c (l p) -> b (c p) l" , p = patch_size ),
945+ CrossEmbed1d (
946+ in_channels = in_channels * patch_size ,
947+ out_channels = channels ,
948+ kernel_sizes = kernel_sizes_init ,
949+ stride = 1 ,
950+ ),
951+ )
952+
953+ self .downsamples = nn .ModuleList (
954+ [
955+ DownsampleBlock1d (
956+ in_channels = channels * multipliers [i ],
957+ out_channels = channels * multipliers [i + 1 ],
958+ factor = factors [i ],
959+ kernel_multiplier = kernel_multiplier_downsample ,
960+ num_groups = resnet_groups ,
961+ num_layers = num_blocks [i ],
962+ extract_channels = extract_channels [i ],
963+ )
964+ for i in range (num_layers )
965+ ]
966+ )
967+
968+ def forward (self , x : Tensor ) -> List [Tensor ]:
969+ x = self .to_in (x )
970+ channels_list = []
971+
972+ for downsample , use_extract in zip (self .downsamples , self .use_extract ):
973+ if use_extract :
974+ x , channels = downsample (x )
975+ channels_list += [channels ]
976+ else :
977+ x = downsample (x )
978+
979+ return channels_list
980+
981+
982+ """
983+ Autoencoder
984+ """
895985
896986
897987def gaussian_sample (mean : Tensor , logvar : Tensor ) -> Tensor :
0 commit comments