2121
2222from audiolm_pytorch .t5 import t5_encode_text , get_encoded_dim , DEFAULT_T5_NAME
2323
24+ from hyper_connections import get_init_and_expand_reduce_stream_functions
25+
2426from torchaudio .functional import resample
2527
2628from audiolm_pytorch .soundstream import SoundStream
@@ -421,6 +423,7 @@ def __init__(
421423 rel_pos_bias = True ,
422424 flash_attn = False ,
423425 add_value_residual = True ,
426+ num_residual_streams = 4 ,
424427 ** kwargs
425428 ):
426429 super ().__init__ ()
@@ -438,11 +441,17 @@ def __init__(
438441
439442 self .rel_pos_bias = RelativePositionBias (dim = dim // 2 , heads = heads ) if rel_pos_bias else None
440443
444+ # hyper connections
445+
446+ init_hyper_conn , self .expand_streams , self .reduce_streams = get_init_and_expand_reduce_stream_functions (num_residual_streams , disable = num_residual_streams == 1 )
447+
448+ # layers
449+
441450 for _ in range (depth ):
442451 self .layers .append (nn .ModuleList ([
443- Attention (dim = dim , heads = heads , dropout = attn_dropout , flash = flash_attn , causal = True , ** kwargs ),
444- Attention (dim = dim , heads = heads , dropout = attn_dropout , dim_context = dim_context , flash = flash_attn , num_null_kv = 1 , norm_context = True , ** kwargs ) if cross_attend else None ,
445- FeedForward (dim = dim , dropout = ff_dropout )
452+ init_hyper_conn ( dim = dim , branch = Attention (dim = dim , heads = heads , dropout = attn_dropout , flash = flash_attn , causal = True , ** kwargs ) ),
453+ init_hyper_conn ( dim = dim , branch = Attention (dim = dim , heads = heads , dropout = attn_dropout , dim_context = dim_context , flash = flash_attn , num_null_kv = 1 , norm_context = True , ** kwargs ) ) if cross_attend else None ,
454+ init_hyper_conn ( dim = dim , branch = FeedForward (dim = dim , dropout = ff_dropout ) )
446455 ]))
447456
448457 self .norm = LayerNorm (dim )
@@ -510,6 +519,10 @@ def forward(
510519 self_attn_value_residual = None
511520 cross_attn_value_residual = None
512521
522+ # expand residual streams
523+
524+ x = self .expand_streams (x )
525+
513526 # transformer layers
514527
515528 for attn , cross_attn , ff in self .layers :
@@ -523,18 +536,21 @@ def forward(
523536
524537 new_kv_cache .append (layer_kv_cache )
525538
526- x = x + residual
527-
528539 if exists (cross_attn ):
529540 assert exists (context )
530541
531- cross_attend_out , values = cross_attn (x , context = context , mask = context_mask , return_values = True , value_residual = cross_attn_value_residual )
532- x = cross_attend_out + x
542+ x , values = cross_attn (x , context = context , mask = context_mask , return_values = True , value_residual = cross_attn_value_residual )
533543
534544 if self .add_value_residual :
535545 cross_attn_value_residual = default (cross_attn_value_residual , values )
536546
537- x = ff (x ) + x
547+ x = ff (x )
548+
549+ # reduce residual streams
550+
551+ x = self .reduce_streams (x )
552+
553+ # final norm
538554
539555 x = self .norm (x )
540556
0 commit comments