Skip to content

Commit 4c5ec08

Browse files
feat: unconstrain diffae encoder, add x-diffusion
1 parent 015b152 commit 4c5ec08

File tree

4 files changed

+101
-87
lines changed

4 files changed

+101
-87
lines changed

README.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,7 @@ upsampled = upsampler.sample(
6464
```py
6565
from audio_diffusion_pytorch import AudioDiffusionAutoencoder
6666

67-
autoencoder = AudioDiffusionAutoencoder(
68-
in_channels=1,
69-
encoder_depth=4
70-
)
67+
autoencoder = AudioDiffusionAutoencoder(in_channels=1)
7168

7269
# Train on audio samples
7370
x = torch.randn(2, 1, 2 ** 18)

audio_diffusion_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
VKDiffusion,
1919
VKDistribution,
2020
VSampler,
21+
XDiffusion,
2122
)
2223
from .model import (
2324
AudioDiffusionAutoencoder,

audio_diffusion_pytorch/diffusion.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,3 +654,39 @@ def forward(self, start: Tensor, keep_start: bool = False) -> Tensor:
654654
spans.append(second_half)
655655

656656
return torch.cat(spans, dim=2)
657+
658+
659+
class XDiffusion(nn.Module):
660+
def __init__(self, type: str, net: nn.Module, **kwargs):
661+
super().__init__()
662+
663+
diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion]
664+
aliases = [t.alias for t in diffusion_classes] # type: ignore
665+
message = f"type='{type}' must be one of {*aliases,}"
666+
assert type in aliases, message
667+
self.net = net
668+
669+
for XDiffusion in diffusion_classes:
670+
if XDiffusion.alias == type: # type: ignore
671+
self.diffusion = XDiffusion(net=net, **kwargs)
672+
673+
def forward(self, *args, **kwargs) -> Tensor:
674+
return self.diffusion(*args, **kwargs)
675+
676+
def sample(
677+
self,
678+
noise: Tensor,
679+
num_steps: int,
680+
sigma_schedule: Schedule,
681+
sampler: Sampler,
682+
clamp: bool,
683+
**kwargs,
684+
) -> Tensor:
685+
diffusion_sampler = DiffusionSampler(
686+
diffusion=self.diffusion,
687+
sampler=sampler,
688+
sigma_schedule=sigma_schedule,
689+
num_steps=num_steps,
690+
clamp=clamp,
691+
)
692+
return diffusion_sampler(noise, **kwargs)

audio_diffusion_pytorch/model.py

Lines changed: 63 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,8 @@
66
from einops import rearrange
77
from torch import Tensor, nn
88

9-
from .diffusion import (
10-
DiffusionSampler,
11-
KDiffusion,
12-
LinearSchedule,
13-
Sampler,
14-
Schedule,
15-
UniformDistribution,
16-
VDiffusion,
17-
VKDiffusion,
18-
VSampler,
19-
)
20-
from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetConditional1d
9+
from .diffusion import LinearSchedule, UniformDistribution, VSampler, XDiffusion
10+
from .modules import STFT, Conv1d, SinusoidalEmbedding, UNet1d, UNetConditional1d
2111
from .utils import (
2212
default,
2313
downsample,
@@ -44,36 +34,15 @@ def __init__(
4434
UNet = UNetConditional1d if use_classifier_free_guidance else UNet1d
4535
self.unet = UNet(**kwargs)
4636

47-
# Check valid diffusion type
48-
diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion]
49-
aliases = [t.alias for t in diffusion_classes] # type: ignore
50-
message = f"diffusion_type='{diffusion_type}' must be one of {*aliases,}"
51-
assert diffusion_type in aliases, message
52-
53-
for XDiffusion in diffusion_classes:
54-
if XDiffusion.alias == diffusion_type: # type: ignore
55-
self.diffusion = XDiffusion(net=self.unet, **diffusion_kwargs)
37+
self.diffusion = XDiffusion(
38+
type=diffusion_type, net=self.unet, **diffusion_kwargs
39+
)
5640

5741
def forward(self, x: Tensor, **kwargs) -> Tensor:
5842
return self.diffusion(x, **kwargs)
5943

60-
def sample(
61-
self,
62-
noise: Tensor,
63-
num_steps: int,
64-
sigma_schedule: Schedule,
65-
sampler: Sampler,
66-
clamp: bool,
67-
**kwargs,
68-
) -> Tensor:
69-
diffusion_sampler = DiffusionSampler(
70-
diffusion=self.diffusion,
71-
sampler=sampler,
72-
sigma_schedule=sigma_schedule,
73-
num_steps=num_steps,
74-
clamp=clamp,
75-
)
76-
return diffusion_sampler(noise, **kwargs)
44+
def sample(self, *args, **kwargs) -> Tensor:
45+
return self.diffusion.sample(*args, **kwargs)
7746

7847

7948
class DiffusionUpsampler1d(Model1d):
@@ -139,69 +108,70 @@ def sample( # type: ignore
139108
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore
140109

141110

142-
class DiffusionAutoencoder1d(Model1d):
111+
class DiffusionAutoencoder1d(nn.Module):
143112
def __init__(
144113
self,
145114
in_channels: int,
146-
channels: int,
147-
patch_blocks: int,
148-
patch_factor: int,
149-
multipliers: Sequence[int],
150-
factors: Sequence[int],
151-
num_blocks: Sequence[int],
152-
resnet_groups: int,
153-
kernel_multiplier_downsample: int,
154-
encoder_depth: int,
155-
encoder_num_blocks: Optional[Sequence[int]] = None,
115+
encoder_inject_depth: int,
116+
encoder_channels: int,
117+
encoder_factors: Sequence[int],
118+
encoder_multipliers: Sequence[int],
119+
diffusion_type: str,
120+
encoder_patch_size: int = 1,
156121
bottleneck: Union[Bottleneck, Sequence[Bottleneck]] = [],
157122
bottleneck_channels: Optional[int] = None,
158-
use_stft: bool = False,
159123
**kwargs,
160124
):
125+
super().__init__()
161126
self.in_channels = in_channels
162-
encoder_num_blocks = default(encoder_num_blocks, num_blocks)
163-
assert_message = "The number of encoder_num_blocks must match encoder_depth"
164-
assert len(encoder_num_blocks) >= encoder_depth, assert_message
165-
assert patch_blocks == 1, "patch_blocks != 1 not supported"
166-
assert not use_stft, "use_stft not supported"
167-
self.factor = patch_factor * prod(factors[0:encoder_depth])
168-
169-
context_channels = [0] * encoder_depth
127+
128+
encoder_kwargs, kwargs = groupby("encoder_", kwargs)
129+
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
130+
131+
# Compute context channels
132+
context_channels = [0] * encoder_inject_depth
170133
if exists(bottleneck_channels):
171134
context_channels += [bottleneck_channels]
172135
else:
173-
context_channels += [channels * multipliers[encoder_depth]]
136+
context_channels += [encoder_channels * encoder_multipliers[-1]]
174137

175-
super().__init__(
176-
in_channels=in_channels,
177-
channels=channels,
178-
patch_blocks=patch_blocks,
179-
patch_factor=patch_factor,
180-
multipliers=multipliers,
181-
factors=factors,
182-
num_blocks=num_blocks,
183-
resnet_groups=resnet_groups,
184-
kernel_multiplier_downsample=kernel_multiplier_downsample,
185-
context_channels=context_channels,
186-
**kwargs,
138+
self.unet = UNet1d(
139+
in_channels=in_channels, context_channels=context_channels, **kwargs
140+
)
141+
142+
self.diffusion = XDiffusion(
143+
type=diffusion_type, net=self.unet, **diffusion_kwargs
187144
)
188145

189-
self.bottlenecks = nn.ModuleList(to_list(bottleneck))
190146
self.encoder = Encoder1d(
191147
in_channels=in_channels,
192-
channels=channels,
193-
patch_size=patch_factor,
194-
multipliers=multipliers[0 : encoder_depth + 1],
195-
factors=factors[0:encoder_depth],
196-
num_blocks=encoder_num_blocks[0:encoder_depth],
197-
resnet_groups=resnet_groups,
148+
channels=encoder_channels,
149+
patch_size=encoder_patch_size,
150+
factors=encoder_factors,
151+
multipliers=encoder_multipliers,
198152
out_channels=bottleneck_channels,
153+
**encoder_kwargs,
199154
)
200155

156+
if exists(bottleneck_channels):
157+
self.to_bottleneck = Conv1d(
158+
in_channels=encoder_channels * encoder_multipliers[-1],
159+
out_channels=bottleneck_channels,
160+
kernel_size=1,
161+
)
162+
163+
self.encoder_downsample_factor = encoder_patch_size * prod(encoder_factors)
164+
self.bottleneck_channels = bottleneck_channels
165+
self.bottlenecks = nn.ModuleList(to_list(bottleneck))
166+
201167
def encode(
202168
self, x: Tensor, with_info: bool = False
203169
) -> Union[Tensor, Tuple[Tensor, Any]]:
204170
latent, info = self.encoder(x, with_info=True)
171+
# Convert latent channels
172+
if exists(self.bottleneck_channels):
173+
latent = self.to_bottleneck(latent)
174+
# Apply bottlenecks if present
205175
for bottleneck in self.bottlenecks:
206176
latent, info_bottleneck = bottleneck(latent, with_info=True)
207177
info = {**info, **prefix_dict("bottleneck_", info_bottleneck)}
@@ -215,13 +185,17 @@ def forward( # type: ignore
215185
return (loss, info) if with_info else loss
216186

217187
def decode(self, latent: Tensor, **kwargs) -> Tensor:
218-
b, length = latent.shape[0], latent.shape[2] * self.factor
188+
b = latent.shape[0]
189+
length = latent.shape[2] * self.encoder_downsample_factor
219190
# Compute noise by inferring shape from latent length
220-
noise = torch.randn(b, self.in_channels, length).to(latent)
191+
noise = torch.randn(b, self.in_channels, length, device=latent.device)
221192
# Compute context form latent
222193
default_kwargs = dict(channels_list=[latent])
223194
# Decode by sampling while conditioning on latent channels
224-
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore
195+
return self.sample(noise, **{**default_kwargs, **kwargs})
196+
197+
def sample(self, *args, **kwargs) -> Tensor:
198+
return self.diffusion.sample(*args, **kwargs)
225199

226200

227201
class DiffusionVocoder1d(Model1d):
@@ -339,7 +313,14 @@ def sample(self, *args, **kwargs):
339313
class AudioDiffusionAutoencoder(DiffusionAutoencoder1d):
340314
def __init__(self, *args, **kwargs):
341315
default_kwargs = dict(
342-
**get_default_model_kwargs(), encoder_depth=4, encoder_channels=64
316+
**get_default_model_kwargs(),
317+
encoder_inject_depth=6,
318+
encoder_channels=16,
319+
encoder_patch_size=16,
320+
encoder_multipliers=[1, 2, 4, 4, 4, 4, 4],
321+
encoder_factors=[4, 4, 4, 2, 2, 2],
322+
encoder_num_blocks=[2, 2, 2, 2, 2, 2],
323+
bottleneck_channels=64,
343324
)
344325
super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore
345326

@@ -398,7 +379,6 @@ def __init__(self, in_channels: int, **kwargs):
398379
use_nearest_upsample=False,
399380
use_skip_scale=True,
400381
use_context_time=True,
401-
use_magnitude_channels=False,
402382
diffusion_type="v",
403383
diffusion_sigma_distribution=UniformDistribution(),
404384
)

0 commit comments

Comments
 (0)