Skip to content

Commit 35c1ba5

Browse files
feat: add autoencoder1d
1 parent 50ecc30 commit 35c1ba5

File tree

2 files changed

+175
-35
lines changed

2 files changed

+175
-35
lines changed

audio_diffusion_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
SpanBySpanComposer,
1010
)
1111
from .model import AudioDiffusionModel, Model1d
12-
from .modules import UNet1d
12+
from .modules import AutoEncoder1d, UNet1d

audio_diffusion_pytorch/modules.py

Lines changed: 174 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from math import log, pi
2-
from typing import List, Optional, Sequence, Tuple
2+
from typing import List, Optional, Sequence, Tuple, Union
33

44
import torch
55
import torch.nn as nn
@@ -26,10 +26,7 @@ def ConvTranspose1d(*args, **kwargs) -> nn.Module:
2626

2727

2828
def Downsample1d(
29-
in_channels: int,
30-
out_channels: int,
31-
factor: int,
32-
kernel_multiplier: int,
29+
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
3330
) -> nn.Module:
3431
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
3532

@@ -464,7 +461,7 @@ def TimePositionalEmbedding(
464461

465462

466463
"""
467-
UNet Components
464+
Encoder/Decoder Components
468465
"""
469466

470467

@@ -475,23 +472,31 @@ def __init__(
475472
out_channels: int,
476473
*,
477474
factor: int,
478-
kernel_multiplier: int,
479-
time_context_features: int,
480475
num_groups: int,
481476
num_layers: int,
482-
use_pre_downsample: bool,
483-
use_attention: bool,
477+
kernel_multiplier: int = 2,
478+
use_pre_downsample: bool = True,
479+
use_skip: bool = False,
480+
use_attention: bool = False,
484481
attention_heads: Optional[int] = None,
485482
attention_features: Optional[int] = None,
486483
attention_multiplier: Optional[int] = None,
484+
time_context_features: Optional[int] = None,
487485
):
488486
super().__init__()
489-
490487
self.use_pre_downsample = use_pre_downsample
488+
self.use_skip = use_skip
491489
self.use_attention = use_attention
492490

493491
channels = out_channels if use_pre_downsample else in_channels
494492

493+
self.downsample = Downsample1d(
494+
in_channels=in_channels,
495+
out_channels=out_channels,
496+
factor=factor,
497+
kernel_multiplier=kernel_multiplier,
498+
)
499+
495500
self.blocks = nn.ModuleList(
496501
[
497502
ResnetBlock1d(
@@ -517,51 +522,47 @@ def __init__(
517522
multiplier=attention_multiplier,
518523
)
519524

520-
self.downsample = Downsample1d(
521-
in_channels=in_channels,
522-
out_channels=out_channels,
523-
factor=factor,
524-
kernel_multiplier=kernel_multiplier,
525-
)
526-
527-
def forward(self, x: Tensor, t: Tensor) -> Tuple[Tensor, List[Tensor]]:
525+
def forward(
526+
self, x: Tensor, t: Optional[Tensor] = None
527+
) -> Union[Tuple[Tensor, List[Tensor]], Tensor]:
528528

529529
if self.use_pre_downsample:
530530
x = self.downsample(x)
531531

532532
skips = []
533533
for block in self.blocks:
534534
x = block(x, t)
535-
skips += [x]
535+
skips += [x] if self.use_skip else []
536536

537537
if self.use_attention:
538538
x = self.transformer(x)
539-
skips += [x]
539+
skips += [x] if self.use_skip else []
540540

541541
if not self.use_pre_downsample:
542542
x = self.downsample(x)
543543

544-
return x, skips
544+
return (x, skips) if self.use_skip else x
545545

546546

547547
class UpsampleBlock1d(nn.Module):
548548
def __init__(
549549
self,
550550
in_channels: int,
551-
skip_channels: int,
552551
out_channels: int,
553552
*,
554553
factor: int,
555-
use_nearest: bool,
556554
num_layers: int,
557-
time_context_features: int,
558555
num_groups: int,
559-
use_pre_upsample: bool,
560-
use_skip_scale: bool,
561-
use_attention: bool,
556+
use_nearest: bool = False,
557+
use_pre_upsample: bool = False,
558+
use_skip: bool = False,
559+
skip_channels: int = 0,
560+
use_skip_scale: bool = False,
561+
use_attention: bool = False,
562562
attention_heads: Optional[int] = None,
563563
attention_features: Optional[int] = None,
564564
attention_multiplier: Optional[int] = None,
565+
time_context_features: Optional[int] = None,
565566
):
566567
super().__init__()
567568

@@ -573,6 +574,7 @@ def __init__(
573574

574575
self.use_pre_upsample = use_pre_upsample
575576
self.use_attention = use_attention
577+
self.use_skip = use_skip
576578
self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0
577579

578580
channels = out_channels if use_pre_upsample else in_channels
@@ -612,13 +614,18 @@ def __init__(
612614
def add_skip(self, x: Tensor, skip: Tensor) -> Tensor:
613615
return torch.cat([x, skip * self.skip_scale], dim=1)
614616

615-
def forward(self, x: Tensor, skips: List[Tensor], t: Tensor) -> Tensor:
617+
def forward(
618+
self,
619+
x: Tensor,
620+
skips: Optional[List[Tensor]] = None,
621+
t: Optional[Tensor] = None,
622+
) -> Tensor:
616623

617624
if self.use_pre_upsample:
618625
x = self.upsample(x)
619626

620627
for block in self.blocks:
621-
x = self.add_skip(x, skip=skips.pop())
628+
x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x
622629
x = block(x, t)
623630

624631
if self.use_attention:
@@ -635,11 +642,11 @@ def __init__(
635642
self,
636643
channels: int,
637644
*,
638-
time_context_features: int,
639645
num_groups: int,
640-
use_attention: bool,
646+
use_attention: bool = False,
641647
attention_heads: Optional[int] = None,
642648
attention_features: Optional[int] = None,
649+
time_context_features: Optional[int] = None,
643650
):
644651
super().__init__()
645652

@@ -675,14 +682,17 @@ def __init__(
675682
time_context_features=time_context_features,
676683
)
677684

678-
def forward(self, x: Tensor, t: Tensor) -> Tensor:
685+
def forward(self, x: Tensor, t: Optional[Tensor] = None) -> Tensor:
679686
x = self.pre_block(x, t)
680687
if self.use_attention:
681688
x = self.attention(x)
682689
x = self.post_block(x, t)
683690
return x
684691

685692

693+
""" UNets """
694+
695+
686696
class UNet1d(nn.Module):
687697
def __init__(
688698
self,
@@ -751,6 +761,7 @@ def __init__(
751761
kernel_multiplier=kernel_multiplier_downsample,
752762
num_groups=resnet_groups,
753763
use_pre_downsample=True,
764+
use_skip=True,
754765
use_attention=attentions[i],
755766
attention_heads=attention_heads,
756767
attention_features=attention_features,
@@ -773,7 +784,6 @@ def __init__(
773784
[
774785
UpsampleBlock1d(
775786
in_channels=channels * multipliers[i + 1],
776-
skip_channels=channels * multipliers[i + 1],
777787
out_channels=channels * multipliers[i],
778788
time_context_features=time_context_features,
779789
num_layers=num_blocks[i] + (1 if attentions[i] else 0),
@@ -782,6 +792,8 @@ def __init__(
782792
num_groups=resnet_groups,
783793
use_skip_scale=use_skip_scale,
784794
use_pre_upsample=False,
795+
use_skip=True,
796+
skip_channels=channels * multipliers[i + 1],
785797
use_attention=attentions[i],
786798
attention_heads=attention_heads,
787799
attention_features=attention_features,
@@ -825,3 +837,131 @@ def forward(self, x: Tensor, t: Tensor):
825837
x = self.to_out(x) # t?
826838

827839
return x
840+
841+
842+
""" Autoencoders """
843+
844+
845+
def gaussian_sample(mean: Tensor, logvar: Tensor) -> Tensor:
846+
std = torch.exp(0.5 * logvar)
847+
sample = mean + std * torch.randn_like(std)
848+
return sample
849+
850+
851+
class AutoEncoder1d(nn.Module):
852+
def __init__(
853+
self,
854+
in_channels: int,
855+
bottleneck_channels: int,
856+
channels: int,
857+
patch_size: int,
858+
multipliers: Sequence[int],
859+
factors: Sequence[int],
860+
num_blocks: Sequence[int],
861+
resnet_groups: int,
862+
loss_kl_weight: float,
863+
kernel_multiplier_downsample: int = 2,
864+
):
865+
super().__init__()
866+
867+
num_layers = len(multipliers) - 1
868+
self.num_layers = num_layers
869+
self.loss_kl_weight = loss_kl_weight
870+
871+
assert len(factors) == num_layers and len(num_blocks) == num_layers
872+
873+
self.to_in = nn.Sequential(
874+
Rearrange("b c (l p) -> b (c p) l", p=patch_size),
875+
Conv1d(
876+
in_channels=in_channels * patch_size,
877+
out_channels=channels,
878+
kernel_size=1,
879+
),
880+
)
881+
882+
self.downsamples = nn.ModuleList(
883+
[
884+
DownsampleBlock1d(
885+
in_channels=channels * multipliers[i],
886+
out_channels=channels * multipliers[i + 1],
887+
num_layers=num_blocks[i],
888+
factor=factors[i],
889+
kernel_multiplier=kernel_multiplier_downsample,
890+
num_groups=resnet_groups,
891+
)
892+
for i in range(num_layers)
893+
]
894+
)
895+
896+
self.pre_bottleneck = Conv1d(
897+
in_channels=channels * multipliers[-1],
898+
out_channels=bottleneck_channels * 2,
899+
kernel_size=1,
900+
)
901+
902+
self.post_bottleneck = Conv1d(
903+
in_channels=bottleneck_channels,
904+
out_channels=channels * multipliers[-1],
905+
kernel_size=1,
906+
)
907+
908+
self.upsamples = nn.ModuleList(
909+
[
910+
UpsampleBlock1d(
911+
in_channels=channels * multipliers[i + 1],
912+
out_channels=channels * multipliers[i],
913+
num_layers=num_blocks[i],
914+
factor=factors[i],
915+
num_groups=resnet_groups,
916+
)
917+
for i in reversed(range(num_layers))
918+
]
919+
)
920+
921+
self.to_out = nn.Sequential(
922+
Conv1d(
923+
in_channels=channels,
924+
out_channels=in_channels * patch_size,
925+
kernel_size=1,
926+
),
927+
Rearrange("b (c p) l -> b c (l p)", p=patch_size),
928+
)
929+
930+
def encode(
931+
self, x: Tensor, *, with_kl_loss: bool = False
932+
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
933+
x = self.to_in(x)
934+
935+
for downsample in self.downsamples:
936+
x = downsample(x)
937+
938+
mean_and_var = self.pre_bottleneck(x)
939+
940+
# Chunk channels to mean and log variance and sample in VAE style
941+
mean, logvar = torch.chunk(mean_and_var, chunks=2, dim=1)
942+
logvar = torch.clamp(logvar, -30.0, 20.0)
943+
bottleneck = gaussian_sample(mean, logvar)
944+
945+
if with_kl_loss:
946+
# KL-Loss: diagonal gaussian with mean 0, variance 1, logvar 0
947+
b = x.shape[0]
948+
var = torch.exp(logvar)
949+
loss = 0.5 * torch.sum(torch.pow(mean, 2) + (var - 1.0) - logvar) / b
950+
return bottleneck, loss
951+
952+
return bottleneck
953+
954+
def decode(self, x: Tensor) -> Tensor:
955+
x = self.post_bottleneck(x)
956+
957+
for upsample in self.upsamples:
958+
x = upsample(x)
959+
960+
return self.to_out(x)
961+
962+
def forward(self, x: Tensor) -> Tensor:
963+
"""Returns autoencoding loss"""
964+
z, kl_loss = self.encode(x, with_kl_loss=True)
965+
y = self.decode(z)
966+
loss = F.mse_loss(x, y) + kl_loss * self.loss_kl_weight
967+
return loss

0 commit comments

Comments
 (0)