Skip to content

Commit 508de5b

Browse files
feat: update diffusion autoencoder to multiencoder conditioning
1 parent 3c710ed commit 508de5b

File tree

5 files changed

+121
-65
lines changed

5 files changed

+121
-65
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ repos:
3030
args: [
3131
'--per-file-ignores=__init__.py:F401',
3232
'--max-line-length=88',
33+
'--ignore=E203,W503'
3334
]
3435

3536
# Checks types

audio_diffusion_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@
2121
DiffusionUpsampler1d,
2222
Model1d,
2323
)
24-
from .modules import Encoder1d, UNet1d, UNetConditional1d
24+
from .modules import MultiEncoder1d, UNet1d, UNetConditional1d

audio_diffusion_pytorch/model.py

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
Sampler,
1515
Schedule,
1616
)
17-
from .modules import Encoder1d, ResnetBlock1d, UNet1d, UNetConditional1d
18-
from .utils import default, exists, prod, to_list
17+
from .modules import MultiEncoder1d, UNet1d, UNetConditional1d
18+
from .utils import default, exists, to_list
1919

2020
"""
2121
Diffusion Classes (generic for 1d data)
@@ -117,51 +117,45 @@ def __init__(
117117
kernel_multiplier_downsample: int,
118118
encoder_depth: int,
119119
encoder_channels: int,
120-
context_channels: int,
121120
bottleneck: Optional[Bottleneck] = None,
122121
encoder_num_blocks: Optional[Sequence[int]] = None,
123122
**kwargs
124123
):
125-
super().__init__(
124+
self.in_channels = in_channels
125+
encoder_num_blocks = default(encoder_num_blocks, num_blocks)
126+
assert_message = "The number of encoder_num_blocks must match encoder_depth"
127+
assert len(encoder_num_blocks) >= encoder_depth, assert_message
128+
129+
multiencoder = MultiEncoder1d(
126130
in_channels=in_channels,
127131
channels=channels,
128132
patch_size=patch_size,
129-
kernel_sizes_init=kernel_sizes_init,
133+
num_layers=encoder_depth,
134+
latent_channels=encoder_channels,
130135
multipliers=multipliers,
131136
factors=factors,
132-
num_blocks=num_blocks,
133-
resnet_groups=resnet_groups,
137+
num_blocks=encoder_num_blocks,
138+
kernel_sizes_init=kernel_sizes_init,
134139
kernel_multiplier_downsample=kernel_multiplier_downsample,
135-
context_channels=[0] * encoder_depth + [context_channels],
136-
**kwargs,
140+
resnet_groups=resnet_groups,
137141
)
138142

139-
self.in_channels = in_channels
140-
self.encoder_factor = patch_size * prod(factors[0:encoder_depth])
141-
self.bottleneck = bottleneck
142-
143-
encoder_num_blocks = default(encoder_num_blocks, num_blocks)
144-
assert_message = "The number of encoder_num_blocks must match encoder_depth"
145-
assert len(encoder_num_blocks) >= encoder_depth, assert_message
146-
147-
self.encoder = Encoder1d(
143+
super().__init__(
148144
in_channels=in_channels,
149145
channels=channels,
150146
patch_size=patch_size,
151147
kernel_sizes_init=kernel_sizes_init,
152148
multipliers=multipliers,
153149
factors=factors,
154-
num_blocks=encoder_num_blocks,
150+
num_blocks=num_blocks,
155151
resnet_groups=resnet_groups,
156152
kernel_multiplier_downsample=kernel_multiplier_downsample,
157-
extract_channels=[0] * (encoder_depth - 1) + [encoder_channels],
153+
context_channels=multiencoder.channels_list,
154+
**kwargs,
158155
)
159156

160-
self.to_context_channels = ResnetBlock1d(
161-
in_channels=encoder_channels,
162-
out_channels=context_channels,
163-
num_groups=resnet_groups,
164-
)
157+
self.bottleneck = bottleneck
158+
self.multiencoder = multiencoder
165159

166160
def forward( # type: ignore
167161
self, x: Tensor, with_info: bool = False, **kwargs
@@ -171,28 +165,28 @@ def forward( # type: ignore
171165
else:
172166
latent = self.encode(x)
173167

174-
channels = self.to_context_channels(latent)
175-
loss = self.diffusion(x, channels_list=[channels], **kwargs)
168+
channels_list = self.multiencoder.decode(latent)
169+
loss = self.diffusion(x, channels_list=channels_list, **kwargs)
176170
return (loss, info) if with_info else loss
177171

178172
def encode(
179173
self, x: Tensor, with_info: bool = False
180174
) -> Union[Tensor, Tuple[Tensor, Any]]:
181-
x = self.encoder(x)[-1]
182-
latent = torch.tanh(x)
175+
latent = self.multiencoder.encode(x)
176+
latent = torch.tanh(latent)
183177
# Apply bottleneck if provided (e.g. quantization module)
184178
if exists(self.bottleneck):
185179
latent, info = self.bottleneck(latent)
186180
return (latent, info) if with_info else latent
187181
return latent
188182

189183
def decode(self, latent: Tensor, **kwargs) -> Tensor:
190-
b, length = latent.shape[0], latent.shape[2] * self.encoder_factor
184+
b, length = latent.shape[0], latent.shape[2] * self.multiencoder.factor
191185
# Compute noise by inferring shape from latent length
192186
noise = torch.randn(b, self.in_channels, length).to(latent)
193187
# Compute context form latent
194-
channels = self.to_context_channels(latent)
195-
default_kwargs = dict(channels_list=[channels])
188+
channels_list = self.multiencoder.decode(latent)
189+
default_kwargs = dict(channels_list=channels_list)
196190
# Decode by sampling while conditioning on latent channels
197191
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore
198192

@@ -257,10 +251,7 @@ def sample(self, *args, **kwargs):
257251
class AudioDiffusionAutoencoder(DiffusionAutoencoder1d):
258252
def __init__(self, *args, **kwargs):
259253
default_kwargs = dict(
260-
**get_default_model_kwargs(),
261-
encoder_depth=4,
262-
encoder_channels=32,
263-
context_channels=512,
254+
**get_default_model_kwargs(), encoder_depth=4, encoder_channels=64
264255
)
265256
super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore
266257

audio_diffusion_pytorch/modules.py

Lines changed: 91 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from einops_exts.torch import EinopsToAndFrom
1010
from torch import Tensor, einsum
1111

12-
from .utils import default, exists
12+
from .utils import default, exists, prod
1313

1414
"""
1515
Convolutional Blocks
@@ -666,6 +666,7 @@ def __init__(
666666
use_skip: bool = False,
667667
skip_channels: int = 0,
668668
use_skip_scale: bool = False,
669+
extract_channels: int = 0,
669670
use_attention: bool = False,
670671
attention_heads: Optional[int] = None,
671672
attention_features: Optional[int] = None,
@@ -675,12 +676,7 @@ def __init__(
675676
):
676677
super().__init__()
677678

678-
assert (not use_attention) or (
679-
exists(attention_heads)
680-
and exists(attention_features)
681-
and exists(attention_multiplier)
682-
)
683-
679+
self.use_extract = extract_channels > 0
684680
self.use_pre_upsample = use_pre_upsample
685681
self.use_attention = use_attention
686682
self.use_skip = use_skip
@@ -723,6 +719,14 @@ def __init__(
723719
use_nearest=use_nearest,
724720
)
725721

722+
if self.use_extract:
723+
num_extract_groups = min(num_groups, extract_channels)
724+
self.to_extracted = ResnetBlock1d(
725+
in_channels=out_channels,
726+
out_channels=extract_channels,
727+
num_groups=num_extract_groups,
728+
)
729+
726730
def add_skip(self, x: Tensor, skip: Tensor) -> Tensor:
727731
return torch.cat([x, skip * self.skip_scale], dim=1)
728732

@@ -732,7 +736,7 @@ def forward(
732736
skips: Optional[List[Tensor]] = None,
733737
mapping: Optional[Tensor] = None,
734738
embedding: Optional[Tensor] = None,
735-
) -> Tensor:
739+
) -> Union[Tuple[Tensor, Tensor], Tensor]:
736740

