Skip to content

Commit cfe358f

Browse files
feat: add tanh bottleneck, option to use multiple bottlenecks
1 parent e22b4e9 commit cfe358f

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

audio_diffusion_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
AutoEncoder1d,
3333
MultiEncoder1d,
3434
T5Embedder,
35+
Tanh,
3536
UNet1d,
3637
UNetConditional1d,
3738
Variational,

audio_diffusion_pytorch/modules.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from math import pi
3-
from typing import Any, List, Optional, Sequence, Tuple, Union
3+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
44

55
import torch
66
import torch.nn as nn
@@ -9,7 +9,7 @@
99
from einops_exts import rearrange_many
1010
from torch import Tensor, einsum
1111

12-
from .utils import default, exists, prod
12+
from .utils import default, exists, prod, to_list
1313

1414
"""
1515
Utils
@@ -1341,6 +1341,15 @@ def forward(
13411341
return (out, dict(loss=loss, mean=mean, logvar=logvar)) if with_info else out
13421342

13431343

1344+
class Tanh(Bottleneck):
1345+
def forward(
1346+
self, x: Tensor, with_info: bool = False
1347+
) -> Union[Tensor, Tuple[Tensor, Any]]:
1348+
x = torch.tanh(x)
1349+
info: Dict = dict()
1350+
return (x, info) if with_info else x
1351+
1352+
13441353
class AutoEncoder1d(nn.Module):
13451354
def __init__(
13461355
self,
@@ -1353,12 +1362,12 @@ def __init__(
13531362
factors: Sequence[int],
13541363
num_blocks: Sequence[int],
13551364
use_noisy: bool = False,
1356-
bottleneck: Optional[Bottleneck] = None,
1365+
bottleneck: Union[Bottleneck, List[Bottleneck]] = [],
13571366
use_magnitude_channels: bool = False,
13581367
):
13591368
super().__init__()
13601369
num_layers = len(multipliers) - 1
1361-
self.bottleneck = bottleneck
1370+
self.bottlenecks = to_list(bottleneck)
13621371
self.use_noisy = use_noisy
13631372
self.use_magnitude_channels = use_magnitude_channels
13641373

@@ -1424,8 +1433,8 @@ def encode(
14241433
xs += [x]
14251434
info = dict(xs=xs)
14261435

1427-
if exists(self.bottleneck):
1428-
x, info_bottleneck = self.bottleneck(x, with_info=True)
1436+
for bottleneck in self.bottlenecks:
1437+
x, info_bottleneck = bottleneck(x, with_info=True)
14291438
info = {**info, **info_bottleneck}
14301439

14311440
return (x, info) if with_info else x

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.70",
6+
version="0.0.71",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)