@@ -122,9 +122,14 @@ def __init__(
122122 num_groups : int ,
123123 dilation : int = 1 ,
124124 time_context_features : Optional [int ] = None ,
125+ context_features : Optional [int ] = None ,
126+ context_heads : Optional [int ] = None ,
127+ context_head_features : Optional [int ] = None ,
125128 ) -> None :
126129 super ().__init__ ()
127130
131+ self .use_context = exists (context_features )
132+
128133 self .to_time_embedding = (
129134 nn .Sequential (
130135 nn .SiLU (),
@@ -143,6 +148,19 @@ def __init__(
143148 dilation = dilation ,
144149 )
145150
151+ if self .use_context :
152+ assert exists (context_heads ) and exists (context_head_features )
153+ self .cross_attend = EinopsToAndFrom (
154+ "b c l" ,
155+ "b l c" ,
156+ CrossAttention (
157+ features = out_channels ,
158+ context_features = context_features ,
159+ head_features = context_head_features ,
160+ num_heads = context_heads ,
161+ ),
162+ )
163+
146164 self .block2 = ConvBlock1d (
147165 in_channels = out_channels , out_channels = out_channels , num_groups = num_groups
148166 )
@@ -153,10 +171,20 @@ def __init__(
153171 else nn .Identity ()
154172 )
155173
156- def forward (self , x : Tensor , time_context : Optional [Tensor ] = None ) -> Tensor :
174+ def forward (
175+ self ,
176+ x : Tensor ,
177+ time_context : Optional [Tensor ] = None ,
178+ context : Optional [Tensor ] = None ,
179+ ) -> Tensor :
180+ assert_message = "You must provide context tokens if context_features > 0"
181+ assert not (self .use_context ^ exists (context )), assert_message
157182
158183 h = self .block1 (x )
159184
185+ if self .use_context and exists (context ):
186+ h = self .cross_attend (h , context = context ) + h
187+
160188 # Compute scale and shift from time_context
161189 scale_shift = None
162190 if exists (self .to_time_embedding ) and exists (time_context ):
@@ -385,6 +413,45 @@ def forward(self, x: Tensor, *, mask: Optional[Tensor] = None) -> Tensor:
385413 return x
386414
387415
416+ class CrossAttention (nn .Module ):
417+ def __init__ (
418+ self ,
419+ features : int ,
420+ * ,
421+ context_features : int = None ,
422+ head_features : int = 64 ,
423+ num_heads : int = 8 ,
424+ ):
425+ super ().__init__ ()
426+ mid_features = head_features * num_heads
427+ context_features = default (context_features , features )
428+
429+ self .norm_in = LayerNorm (features = features , bias = False )
430+ self .norm_context = LayerNorm (features = context_features , bias = False )
431+
432+ self .to_q = nn .Linear (
433+ in_features = features , out_features = mid_features , bias = False
434+ )
435+ self .to_kv = nn .Linear (
436+ in_features = context_features , out_features = mid_features * 2 , bias = False
437+ )
438+ self .attention = AttentionBase (
439+ features ,
440+ num_heads = num_heads ,
441+ head_features = head_features ,
442+ use_null_tokens = False ,
443+ )
444+
445+ def forward (self , x : Tensor , context : Tensor , mask : Tensor = None ) -> Tensor :
446+ b , n , d = x .shape
447+ x = self .norm_in (x )
448+ context = self .norm_context (context )
449+ # Queries form x, k and v from context
450+ q , k , v = (self .to_q (x ), * torch .chunk (self .to_kv (context ), chunks = 2 , dim = - 1 ))
451+ x = self .attention (q , k , v , mask = mask )
452+ return x
453+
454+
388455"""
389456Transformer Blocks
390457"""
@@ -468,6 +535,7 @@ def __init__(
468535 attention_features : Optional [int ] = None ,
469536 attention_multiplier : Optional [int ] = None ,
470537 time_context_features : Optional [int ] = None ,
538+ context_features : Optional [int ] = None ,
471539 ):
472540 super ().__init__ ()
473541 self .use_pre_downsample = use_pre_downsample
@@ -492,6 +560,9 @@ def __init__(
492560 out_channels = channels ,
493561 num_groups = num_groups ,
494562 time_context_features = time_context_features ,
563+ context_features = context_features ,
564+ context_heads = attention_heads ,
565+ context_head_features = attention_features ,
495566 )
496567 for i in range (num_layers )
497568 ]
@@ -519,18 +590,22 @@ def __init__(
519590 )
520591
521592 def forward (
522- self , x : Tensor , t : Optional [Tensor ] = None , context : Optional [Tensor ] = None
593+ self ,
594+ x : Tensor ,
595+ t : Optional [Tensor ] = None ,
596+ channels : Optional [Tensor ] = None ,
597+ tokens : Optional [Tensor ] = None ,
523598 ) -> Union [Tuple [Tensor , List [Tensor ]], Tensor ]:
524599
525600 if self .use_pre_downsample :
526601 x = self .downsample (x )
527602
528- if self .use_context and exists (context ):
529- x = torch .cat ([x , context ], dim = 1 )
603+ if self .use_context and exists (channels ):
604+ x = torch .cat ([x , channels ], dim = 1 )
530605
531606 skips = []
532607 for block in self .blocks :
533- x = block (x , t )
608+ x = block (x , t , context = tokens )
534609 skips += [x ] if self .use_skip else []
535610
536611 if self .use_attention :
@@ -566,6 +641,7 @@ def __init__(
566641 attention_features : Optional [int ] = None ,
567642 attention_multiplier : Optional [int ] = None ,
568643 time_context_features : Optional [int ] = None ,
644+ context_features : Optional [int ] = None ,
569645 ):
570646 super ().__init__ ()
571647
@@ -589,6 +665,9 @@ def __init__(
589665 out_channels = channels ,
590666 num_groups = num_groups ,
591667 time_context_features = time_context_features ,
668+ context_features = context_features ,
669+ context_heads = attention_heads ,
670+ context_head_features = attention_features ,
592671 )
593672 for _ in range (num_layers )
594673 ]
@@ -622,14 +701,15 @@ def forward(
622701 x : Tensor ,
623702 skips : Optional [List [Tensor ]] = None ,
624703 t : Optional [Tensor ] = None ,
704+ tokens : Optional [Tensor ] = None ,
625705 ) -> Tensor :
626706
627707 if self .use_pre_upsample :
628708 x = self .upsample (x )
629709
630710 for block in self .blocks :
631711 x = self .add_skip (x , skip = skips .pop ()) if exists (skips ) else x
632- x = block (x , t )
712+ x = block (x , t , context = tokens )
633713
634714 if self .use_attention :
635715 x = self .transformer (x )
@@ -650,6 +730,7 @@ def __init__(
650730 attention_heads : Optional [int ] = None ,
651731 attention_features : Optional [int ] = None ,
652732 time_context_features : Optional [int ] = None ,
733+ context_features : Optional [int ] = None ,
653734 ):
654735 super ().__init__ ()
655736
@@ -664,6 +745,9 @@ def __init__(
664745 out_channels = channels ,
665746 num_groups = num_groups ,
666747 time_context_features = time_context_features ,
748+ context_features = context_features ,
749+ context_heads = attention_heads ,
750+ context_head_features = attention_features ,
667751 )
668752
669753 if use_attention :
@@ -683,13 +767,21 @@ def __init__(
683767 out_channels = channels ,
684768 num_groups = num_groups ,
685769 time_context_features = time_context_features ,
770+ context_features = context_features ,
771+ context_heads = attention_heads ,
772+ context_head_features = attention_features ,
686773 )
687774
688- def forward (self , x : Tensor , t : Optional [Tensor ] = None ) -> Tensor :
689- x = self .pre_block (x , t )
775+ def forward (
776+ self ,
777+ x : Tensor ,
778+ t : Optional [Tensor ] = None ,
779+ tokens : Optional [Tensor ] = None ,
780+ ) -> Tensor :
781+ x = self .pre_block (x , t , context = tokens )
690782 if self .use_attention :
691783 x = self .attention (x )
692- x = self .post_block (x , t )
784+ x = self .post_block (x , t , context = tokens )
693785 return x
694786
695787
@@ -754,6 +846,7 @@ def __init__(
754846 use_attention_bottleneck : bool ,
755847 out_channels : Optional [int ] = None ,
756848 context_channels : Optional [Sequence [int ]] = None ,
849+ context_features : Optional [int ] = None ,
757850 kernel_sizes_out : Optional [Sequence [int ]] = None ,
758851 ):
759852 super ().__init__ ()
@@ -808,6 +901,7 @@ def __init__(
808901 out_channels = channels * multipliers [i + 1 ],
809902 time_context_features = time_context_features ,
810903 context_channels = context_channels [i + 1 ],
904+ context_features = context_features ,
811905 num_layers = num_blocks [i ],
812906 factor = factors [i ],
813907 kernel_multiplier = kernel_multiplier_downsample ,
@@ -826,6 +920,7 @@ def __init__(
826920 self .bottleneck = BottleneckBlock1d (
827921 channels = channels * multipliers [- 1 ],
828922 time_context_features = time_context_features ,
923+ context_features = context_features ,
829924 num_groups = resnet_groups ,
830925 use_attention = use_attention_bottleneck ,
831926 attention_heads = attention_heads ,
@@ -838,6 +933,7 @@ def __init__(
838933 in_channels = channels * multipliers [i + 1 ],
839934 out_channels = channels * multipliers [i ],
840935 time_context_features = time_context_features ,
936+ context_features = context_features ,
841937 num_layers = num_blocks [i ] + (1 if attentions [i ] else 0 ),
842938 factor = factors [i ],
843939 use_nearest = use_nearest_upsample ,
@@ -902,23 +998,25 @@ def forward(
902998 t : Tensor ,
903999 * ,
9041000 context : Optional [Sequence [Tensor ]] = None ,
1001+ tokens : Optional [Tensor ] = None ,
9051002 ):
9061003 c = self .get_context (context )
9071004 x = torch .cat ([x , c ], dim = 1 ) if exists (c ) else x
1005+
9081006 x = self .to_in (x )
9091007 t = self .to_time (t )
9101008 skips_list = []
9111009
9121010 for i , downsample in enumerate (self .downsamples ):
913- c = self .get_context (context , layer = i + 1 )
914- x , skips = downsample (x , t , c )
1011+ channels = self .get_context (context , layer = i + 1 )
1012+ x , skips = downsample (x , t , channels = channels , tokens = tokens )
9151013 skips_list += [skips ]
9161014
917- x = self .bottleneck (x , t )
1015+ x = self .bottleneck (x , t , tokens = tokens )
9181016
9191017 for i , upsample in enumerate (self .upsamples ):
9201018 skips = skips_list .pop ()
921- x = upsample (x , skips , t )
1019+ x = upsample (x , skips , t , tokens = tokens )
9221020
9231021 x = self .to_out (x ) # t?
9241022
0 commit comments