737741
if self.use_pre_upsample:
738742
x = self.upsample(x)
@@ -747,6 +751,10 @@ def forward(
747751
if not self.use_pre_upsample:
748752
x = self.upsample(x)
749753

754+
if self.use_extract:
755+
extracted = self.to_extracted(x)
756+
return x, extracted
757+
750758
return x
751759

752760

@@ -1144,11 +1152,11 @@ def forward( # type: ignore
11441152

11451153

11461154
"""
1147-
Encoder
1155+
Encoders / Decoders
11481156
"""
11491157

11501158

1151-
class Encoder1d(nn.Module):
1159+
class MultiEncoder1d(nn.Module):
11521160
def __init__(
11531161
self,
11541162
in_channels: int,
@@ -1157,18 +1165,17 @@ def __init__(
11571165
resnet_groups: int,
11581166
kernel_multiplier_downsample: int,
11591167
kernel_sizes_init: Sequence[int],
1168+
num_layers: int,
1169+
latent_channels: int,
11601170
multipliers: Sequence[int],
11611171
factors: Sequence[int],
11621172
num_blocks: Sequence[int],
1163-
extract_channels: Sequence[int],
11641173
):
11651174
super().__init__()
1166-
1167-
num_layers = len(extract_channels)
1168-
self.num_layers = num_layers
1169-
1170-
use_extract = [channels > 0 for channels in extract_channels]
1171-
self.use_extract = use_extract
1175+
self.factor = patch_size * prod(factors[0:num_layers])
1176+
self.channels_list = self.get_channels_list(
1177+
in_channels, channels, multipliers, num_layers
1178+
)
11721179

11731180
assert (
11741181
len(multipliers) >= num_layers + 1
@@ -1195,21 +1202,78 @@ def __init__(
11951202
kernel_multiplier=kernel_multiplier_downsample,
11961203
num_groups=resnet_groups,
11971204
num_layers=num_blocks[i],
1198-
extract_channels=extract_channels[i],
11991205
)
12001206
for i in range(num_layers)
12011207
]
12021208
)
12031209

1204-
def forward(self, x: Tensor) -> List[Tensor]:
1205-
x = self.to_in(x)
1206-
channels_list = []
1210+
pre_latent_channels = channels * multipliers[num_layers]
1211+
1212+
self.to_latent = ResnetBlock1d(
1213+
in_channels=pre_latent_channels,
1214+
out_channels=latent_channels,
1215+
num_groups=resnet_groups,
1216+
)
1217+
1218+
self.from_latent = ResnetBlock1d(
1219+
in_channels=latent_channels,
1220+
out_channels=pre_latent_channels,
1221+
num_groups=resnet_groups,
1222+
)
12071223

1208-
for downsample, use_extract in zip(self.downsamples, self.use_extract):
1209-
if use_extract:
1210-
x, channels = downsample(x)
1211-
channels_list += [channels]
1212-
else:
1213-
x = downsample(x)
1224+
self.upsamples = nn.ModuleList(
1225+
[
1226+
UpsampleBlock1d(
1227+
in_channels=channels * multipliers[i + 1],
1228+
out_channels=channels * multipliers[i],
1229+
factor=factors[i],
1230+
num_groups=resnet_groups,
1231+
num_layers=num_blocks[i],
1232+
use_nearest=False,
1233+
use_skip=False,
1234+
extract_channels=channels * multipliers[i],
1235+
)
1236+
for i in reversed(range(num_layers))
1237+
]
1238+
)
12141239

1240+
self.to_out = nn.Sequential(
1241+
ResnetBlock1d(
1242+
in_channels=channels, out_channels=channels, num_groups=resnet_groups
1243+
),
1244+
Conv1d(
1245+
in_channels=channels,
1246+
out_channels=in_channels * patch_size,
1247+
kernel_size=1,
1248+
),
1249+
Rearrange("b (c p) l -> b c (l p)", p=patch_size),
1250+
)
1251+
1252+
def get_channels_list(
1253+
self,
1254+
in_channels: int,
1255+
channels: int,
1256+
multipliers: Sequence[int],
1257+
num_layers: int,
1258+
) -> List[int]:
1259+
channels_list = [in_channels]
1260+
channels_list += [channels * m for m in multipliers[1 : num_layers + 1]]
12151261
return channels_list
1262+
1263+
def encode(self, x: Tensor) -> Tensor:
1264+
x = self.to_in(x)
1265+
for downsample in self.downsamples:
1266+
x = downsample(x)
1267+
latent = self.to_latent(x)
1268+
return latent
1269+
1270+
def decode(self, latent: Tensor) -> List[Tensor]:
1271+
x = self.from_latent(latent)
1272+
channels_list = []
1273+
channels = x
1274+
for upsample in self.upsamples:
1275+
channels_list += [channels]
1276+
x, channels = upsample(x)
1277+
x = self.to_out(x)
1278+
channels_list += [x]
1279+
return channels_list[::-1]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.35",
6+
version="0.0.36",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)