Skip to content

Commit 8df7d67

Browse files
feat: option to add context channels at every layer
1 parent 672876b commit 8df7d67

File tree

1 file changed

+56
-8
lines changed

1 file changed

+56
-8
lines changed

audio_diffusion_pytorch/modules.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def Downsample1d(
3636
kernel_size=factor * kernel_multiplier + 1,
3737
stride=factor,
3838
padding=factor * (kernel_multiplier // 2),
39-
groups=in_channels // 4,
4039
)
4140

4241

@@ -718,14 +717,29 @@ def __init__(
718717
use_skip_scale: bool,
719718
use_attention_bottleneck: bool,
720719
out_channels: Optional[int] = None,
720+
context_channels: Optional[Sequence[int]] = None,
721721
):
722722
super().__init__()
723723

724724
out_channels = default(out_channels, in_channels)
725+
context_channels = list(default(context_channels, []))
725726
time_context_features = channels * 4
727+
726728
num_layers = len(multipliers) - 1
727729
self.num_layers = num_layers
728730

731+
use_context = len(context_channels) > 0
732+
self.use_context = use_context
733+
734+
context_pad_length = num_layers + 1 - len(context_channels)
735+
context_channels = context_channels + [0] * context_pad_length
736+
self.context_channels = context_channels
737+
738+
if use_context:
739+
has_context = [c > 0 for c in context_channels]
740+
self.has_context = has_context
741+
self.context_id = [sum(has_context[:i]) for i in range(len(has_context))]
742+
729743
assert (
730744
len(factors) == num_layers
731745
and len(attentions) == num_layers
@@ -735,7 +749,7 @@ def __init__(
735749
self.to_in = nn.Sequential(
736750
Rearrange("b c (l p) -> b (c p) l", p=patch_size),
737751
CrossEmbed1d(
738-
in_channels=in_channels * patch_size,
752+
in_channels=(in_channels + context_channels[0]) * patch_size,
739753
out_channels=channels,
740754
kernel_sizes=kernel_sizes_init,
741755
stride=1,
@@ -757,7 +771,7 @@ def __init__(
757771
self.downsamples = nn.ModuleList(
758772
[
759773
DownsampleBlock1d(
760-
in_channels=channels * multipliers[i],
774+
in_channels=channels * multipliers[i] + context_channels[i + 1],
761775
out_channels=channels * multipliers[i + 1],
762776
time_context_features=time_context_features,
763777
num_layers=num_blocks[i],
@@ -784,10 +798,11 @@ def __init__(
784798
attention_features=attention_features,
785799
)
786800

801+
context_channels = context_channels + [0] # Upsample skips first context
787802
self.upsamples = nn.ModuleList(
788803
[
789804
UpsampleBlock1d(
790-
in_channels=channels * multipliers[i + 1],
805+
in_channels=channels * multipliers[i + 1] + context_channels[i + 2],
791806
out_channels=channels * multipliers[i],
792807
time_context_features=time_context_features,
793808
num_layers=num_blocks[i] + (1 if attentions[i] else 0),
@@ -809,7 +824,7 @@ def __init__(
809824

810825
self.to_out = nn.Sequential(
811826
ResnetBlock1d(
812-
in_channels=channels,
827+
in_channels=channels + context_channels[1],
813828
out_channels=channels,
814829
num_groups=resnet_groups,
815830
time_context_features=time_context_features,
@@ -822,21 +837,54 @@ def __init__(
822837
Rearrange("b (c p) l -> b c (l p)", p=patch_size),
823838
)
824839

825-
def forward(self, x: Tensor, t: Tensor):
840+
def add_context(
841+
self, x: Tensor, context_list: Optional[Sequence[Tensor]] = None, layer: int = 0
842+
) -> Tensor:
843+
"""Concatenates context to x, if present, and checks that shape is correct"""
844+
use_context = self.use_context and self.has_context[layer]
845+
if not use_context:
846+
return x
847+
assert exists(context_list), "Missing context"
848+
# Get context index (skipping zero channel contexts)
849+
context_id = self.context_id[layer]
850+
# Get context
851+
context = context_list[context_id]
852+
message = f"Missing context for layer {layer} at index {context_id}"
853+
assert exists(context), message
854+
# Check channels
855+
channels = self.context_channels[layer]
856+
message = f"Expected context with {channels} channels at index {context_id}"
857+
assert context.shape[1] == channels, message
858+
# Check length
859+
length = x.shape[2]
860+
message = f"Expected context length of {length} at index {context_id}"
861+
assert context.shape[2] == length, message
862+
# Concatenate context
863+
return torch.cat([x, context], dim=1)
826864

865+
def forward(
866+
self,
867+
x: Tensor,
868+
t: Tensor,
869+
*,
870+
context: Optional[Sequence[Tensor]] = None,
871+
):
872+
x = self.add_context(x, context)
827873
x = self.to_in(x)
828874
t = self.to_time(t)
829875
skips_list = []
830876

831-
for downsample in self.downsamples:
877+
for i, downsample in enumerate(self.downsamples):
878+
x = self.add_context(x, context, layer=i + 1)
832879
x, skips = downsample(x, t)
833880
skips_list += [skips]
834881

835882
x = self.bottleneck(x, t)
836883

837-
for upsample in self.upsamples:
884+
for i, upsample in enumerate(self.upsamples):
838885
skips = skips_list.pop()
839886
x = upsample(x, skips, t)
887+
x = self.add_context(x, context, layer=len(self.upsamples) - i)
840888

841889
x = self.to_out(x) # t?
842890

0 commit comments

Comments
 (0)