Skip to content

Commit 531dcee

Browse files
fix: forward context_features, big bug that prevented conditioning from working
1 parent 57d199d commit 531dcee

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

audio_diffusion_pytorch/modules.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,10 @@ def __init__(
414414

415415
if self.use_cross_attention:
416416
self.cross_attention = Attention(
417-
features=features, num_heads=num_heads, head_features=head_features
417+
features=features,
418+
num_heads=num_heads,
419+
head_features=head_features,
420+
context_features=context_features,
418421
)
419422

420423
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
@@ -592,6 +595,7 @@ def __init__(
592595
num_heads=attention_heads,
593596
head_features=attention_features,
594597
multiplier=attention_multiplier,
598+
context_features=context_embedding_features,
595599
)
596600

597601
if self.use_extract:
@@ -692,6 +696,7 @@ def __init__(
692696
num_heads=attention_heads,
693697
head_features=attention_features,
694698
multiplier=attention_multiplier,
699+
context_features=context_embedding_features,
695700
)
696701

697702
self.upsample = Upsample1d(
@@ -776,6 +781,7 @@ def __init__(
776781
num_heads=attention_heads,
777782
head_features=attention_features,
778783
multiplier=attention_multiplier,
784+
context_features=context_embedding_features,
779785
)
780786

781787
self.post_block = ResnetBlock1d(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.65",
6+
version="0.0.66",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)