@@ -473,6 +473,7 @@ def __init__(
473473 use_pre_downsample : bool = True ,
474474 use_skip : bool = False ,
475475 extract_channels : int = 0 ,
476+ context_channels : int = 0 ,
476477 use_attention : bool = False ,
477478 attention_heads : Optional [int ] = None ,
478479 attention_features : Optional [int ] = None ,
@@ -484,6 +485,7 @@ def __init__(
484485 self .use_skip = use_skip
485486 self .use_attention = use_attention
486487 self .use_extract = extract_channels > 0
488+ self .use_context = context_channels > 0
487489
488490 channels = out_channels if use_pre_downsample else in_channels
489491
@@ -497,12 +499,12 @@ def __init__(
497499 self .blocks = nn .ModuleList (
498500 [
499501 ResnetBlock1d (
500- in_channels = channels ,
502+ in_channels = channels + ( context_channels if i == 0 else 0 ) ,
501503 out_channels = channels ,
502504 num_groups = num_groups ,
503505 time_context_features = time_context_features ,
504506 )
505- for _ in range (num_layers )
507+ for i in range (num_layers )
506508 ]
507509 )
508510
@@ -528,12 +530,15 @@ def __init__(
528530 )
529531
530532 def forward (
531- self , x : Tensor , t : Optional [Tensor ] = None
533+ self , x : Tensor , t : Optional [Tensor ] = None , context : Optional [ Tensor ] = None
532534 ) -> Union [Tuple [Tensor , List [Tensor ]], Tensor ]:
533535
534536 if self .use_pre_downsample :
535537 x = self .downsample (x )
536538
539+ if self .use_context and exists (context ):
540+ x = torch .cat ([x , context ], dim = 1 )
541+
537542 skips = []
538543 for block in self .blocks :
539544 x = block (x , t )
@@ -774,9 +779,10 @@ def __init__(
774779 self .downsamples = nn .ModuleList (
775780 [
776781 DownsampleBlock1d (
777- in_channels = channels * multipliers [i ] + context_channels [ i + 1 ] ,
782+ in_channels = channels * multipliers [i ],
778783 out_channels = channels * multipliers [i + 1 ],
779784 time_context_features = time_context_features ,
785+ context_channels = context_channels [i + 1 ],
780786 num_layers = num_blocks [i ],
781787 factor = factors [i ],
782788 kernel_multiplier = kernel_multiplier_downsample ,
@@ -839,13 +845,13 @@ def __init__(
839845 Rearrange ("b (c p) l -> b c (l p)" , p = patch_size ),
840846 )
841847
842- def add_context (
843- self , x : Tensor , context_list : Optional [Sequence [Tensor ]] = None , layer : int = 0
844- ) -> Tensor :
848+ def get_context (
849+ self , context_list : Optional [Sequence [Tensor ]] = None , layer : int = 0
850+ ) -> Optional [ Tensor ] :
845851 """Concatenates context to x, if present, and checks that shape is correct"""
846852 use_context = self .use_context and self .has_context [layer ]
847853 if not use_context :
848- return x
854+ return None
849855 assert exists (context_list ), "Missing context"
850856 # Get context index (skipping zero channel contexts)
851857 context_id = self .context_ids [layer ]
@@ -857,12 +863,7 @@ def add_context(
857863 channels = self .context_channels [layer ]
858864 message = f"Expected context with { channels } channels at index { context_id } "
859865 assert context .shape [1 ] == channels , message
860- # Check length
861- length = x .shape [2 ]
862- message = f"Expected context length of { length } at index { context_id } "
863- assert context .shape [2 ] == length , message
864- # Concatenate context
865- return torch .cat ([x , context ], dim = 1 )
866+ return context
866867
867868 def forward (
868869 self ,
@@ -871,14 +872,15 @@ def forward(
871872 * ,
872873 context : Optional [Sequence [Tensor ]] = None ,
873874 ):
874- x = self .add_context (x , context )
875+ c = self .get_context (context )
876+ x = torch .cat ([x , c ], dim = 1 ) if exists (c ) else x
875877 x = self .to_in (x )
876878 t = self .to_time (t )
877879 skips_list = []
878880
879881 for i , downsample in enumerate (self .downsamples ):
880- x = self .add_context ( x , context , layer = i + 1 )
881- x , skips = downsample (x , t )
882+ c = self .get_context ( context , layer = i + 1 )
883+ x , skips = downsample (x , t , c )
882884 skips_list += [skips ]
883885
884886 x = self .bottleneck (x , t )
0 commit comments