Skip to content

Commit bfa7952

Browse files
feat: add option to provide context tokens
1 parent db6a21a commit bfa7952

File tree

2 files changed

+112
-14
lines changed

2 files changed

+112
-14
lines changed

audio_diffusion_pytorch/modules.py

Lines changed: 111 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,14 @@ def __init__(
122122
num_groups: int,
123123
dilation: int = 1,
124124
time_context_features: Optional[int] = None,
125+
context_features: Optional[int] = None,
126+
context_heads: Optional[int] = None,
127+
context_head_features: Optional[int] = None,
125128
) -> None:
126129
super().__init__()
127130

131+
self.use_context = exists(context_features)
132+
128133
self.to_time_embedding = (
129134
nn.Sequential(
130135
nn.SiLU(),
@@ -143,6 +148,19 @@ def __init__(
143148
dilation=dilation,
144149
)
145150

151+
if self.use_context:
152+
assert exists(context_heads) and exists(context_head_features)
153+
self.cross_attend = EinopsToAndFrom(
154+
"b c l",
155+
"b l c",
156+
CrossAttention(
157+
features=out_channels,
158+
context_features=context_features,
159+
head_features=context_head_features,
160+
num_heads=context_heads,
161+
),
162+
)
163+
146164
self.block2 = ConvBlock1d(
147165
in_channels=out_channels, out_channels=out_channels, num_groups=num_groups
148166
)
@@ -153,10 +171,20 @@ def __init__(
153171
else nn.Identity()
154172
)
155173

156-
def forward(self, x: Tensor, time_context: Optional[Tensor] = None) -> Tensor:
174+
def forward(
175+
self,
176+
x: Tensor,
177+
time_context: Optional[Tensor] = None,
178+
context: Optional[Tensor] = None,
179+
) -> Tensor:
180+
assert_message = "You must provide context tokens if context_features > 0"
181+
assert not (self.use_context ^ exists(context)), assert_message
157182

158183
h = self.block1(x)
159184

185+
if self.use_context and exists(context):
186+
h = self.cross_attend(h, context=context) + h
187+
160188
# Compute scale and shift from time_context
161189
scale_shift = None
162190
if exists(self.to_time_embedding) and exists(time_context):
@@ -385,6 +413,45 @@ def forward(self, x: Tensor, *, mask: Optional[Tensor] = None) -> Tensor:
385413
return x
386414

387415

416+
class CrossAttention(nn.Module):
417+
def __init__(
418+
self,
419+
features: int,
420+
*,
421+
context_features: int = None,
422+
head_features: int = 64,
423+
num_heads: int = 8,
424+
):
425+
super().__init__()
426+
mid_features = head_features * num_heads
427+
context_features = default(context_features, features)
428+
429+
self.norm_in = LayerNorm(features=features, bias=False)
430+
self.norm_context = LayerNorm(features=context_features, bias=False)
431+
432+
self.to_q = nn.Linear(
433+
in_features=features, out_features=mid_features, bias=False
434+
)
435+
self.to_kv = nn.Linear(
436+
in_features=context_features, out_features=mid_features * 2, bias=False
437+
)
438+
self.attention = AttentionBase(
439+
features,
440+
num_heads=num_heads,
441+
head_features=head_features,
442+
use_null_tokens=False,
443+
)
444+
445+
def forward(self, x: Tensor, context: Tensor, mask: Tensor = None) -> Tensor:
446+
b, n, d = x.shape
447+
x = self.norm_in(x)
448+
context = self.norm_context(context)
449+
# Queries form x, k and v from context
450+
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
451+
x = self.attention(q, k, v, mask=mask)
452+
return x
453+
454+
388455
"""
389456
Transformer Blocks
390457
"""
@@ -468,6 +535,7 @@ def __init__(
468535
attention_features: Optional[int] = None,
469536
attention_multiplier: Optional[int] = None,
470537
time_context_features: Optional[int] = None,
538+
context_features: Optional[int] = None,
471539
):
472540
super().__init__()
473541
self.use_pre_downsample = use_pre_downsample
@@ -492,6 +560,9 @@ def __init__(
492560
out_channels=channels,
493561
num_groups=num_groups,
494562
time_context_features=time_context_features,
563+
context_features=context_features,
564+
context_heads=attention_heads,
565+
context_head_features=attention_features,
495566
)
496567
for i in range(num_layers)
497568
]
@@ -519,18 +590,22 @@ def __init__(
519590
)
520591

521592
def forward(
522-
self, x: Tensor, t: Optional[Tensor] = None, context: Optional[Tensor] = None
593+
self,
594+
x: Tensor,
595+
t: Optional[Tensor] = None,
596+
channels: Optional[Tensor] = None,
597+
tokens: Optional[Tensor] = None,
523598
) -> Union[Tuple[Tensor, List[Tensor]], Tensor]:
524599

525600
if self.use_pre_downsample:
526601
x = self.downsample(x)
527602

528-
if self.use_context and exists(context):
529-
x = torch.cat([x, context], dim=1)
603+
if self.use_context and exists(channels):
604+
x = torch.cat([x, channels], dim=1)
530605

531606
skips = []
532607
for block in self.blocks:
533-
x = block(x, t)
608+
x = block(x, t, context=tokens)
534609
skips += [x] if self.use_skip else []
535610

536611
if self.use_attention:
@@ -566,6 +641,7 @@ def __init__(
566641
attention_features: Optional[int] = None,
567642
attention_multiplier: Optional[int] = None,
568643
time_context_features: Optional[int] = None,
644+
context_features: Optional[int] = None,
569645
):
570646
super().__init__()
571647

@@ -589,6 +665,9 @@ def __init__(
589665
out_channels=channels,
590666
num_groups=num_groups,
591667
time_context_features=time_context_features,
668+
context_features=context_features,
669+
context_heads=attention_heads,
670+
context_head_features=attention_features,
592671
)
593672
for _ in range(num_layers)
594673
]
@@ -622,14 +701,15 @@ def forward(
622701
x: Tensor,
623702
skips: Optional[List[Tensor]] = None,
624703
t: Optional[Tensor] = None,
704+
tokens: Optional[Tensor] = None,
625705
) -> Tensor:
626706

