Skip to content

Commit ed524bd

Browse files
feat: update downsample block with extract, add 1d encoder
1 parent 8df7d67 commit ed524bd

File tree

2 files changed

+95
-5
lines changed

2 files changed

+95
-5
lines changed

audio_diffusion_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
SpanBySpanComposer,
1313
)
1414
from .model import AudioAutoEncoderModel, AudioDiffusionModel, Model1d
15-
from .modules import AutoEncoder1d, UNet1d
15+
from .modules import AutoEncoder1d, Encoder1d, UNet1d

audio_diffusion_pytorch/modules.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ def __init__(
480480
kernel_multiplier: int = 2,
481481
use_pre_downsample: bool = True,
482482
use_skip: bool = False,
483+
extract_channels: int = 0,
483484
use_attention: bool = False,
484485
attention_heads: Optional[int] = None,
485486
attention_features: Optional[int] = None,
@@ -490,6 +491,7 @@ def __init__(
490491
self.use_pre_downsample = use_pre_downsample
491492
self.use_skip = use_skip
492493
self.use_attention = use_attention
494+
self.use_extract = extract_channels > 0
493495

494496
channels = out_channels if use_pre_downsample else in_channels
495497

@@ -525,6 +527,14 @@ def __init__(
525527
multiplier=attention_multiplier,
526528
)
527529

530+
if self.use_extract:
531+
num_extract_groups = min(num_groups, extract_channels)
532+
self.to_extracted = ResnetBlock1d(
533+
in_channels=out_channels,
534+
out_channels=extract_channels,
535+
num_groups=num_extract_groups,
536+
)
537+
528538
def forward(
529539
self, x: Tensor, t: Optional[Tensor] = None
530540
) -> Union[Tuple[Tensor, List[Tensor]], Tensor]:
@@ -544,6 +554,10 @@ def forward(
544554
if not self.use_pre_downsample:
545555
x = self.downsample(x)
546556

557+
if self.use_extract:
558+
extracted = self.to_extracted(x)
559+
return x, extracted
560+
547561
return (x, skips) if self.use_skip else x
548562

549563

@@ -693,7 +707,9 @@ def forward(self, x: Tensor, t: Optional[Tensor] = None) -> Tensor:
693707
return x
694708

695709

696-
""" UNets """
710+
"""
711+
UNet
712+
"""
697713

698714

699715
class UNet1d(nn.Module):
@@ -738,7 +754,7 @@ def __init__(
738754
if use_context:
739755
has_context = [c > 0 for c in context_channels]
740756
self.has_context = has_context
741-
self.context_id = [sum(has_context[:i]) for i in range(len(has_context))]
757+
self.context_ids = [sum(has_context[:i]) for i in range(len(has_context))]
742758

743759
assert (
744760
len(factors) == num_layers
@@ -846,7 +862,7 @@ def add_context(
846862
return x
847863
assert exists(context_list), "Missing context"
848864
# Get context index (skipping zero channel contexts)
849-
context_id = self.context_id[layer]
865+
context_id = self.context_ids[layer]
850866
# Get context
851867
context = context_list[context_id]
852868
message = f"Missing context for layer {layer} at index {context_id}"
@@ -891,7 +907,81 @@ def forward(
891907
return x
892908

893909

894-
""" Autoencoders """
910+
"""
911+
Encoder
912+
"""
913+
914+
915+
class Encoder1d(nn.Module):
916+
def __init__(
917+
self,
918+
in_channels: int,
919+
channels: int,
920+
patch_size: int,
921+
resnet_groups: int,
922+
kernel_multiplier_downsample: int,
923+
kernel_sizes_init: Sequence[int],
924+
multipliers: Sequence[int],
925+
factors: Sequence[int],
926+
num_blocks: Sequence[int],
927+
extract_channels: Sequence[int],
928+
):
929+
super().__init__()
930+
931+
num_layers = len(extract_channels)
932+
self.num_layers = num_layers
933+
934+
use_extract = [channels > 0 for channels in extract_channels]
935+
self.use_extract = use_extract
936+
937+
assert (
938+
len(multipliers) >= num_layers + 1
939+
and len(factors) >= num_layers
940+
and len(num_blocks) >= num_layers
941+
)
942+
943+
self.to_in = nn.Sequential(
944+
Rearrange("b c (l p) -> b (c p) l", p=patch_size),
945+
CrossEmbed1d(
946+
in_channels=in_channels * patch_size,
947+
out_channels=channels,
948+
kernel_sizes=kernel_sizes_init,
949+
stride=1,
950+
),
951+
)
952+
953+
self.downsamples = nn.ModuleList(
954+
[
955+
DownsampleBlock1d(
956+
in_channels=channels * multipliers[i],
957+
out_channels=channels * multipliers[i + 1],
958+
factor=factors[i],
959+
kernel_multiplier=kernel_multiplier_downsample,
960+
num_groups=resnet_groups,
961+
num_layers=num_blocks[i],
962+
extract_channels=extract_channels[i],
963+
)
964+
for i in range(num_layers)
965+
]
966+
)
967+
968+
def forward(self, x: Tensor) -> List[Tensor]:
969+
x = self.to_in(x)
970+
channels_list = []
971+
972+
for downsample, use_extract in zip(self.downsamples, self.use_extract):
973+
if use_extract:
974+
x, channels = downsample(x)
975+
channels_list += [channels]
976+
else:
977+
x = downsample(x)
978+
979+
return channels_list
980+
981+
982+
"""
983+
Autoencoder
984+
"""
895985

896986

897987
def gaussian_sample(mean: Tensor, logvar: Tensor) -> Tensor:

0 commit comments

Comments
 (0)