@@ -36,7 +36,6 @@ def Downsample1d(
3636 kernel_size = factor * kernel_multiplier + 1 ,
3737 stride = factor ,
3838 padding = factor * (kernel_multiplier // 2 ),
39- groups = in_channels // 4 ,
4039 )
4140
4241
@@ -718,14 +717,29 @@ def __init__(
718717 use_skip_scale : bool ,
719718 use_attention_bottleneck : bool ,
720719 out_channels : Optional [int ] = None ,
720+ context_channels : Optional [Sequence [int ]] = None ,
721721 ):
722722 super ().__init__ ()
723723
724724 out_channels = default (out_channels , in_channels )
725+ context_channels = list (default (context_channels , []))
725726 time_context_features = channels * 4
727+
726728 num_layers = len (multipliers ) - 1
727729 self .num_layers = num_layers
728730
731+ use_context = len (context_channels ) > 0
732+ self .use_context = use_context
733+
734+ context_pad_length = num_layers + 1 - len (context_channels )
735+ context_channels = context_channels + [0 ] * context_pad_length
736+ self .context_channels = context_channels
737+
738+ if use_context :
739+ has_context = [c > 0 for c in context_channels ]
740+ self .has_context = has_context
741+ self .context_id = [sum (has_context [:i ]) for i in range (len (has_context ))]
742+
729743 assert (
730744 len (factors ) == num_layers
731745 and len (attentions ) == num_layers
@@ -735,7 +749,7 @@ def __init__(
735749 self .to_in = nn .Sequential (
736750 Rearrange ("b c (l p) -> b (c p) l" , p = patch_size ),
737751 CrossEmbed1d (
738- in_channels = in_channels * patch_size ,
752+ in_channels = ( in_channels + context_channels [ 0 ]) * patch_size ,
739753 out_channels = channels ,
740754 kernel_sizes = kernel_sizes_init ,
741755 stride = 1 ,
@@ -757,7 +771,7 @@ def __init__(
757771 self .downsamples = nn .ModuleList (
758772 [
759773 DownsampleBlock1d (
760- in_channels = channels * multipliers [i ],
774+ in_channels = channels * multipliers [i ] + context_channels [ i + 1 ] ,
761775 out_channels = channels * multipliers [i + 1 ],
762776 time_context_features = time_context_features ,
763777 num_layers = num_blocks [i ],
@@ -784,10 +798,11 @@ def __init__(
784798 attention_features = attention_features ,
785799 )
786800
801+ context_channels = context_channels + [0 ] # Upsample skips first context
787802 self .upsamples = nn .ModuleList (
788803 [
789804 UpsampleBlock1d (
790- in_channels = channels * multipliers [i + 1 ],
805+ in_channels = channels * multipliers [i + 1 ] + context_channels [ i + 2 ] ,
791806 out_channels = channels * multipliers [i ],
792807 time_context_features = time_context_features ,
793808 num_layers = num_blocks [i ] + (1 if attentions [i ] else 0 ),
@@ -809,7 +824,7 @@ def __init__(
809824
810825 self .to_out = nn .Sequential (
811826 ResnetBlock1d (
812- in_channels = channels ,
827+ in_channels = channels + context_channels [ 1 ] ,
813828 out_channels = channels ,
814829 num_groups = resnet_groups ,
815830 time_context_features = time_context_features ,
@@ -822,21 +837,54 @@ def __init__(
822837 Rearrange ("b (c p) l -> b c (l p)" , p = patch_size ),
823838 )
824839
825- def forward (self , x : Tensor , t : Tensor ):
840+ def add_context (
841+ self , x : Tensor , context_list : Optional [Sequence [Tensor ]] = None , layer : int = 0
842+ ) -> Tensor :
843+ """Concatenates context to x, if present, and checks that shape is correct"""
844+ use_context = self .use_context and self .has_context [layer ]
845+ if not use_context :
846+ return x
847+ assert exists (context_list ), "Missing context"
848+ # Get context index (skipping zero channel contexts)
849+ context_id = self .context_id [layer ]
850+ # Get context
851+ context = context_list [context_id ]
852+ message = f"Missing context for layer { layer } at index { context_id } "
853+ assert exists (context ), message
854+ # Check channels
855+ channels = self .context_channels [layer ]
856+ message = f"Expected context with { channels } channels at index { context_id } "
857+ assert context .shape [1 ] == channels , message
858+ # Check length
859+ length = x .shape [2 ]
860+ message = f"Expected context length of { length } at index { context_id } "
861+ assert context .shape [2 ] == length , message
862+ # Concatenate context
863+ return torch .cat ([x , context ], dim = 1 )
826864
865+ def forward (
866+ self ,
867+ x : Tensor ,
868+ t : Tensor ,
869+ * ,
870+ context : Optional [Sequence [Tensor ]] = None ,
871+ ):
872+ x = self .add_context (x , context )
827873 x = self .to_in (x )
828874 t = self .to_time (t )
829875 skips_list = []
830876
831- for downsample in self .downsamples :
877+ for i , downsample in enumerate (self .downsamples ):
878+ x = self .add_context (x , context , layer = i + 1 )
832879 x , skips = downsample (x , t )
833880 skips_list += [skips ]
834881
835882 x = self .bottleneck (x , t )
836883
837- for upsample in self .upsamples :
884+ for i , upsample in enumerate ( self .upsamples ) :
838885 skips = skips_list .pop ()
839886 x = upsample (x , skips , t )
887+ x = self .add_context (x , context , layer = len (self .upsamples ) - i )
840888
841889 x = self .to_out (x ) # t?
842890
0 commit comments