Skip to content

Commit 514b8f7

Browse files
feat: add adapter option to DiffusionAE
1 parent a34014f commit 514b8f7

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

audio_diffusion_pytorch/models.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,15 @@
88

99
from .components import AppendChannelsPlugin, MelSpectrogram
1010
from .diffusion import ARVDiffusion, ARVSampler, VDiffusion, VSampler
11-
from .utils import closest_power_2, default, downsample, groupby, randn_like, upsample
11+
from .utils import (
12+
closest_power_2,
13+
default,
14+
downsample,
15+
exists,
16+
groupby,
17+
randn_like,
18+
upsample,
19+
)
1220

1321

1422
class DiffusionModel(nn.Module):
@@ -46,6 +54,18 @@ def __init__(self):
4654
self.downsample_factor = None
4755

4856

57+
class AdapterBase(nn.Module, ABC):
58+
"""Abstract class for DiffusionAE encoder"""
59+
60+
@abstractmethod
61+
def encode(self, x: Tensor) -> Tensor:
62+
pass
63+
64+
@abstractmethod
65+
def decode(self, x: Tensor) -> Tensor:
66+
pass
67+
68+
4969
class DiffusionAE(DiffusionModel):
5070
"""Diffusion Auto Encoder"""
5171

@@ -55,6 +75,8 @@ def __init__(
5575
channels: Sequence[int],
5676
encoder: EncoderBase,
5777
inject_depth: int,
78+
latent_factor: Optional[int] = None,
79+
adapter: Optional[AdapterBase] = None,
5880
**kwargs,
5981
):
6082
context_channels = [0] * len(channels)
@@ -68,12 +90,19 @@ def __init__(
6890
self.in_channels = in_channels
6991
self.encoder = encoder
7092
self.inject_depth = inject_depth
93+
# Optional custom latent factor and adapter
94+
self.latent_factor = default(latent_factor, self.encoder.downsample_factor)
95+
self.adapter = adapter.requires_grad_(False) if exists(adapter) else None
7196

7297
def forward( # type: ignore
7398
self, x: Tensor, with_info: bool = False, **kwargs
7499
) -> Union[Tensor, Tuple[Tensor, Any]]:
100+
# Encode input to latent channels
75101
latent, info = self.encode(x, with_info=True)
76102
channels = [None] * self.inject_depth + [latent]
103+
# Adapt input to diffusion if adapter provided
104+
x = self.adapter.encode(x) if exists(self.adapter) else x
105+
# Compute diffusion loss
77106
loss = super().forward(x, channels=channels, **kwargs)
78107
return (loss, info) if with_info else loss
79108

@@ -85,18 +114,20 @@ def decode(
85114
self, latent: Tensor, generator: Optional[Generator] = None, **kwargs
86115
) -> Tensor:
87116
b = latent.shape[0]
88-
length = closest_power_2(latent.shape[2] * self.encoder.downsample_factor)
117+
noise_length = closest_power_2(latent.shape[2] * self.latent_factor)
89118
# Compute noise by inferring shape from latent length
90119
noise = torch.randn(
91-
(b, self.in_channels, length),
120+
(b, self.in_channels, noise_length),
92121
device=latent.device,
93122
dtype=latent.dtype,
94123
generator=generator,
95124
)
96125
# Compute context from latent
97126
channels = [None] * self.inject_depth + [latent] # type: ignore
98127
# Decode by sampling while conditioning on latent channels
99-
return super().sample(noise, channels=channels, **kwargs)
128+
out = super().sample(noise, channels=channels, **kwargs)
129+
# Decode output with adapter if provided
130+
return self.adapter.decode(out) if exists(self.adapter) else out
100131

101132

102133
class DiffusionUpsampler(DiffusionModel):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.1.1",
6+
version="0.1.2",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)