627707
if self.use_pre_upsample:
628708
x = self.upsample(x)
629709

630710
for block in self.blocks:
631711
x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x
632-
x = block(x, t)
712+
x = block(x, t, context=tokens)
633713

634714
if self.use_attention:
635715
x = self.transformer(x)
@@ -650,6 +730,7 @@ def __init__(
650730
attention_heads: Optional[int] = None,
651731
attention_features: Optional[int] = None,
652732
time_context_features: Optional[int] = None,
733+
context_features: Optional[int] = None,
653734
):
654735
super().__init__()
655736

@@ -664,6 +745,9 @@ def __init__(
664745
out_channels=channels,
665746
num_groups=num_groups,
666747
time_context_features=time_context_features,
748+
context_features=context_features,
749+
context_heads=attention_heads,
750+
context_head_features=attention_features,
667751
)
668752

669753
if use_attention:
@@ -683,13 +767,21 @@ def __init__(
683767
out_channels=channels,
684768
num_groups=num_groups,
685769
time_context_features=time_context_features,
770+
context_features=context_features,
771+
context_heads=attention_heads,
772+
context_head_features=attention_features,
686773
)
687774

688-
def forward(self, x: Tensor, t: Optional[Tensor] = None) -> Tensor:
689-
x = self.pre_block(x, t)
775+
def forward(
776+
self,
777+
x: Tensor,
778+
t: Optional[Tensor] = None,
779+
tokens: Optional[Tensor] = None,
780+
) -> Tensor:
781+
x = self.pre_block(x, t, context=tokens)
690782
if self.use_attention:
691783
x = self.attention(x)
692-
x = self.post_block(x, t)
784+
x = self.post_block(x, t, context=tokens)
693785
return x
694786

695787

@@ -754,6 +846,7 @@ def __init__(
754846
use_attention_bottleneck: bool,
755847
out_channels: Optional[int] = None,
756848
context_channels: Optional[Sequence[int]] = None,
849+
context_features: Optional[int] = None,
757850
kernel_sizes_out: Optional[Sequence[int]] = None,
758851
):
759852
super().__init__()
@@ -808,6 +901,7 @@ def __init__(
808901
out_channels=channels * multipliers[i + 1],
809902
time_context_features=time_context_features,
810903
context_channels=context_channels[i + 1],
904+
context_features=context_features,
811905
num_layers=num_blocks[i],
812906
factor=factors[i],
813907
kernel_multiplier=kernel_multiplier_downsample,
@@ -826,6 +920,7 @@ def __init__(
826920
self.bottleneck = BottleneckBlock1d(
827921
channels=channels * multipliers[-1],
828922
time_context_features=time_context_features,
923+
context_features=context_features,
829924
num_groups=resnet_groups,
830925
use_attention=use_attention_bottleneck,
831926
attention_heads=attention_heads,
@@ -838,6 +933,7 @@ def __init__(
838933
in_channels=channels * multipliers[i + 1],
839934
out_channels=channels * multipliers[i],
840935
time_context_features=time_context_features,
936+
context_features=context_features,
841937
num_layers=num_blocks[i] + (1 if attentions[i] else 0),
842938
factor=factors[i],
843939
use_nearest=use_nearest_upsample,
@@ -902,23 +998,25 @@ def forward(
902998
t: Tensor,
903999
*,
9041000
context: Optional[Sequence[Tensor]] = None,
1001+
tokens: Optional[Tensor] = None,
9051002
):
9061003
c = self.get_context(context)
9071004
x = torch.cat([x, c], dim=1) if exists(c) else x
1005+
9081006
x = self.to_in(x)
9091007
t = self.to_time(t)
9101008
skips_list = []
9111009

9121010
for i, downsample in enumerate(self.downsamples):
913-
c = self.get_context(context, layer=i + 1)
914-
x, skips = downsample(x, t, c)
1011+
channels = self.get_context(context, layer=i + 1)
1012+
x, skips = downsample(x, t, channels=channels, tokens=tokens)
9151013
skips_list += [skips]
9161014

917-
x = self.bottleneck(x, t)
1015+
x = self.bottleneck(x, t, tokens=tokens)
9181016

9191017
for i, upsample in enumerate(self.upsamples):
9201018
skips = skips_list.pop()
921-
x = upsample(x, skips, t)
1019+
x = upsample(x, skips, t, tokens=tokens)
9221020

9231021
x = self.to_out(x) # t?
9241022

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.30",
6+
version="0.0.31",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)