Skip to content

Commit 4fdcb35

Browse files
feat: context after downsample
1 parent f46557b commit 4fdcb35

File tree

2 files changed

+20
-18
lines changed

2 files changed

+20
-18
lines changed

audio_diffusion_pytorch/modules.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ def __init__(
473473
use_pre_downsample: bool = True,
474474
use_skip: bool = False,
475475
extract_channels: int = 0,
476+
context_channels: int = 0,
476477
use_attention: bool = False,
477478
attention_heads: Optional[int] = None,
478479
attention_features: Optional[int] = None,
@@ -484,6 +485,7 @@ def __init__(
484485
self.use_skip = use_skip
485486
self.use_attention = use_attention
486487
self.use_extract = extract_channels > 0
488+
self.use_context = context_channels > 0
487489

488490
channels = out_channels if use_pre_downsample else in_channels
489491

@@ -497,12 +499,12 @@ def __init__(
497499
self.blocks = nn.ModuleList(
498500
[
499501
ResnetBlock1d(
500-
in_channels=channels,
502+
in_channels=channels + (context_channels if i == 0 else 0),
501503
out_channels=channels,
502504
num_groups=num_groups,
503505
time_context_features=time_context_features,
504506
)
505-
for _ in range(num_layers)
507+
for i in range(num_layers)
506508
]
507509
)
508510

@@ -528,12 +530,15 @@ def __init__(
528530
)
529531

530532
def forward(
531-
self, x: Tensor, t: Optional[Tensor] = None
533+
self, x: Tensor, t: Optional[Tensor] = None, context: Optional[Tensor] = None
532534
) -> Union[Tuple[Tensor, List[Tensor]], Tensor]:
533535

534536
if self.use_pre_downsample:
535537
x = self.downsample(x)
536538

539+
if self.use_context and exists(context):
540+
x = torch.cat([x, context], dim=1)
541+
537542
skips = []
538543
for block in self.blocks:
539544
x = block(x, t)
@@ -774,9 +779,10 @@ def __init__(
774779
self.downsamples = nn.ModuleList(
775780
[
776781
DownsampleBlock1d(
777-
in_channels=channels * multipliers[i] + context_channels[i + 1],
782+
in_channels=channels * multipliers[i],
778783
out_channels=channels * multipliers[i + 1],
779784
time_context_features=time_context_features,
785+
context_channels=context_channels[i + 1],
780786
num_layers=num_blocks[i],
781787
factor=factors[i],
782788
kernel_multiplier=kernel_multiplier_downsample,
@@ -839,13 +845,13 @@ def __init__(
839845
Rearrange("b (c p) l -> b c (l p)", p=patch_size),
840846
)
841847

842-
def add_context(
843-
self, x: Tensor, context_list: Optional[Sequence[Tensor]] = None, layer: int = 0
844-
) -> Tensor:
848+
def get_context(
849+
self, context_list: Optional[Sequence[Tensor]] = None, layer: int = 0
850+
) -> Optional[Tensor]:
845851
"""Concatenates context to x, if present, and checks that shape is correct"""
846852
use_context = self.use_context and self.has_context[layer]
847853
if not use_context:
848-
return x
854+
return None
849855
assert exists(context_list), "Missing context"
850856
# Get context index (skipping zero channel contexts)
851857
context_id = self.context_ids[layer]
@@ -857,12 +863,7 @@ def add_context(
857863
channels = self.context_channels[layer]
858864
message = f"Expected context with {channels} channels at index {context_id}"
859865
assert context.shape[1] == channels, message
860-
# Check length
861-
length = x.shape[2]
862-
message = f"Expected context length of {length} at index {context_id}"
863-
assert context.shape[2] == length, message
864-
# Concatenate context
865-
return torch.cat([x, context], dim=1)
866+
return context
866867

867868
def forward(
868869
self,
@@ -871,14 +872,15 @@ def forward(
871872
*,
872873
context: Optional[Sequence[Tensor]] = None,
873874
):
874-
x = self.add_context(x, context)
875+
c = self.get_context(context)
876+
x = torch.cat([x, c], dim=1) if exists(c) else x
875877
x = self.to_in(x)
876878
t = self.to_time(t)
877879
skips_list = []
878880

879881
for i, downsample in enumerate(self.downsamples):
880-
x = self.add_context(x, context, layer=i + 1)
881-
x, skips = downsample(x, t)
882+
c = self.get_context(context, layer=i + 1)
883+
x, skips = downsample(x, t, c)
882884
skips_list += [skips]
883885

884886
x = self.bottleneck(x, t)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.22",
6+
version="0.0.23",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)