Skip to content

Commit e22b4e9

Browse files
feat: add option to use rel pos
1 parent 7f6151e commit e22b4e9

File tree

3 files changed

+126
-4
lines changed

3 files changed

+126
-4
lines changed

audio_diffusion_pytorch/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def get_default_model_kwargs():
238238
attention_heads=8,
239239
attention_features=64,
240240
attention_multiplier=2,
241+
attention_use_rel_pos=False,
241242
resnet_groups=8,
242243
kernel_multiplier_downsample=2,
243244
use_nearest_upsample=False,

audio_diffusion_pytorch/modules.py

Lines changed: 124 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,55 @@ def __init__(
315315
"""
316316

317317

318+
class RelativePositionBias(nn.Module):
319+
def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
320+
super().__init__()
321+
self.num_buckets = num_buckets
322+
self.max_distance = max_distance
323+
self.num_heads = num_heads
324+
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
325+
326+
@staticmethod
327+
def _relative_position_bucket(
328+
relative_position: Tensor, num_buckets: int, max_distance: int
329+
):
330+
num_buckets //= 2
331+
ret = (relative_position >= 0).to(torch.long) * num_buckets
332+
n = torch.abs(relative_position)
333+
334+
max_exact = num_buckets // 2
335+
is_small = n < max_exact
336+
337+
val_if_large = (
338+
max_exact
339+
+ (
340+
torch.log(n.float() / max_exact)
341+
/ math.log(max_distance / max_exact)
342+
* (num_buckets - max_exact)
343+
).long()
344+
)
345+
val_if_large = torch.min(
346+
val_if_large, torch.full_like(val_if_large, num_buckets - 1)
347+
)
348+
349+
ret += torch.where(is_small, n, val_if_large)
350+
return ret
351+
352+
def forward(self, num_queries: int, num_keys: int) -> Tensor:
353+
i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
354+
q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
355+
k_pos = torch.arange(j, dtype=torch.long, device=device)
356+
rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")
357+
358+
relative_position_bucket = self._relative_position_bucket(
359+
rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance
360+
)
361+
362+
bias = self.relative_attention_bias(relative_position_bucket)
363+
bias = rearrange(bias, "m n h -> 1 h m n")
364+
return bias
365+
366+
318367
def FeedForward(features: int, multiplier: int) -> nn.Module:
319368
mid_features = features * multiplier
320369
return nn.Sequential(
@@ -331,19 +380,33 @@ def __init__(
331380
*,
332381
head_features: int,
333382
num_heads: int,
383+
use_rel_pos: bool,
384+
rel_pos_num_buckets: Optional[int] = None,
385+
rel_pos_max_distance: Optional[int] = None,
334386
):
335387
super().__init__()
336388
self.scale = head_features ** -0.5
337389
self.num_heads = num_heads
390+
self.use_rel_pos = use_rel_pos
338391
mid_features = head_features * num_heads
339392

393+
if use_rel_pos:
394+
assert exists(rel_pos_num_buckets) and exists(rel_pos_max_distance)
395+
self.rel_pos = RelativePositionBias(
396+
num_buckets=rel_pos_num_buckets,
397+
max_distance=rel_pos_max_distance,
398+
num_heads=num_heads,
399+
)
400+
340401
self.to_out = nn.Linear(in_features=mid_features, out_features=features)
341402

342403
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
343404
# Split heads
344405
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
345406
# Compute similarity matrix
346-
sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale
407+
sim = einsum("... n d, ... m d -> ... n m", q, k)
408+
sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
409+
sim = sim * self.scale
347410
# Get attention matrix with softmax
348411
attn = sim.softmax(dim=-1)
349412
# Compute values
@@ -360,6 +423,9 @@ def __init__(
360423
head_features: int,
361424
num_heads: int,
362425
context_features: Optional[int] = None,
426+
use_rel_pos: bool,
427+
rel_pos_num_buckets: Optional[int] = None,
428+
rel_pos_max_distance: Optional[int] = None,
363429
):
364430
super().__init__()
365431
self.context_features = context_features
@@ -375,7 +441,12 @@ def __init__(
375441
in_features=context_features, out_features=mid_features * 2, bias=False
376442
)
377443
self.attention = AttentionBase(
378-
features, num_heads=num_heads, head_features=head_features
444+
features,
445+
num_heads=num_heads,
446+
head_features=head_features,
447+
use_rel_pos=use_rel_pos,
448+
rel_pos_num_buckets=rel_pos_num_buckets,
449+
rel_pos_max_distance=rel_pos_max_distance,
379450
)
380451

381452
def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
@@ -402,14 +473,22 @@ def __init__(
402473
num_heads: int,
403474
head_features: int,
404475
multiplier: int,
476+
use_rel_pos: bool,
477+
rel_pos_num_buckets: Optional[int] = None,
478+
rel_pos_max_distance: Optional[int] = None,
405479
context_features: Optional[int] = None,
406480
):
407481
super().__init__()
408482

409483
self.use_cross_attention = exists(context_features) and context_features > 0
410484

411485
self.attention = Attention(
412-
features=features, num_heads=num_heads, head_features=head_features
486+
features=features,
487+
num_heads=num_heads,
488+
head_features=head_features,
489+
use_rel_pos=use_rel_pos,
490+
rel_pos_num_buckets=rel_pos_num_buckets,
491+
rel_pos_max_distance=rel_pos_max_distance,
413492
)
414493

415494
if self.use_cross_attention:
@@ -418,6 +497,9 @@ def __init__(
418497
num_heads=num_heads,
419498
head_features=head_features,
420499
context_features=context_features,
500+
use_rel_pos=use_rel_pos,
501+
rel_pos_num_buckets=rel_pos_num_buckets,
502+
rel_pos_max_distance=rel_pos_max_distance,
421503
)
422504

423505
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
@@ -443,6 +525,9 @@ def __init__(
443525
num_heads: int,
444526
head_features: int,
445527
multiplier: int,
528+
use_rel_pos: bool,
529+
rel_pos_num_buckets: Optional[int] = None,
530+
rel_pos_max_distance: Optional[int] = None,
446531
context_features: Optional[int] = None,
447532
):
448533
super().__init__()
@@ -465,6 +550,9 @@ def __init__(
465550
num_heads=num_heads,
466551
multiplier=multiplier,
467552
context_features=context_features,
553+
use_rel_pos=use_rel_pos,
554+
rel_pos_num_buckets=rel_pos_num_buckets,
555+
rel_pos_max_distance=rel_pos_max_distance,
468556
)
469557
for i in range(num_layers)
470558
]
@@ -552,6 +640,9 @@ def __init__(
552640
attention_heads: Optional[int] = None,
553641
attention_features: Optional[int] = None,
554642
attention_multiplier: Optional[int] = None,
643+
attention_use_rel_pos: Optional[bool] = None,
644+
attention_rel_pos_max_distance: Optional[int] = None,
645+
attention_rel_pos_num_buckets: Optional[int] = None,
555646
context_mapping_features: Optional[int] = None,
556647
context_embedding_features: Optional[int] = None,
557648
):
@@ -588,6 +679,7 @@ def __init__(
588679
exists(attention_heads)
589680
and exists(attention_features)
590681
and exists(attention_multiplier)
682+
and exists(attention_use_rel_pos)
591683
)
592684
self.transformer = Transformer1d(
593685
num_layers=num_transformer_blocks,
@@ -596,6 +688,9 @@ def __init__(
596688
head_features=attention_features,
597689
multiplier=attention_multiplier,
598690
context_features=context_embedding_features,
691+
use_rel_pos=attention_use_rel_pos,
692+
rel_pos_num_buckets=attention_rel_pos_num_buckets,
693+
rel_pos_max_distance=attention_rel_pos_max_distance,
599694
)
600695

601696
if self.use_extract:
@@ -659,6 +754,9 @@ def __init__(
659754
attention_heads: Optional[int] = None,
660755
attention_features: Optional[int] = None,
661756
attention_multiplier: Optional[int] = None,
757+
attention_use_rel_pos: Optional[bool] = None,
758+
attention_rel_pos_max_distance: Optional[int] = None,
759+
attention_rel_pos_num_buckets: Optional[int] = None,
662760
context_mapping_features: Optional[int] = None,
663761
context_embedding_features: Optional[int] = None,
664762
):
@@ -689,6 +787,7 @@ def __init__(
689787
exists(attention_heads)
690788
and exists(attention_features)
691789
and exists(attention_multiplier)
790+
and exists(attention_use_rel_pos)
692791
)
693792
self.transformer = Transformer1d(
694793
num_layers=num_transformer_blocks,
@@ -697,6 +796,9 @@ def __init__(
697796
head_features=attention_features,
698797
multiplier=attention_multiplier,
699798
context_features=context_embedding_features,
799+
use_rel_pos=attention_use_rel_pos,
800+
rel_pos_num_buckets=attention_rel_pos_num_buckets,
801+
rel_pos_max_distance=attention_rel_pos_max_distance,
700802
)
701803

702804
self.upsample = Upsample1d(
@@ -756,6 +858,9 @@ def __init__(
756858
attention_heads: Optional[int] = None,
757859
attention_features: Optional[int] = None,
758860
attention_multiplier: Optional[int] = None,
861+
attention_use_rel_pos: Optional[bool] = None,
862+
attention_rel_pos_max_distance: Optional[int] = None,
863+
attention_rel_pos_num_buckets: Optional[int] = None,
759864
context_mapping_features: Optional[int] = None,
760865
context_embedding_features: Optional[int] = None,
761866
):
@@ -774,6 +879,7 @@ def __init__(
774879
exists(attention_heads)
775880
and exists(attention_features)
776881
and exists(attention_multiplier)
882+
and exists(attention_use_rel_pos)
777883
)
778884
self.transformer = Transformer1d(
779885
num_layers=num_transformer_blocks,
@@ -782,6 +888,9 @@ def __init__(
782888
head_features=attention_features,
783889
multiplier=attention_multiplier,
784890
context_features=context_embedding_features,
891+
use_rel_pos=attention_use_rel_pos,
892+
rel_pos_num_buckets=attention_rel_pos_num_buckets,
893+
rel_pos_max_distance=attention_rel_pos_max_distance,
785894
)
786895

787896
self.post_block = ResnetBlock1d(
@@ -844,6 +953,9 @@ def __init__(
844953
context_features: Optional[int] = None,
845954
context_channels: Optional[Sequence[int]] = None,
846955
context_embedding_features: Optional[int] = None,
956+
attention_use_rel_pos: bool = False,
957+
attention_rel_pos_max_distance: Optional[int] = None,
958+
attention_rel_pos_num_buckets: Optional[int] = None,
847959
):
848960
super().__init__()
849961
out_channels = default(out_channels, in_channels)
@@ -931,6 +1043,9 @@ def __init__(
9311043
attention_heads=attention_heads,
9321044
attention_features=attention_features,
9331045
attention_multiplier=attention_multiplier,
1046+
attention_use_rel_pos=attention_use_rel_pos,
1047+
attention_rel_pos_max_distance=attention_rel_pos_max_distance,
1048+
attention_rel_pos_num_buckets=attention_rel_pos_num_buckets,
9341049
)
9351050
for i in range(num_layers)
9361051
]
@@ -945,6 +1060,9 @@ def __init__(
9451060
attention_heads=attention_heads,
9461061
attention_features=attention_features,
9471062
attention_multiplier=attention_multiplier,
1063+
attention_use_rel_pos=attention_use_rel_pos,
1064+
attention_rel_pos_max_distance=attention_rel_pos_max_distance,
1065+
attention_rel_pos_num_buckets=attention_rel_pos_num_buckets,
9481066
)
9491067

9501068
self.upsamples = nn.ModuleList(
@@ -966,6 +1084,9 @@ def __init__(
9661084
attention_heads=attention_heads,
9671085
attention_features=attention_features,
9681086
attention_multiplier=attention_multiplier,
1087+
attention_use_rel_pos=attention_use_rel_pos,
1088+
attention_rel_pos_max_distance=attention_rel_pos_max_distance,
1089+
attention_rel_pos_num_buckets=attention_rel_pos_num_buckets,
9691090
)
9701091
for i in reversed(range(num_layers))
9711092
]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.69",
6+
version="0.0.70",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)