Skip to content

Commit d65fd15

Browse files
committed
hyper connect the audio attention models
1 parent 2154a74 commit d65fd15

File tree

4 files changed

+37
-9
lines changed

4 files changed

+37
-9
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,3 +599,14 @@ $ accelerate launch train.py
599599
url = {https://api.semanticscholar.org/CorpusID:275405495}
600600
}
601601
```
602+
603+
```bibtex
604+
@article{Zhu2024HyperConnections,
605+
title = {Hyper-Connections},
606+
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
607+
journal = {ArXiv},
608+
year = {2024},
609+
volume = {abs/2409.19606},
610+
url = {https://api.semanticscholar.org/CorpusID:272987528}
611+
}
612+
```

audiolm_pytorch/audiolm_pytorch.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
from audiolm_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
2323

24+
from hyper_connections import get_init_and_expand_reduce_stream_functions
25+
2426
from torchaudio.functional import resample
2527

2628
from audiolm_pytorch.soundstream import SoundStream
@@ -421,6 +423,7 @@ def __init__(
421423
rel_pos_bias = True,
422424
flash_attn = False,
423425
add_value_residual = True,
426+
num_residual_streams = 4,
424427
**kwargs
425428
):
426429
super().__init__()
@@ -438,11 +441,17 @@ def __init__(
438441

439442
self.rel_pos_bias = RelativePositionBias(dim = dim // 2, heads = heads) if rel_pos_bias else None
440443

444+
# hyper connections
445+
446+
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
447+
448+
# layers
449+
441450
for _ in range(depth):
442451
self.layers.append(nn.ModuleList([
443-
Attention(dim = dim, heads = heads, dropout = attn_dropout, flash = flash_attn, causal = True, **kwargs),
444-
Attention(dim = dim, heads = heads, dropout = attn_dropout, dim_context = dim_context, flash = flash_attn, num_null_kv = 1, norm_context = True, **kwargs) if cross_attend else None,
445-
FeedForward(dim = dim, dropout = ff_dropout)
452+
init_hyper_conn(dim = dim, branch = Attention(dim = dim, heads = heads, dropout = attn_dropout, flash = flash_attn, causal = True, **kwargs)),
453+
init_hyper_conn(dim = dim, branch = Attention(dim = dim, heads = heads, dropout = attn_dropout, dim_context = dim_context, flash = flash_attn, num_null_kv = 1, norm_context = True, **kwargs)) if cross_attend else None,
454+
init_hyper_conn(dim = dim, branch = FeedForward(dim = dim, dropout = ff_dropout))
446455
]))
447456

448457
self.norm = LayerNorm(dim)
@@ -510,6 +519,10 @@ def forward(
510519
self_attn_value_residual = None
511520
cross_attn_value_residual = None
512521

522+
# expand residual streams
523+
524+
x = self.expand_streams(x)
525+
513526
# transformer layers
514527

515528
for attn, cross_attn, ff in self.layers:
@@ -523,18 +536,21 @@ def forward(
523536

524537
new_kv_cache.append(layer_kv_cache)
525538

526-
x = x + residual
527-
528539
if exists(cross_attn):
529540
assert exists(context)
530541

531-
cross_attend_out, values = cross_attn(x, context = context, mask = context_mask, return_values = True, value_residual = cross_attn_value_residual)
532-
x = cross_attend_out + x
542+
x, values = cross_attn(x, context = context, mask = context_mask, return_values = True, value_residual = cross_attn_value_residual)
533543

534544
if self.add_value_residual:
535545
cross_attn_value_residual = default(cross_attn_value_residual, values)
536546

537-
x = ff(x) + x
547+
x = ff(x)
548+
549+
# reduce residual streams
550+
551+
x = self.reduce_streams(x)
552+
553+
# final norm
538554

539555
x = self.norm(x)
540556

audiolm_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '2.3.1'
1+
__version__ = '2.4.0'

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
'fairseq',
2828
'wandb',
2929
'gateloop-transformer>=0.2.3',
30+
'hyper-connections>=0.1.8',
3031
'joblib',
3132
'local-attention>=1.9.0',
3233
'pytorch-warmup',

0 commit comments

Comments
 (0)