Skip to content

Commit 21014f9

Browse files
feat: add diffusion vocoder, add diffusion upphaser
1 parent 22e5d75 commit 21014f9

File tree

5 files changed

+134
-18
lines changed

5 files changed

+134
-18
lines changed

audio_diffusion_pytorch/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,13 @@
2323
AudioDiffusionAutoencoder,
2424
AudioDiffusionConditional,
2525
AudioDiffusionModel,
26+
AudioDiffusionUpphaser,
2627
AudioDiffusionUpsampler,
28+
AudioDiffusionVocoder,
2729
DiffusionAutoencoder1d,
30+
DiffusionUpphaser1d,
2831
DiffusionUpsampler1d,
32+
DiffusionVocoder1d,
2933
Model1d,
3034
)
3135
from .modules import (

audio_diffusion_pytorch/model.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from math import pi
12
from typing import Any, Optional, Sequence, Tuple, Union
23

34
import torch
5+
from einops import rearrange
46
from torch import Tensor, nn
57

68
from .diffusion import (
@@ -15,6 +17,7 @@
1517
VSampler,
1618
)
1719
from .modules import (
20+
STFT,
1821
Bottleneck,
1922
MultiEncoder1d,
2023
SinusoidalEmbedding,
@@ -223,6 +226,62 @@ def decode(self, latent: Tensor, **kwargs) -> Tensor:
223226
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore
224227

225228

229+
class DiffusionVocoder1d(Model1d):
230+
def __init__(
231+
self,
232+
in_channels: int,
233+
vocoder_num_fft: int,
234+
**kwargs,
235+
):
236+
self.frequency_channels = vocoder_num_fft // 2 + 1
237+
spectrogram_channels = in_channels * self.frequency_channels
238+
239+
vocoder_kwargs, kwargs = groupby_kwargs_prefix("vocoder_", kwargs)
240+
default_kwargs = dict(
241+
in_channels=spectrogram_channels, context_channels=[spectrogram_channels]
242+
)
243+
244+
super().__init__(**{**default_kwargs, **kwargs}) # type: ignore
245+
self.stft = STFT(num_fft=vocoder_num_fft, **vocoder_kwargs)
246+
247+
def forward(self, x: Tensor, **kwargs) -> Tensor:
248+
# Get magnitude and phase of true wave
249+
magnitude, phase = self.stft.encode(x)
250+
magnitude = rearrange(magnitude, "b c f t -> b (c f) t")
251+
phase = rearrange(phase, "b c f t -> b (c f) t")
252+
# Get diffusion phase loss while conditioning on magnitude (/pi [-1,1] range)
253+
return self.diffusion(phase / pi, channels_list=[magnitude], **kwargs)
254+
255+
def sample(self, spectrogram: Tensor, **kwargs): # type: ignore
256+
b, c, f, t, device = *spectrogram.shape, spectrogram.device
257+
magnitude = rearrange(spectrogram, "b c f t -> b (c f) t")
258+
noise = torch.randn((b, c * f, t), device=device)
259+
default_kwargs = dict(channels_list=[magnitude])
260+
phase = super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore # noqa
261+
phase = rearrange(phase, "b (c f) t -> b c f t", c=c)
262+
wave = self.stft.decode(spectrogram, phase * pi)
263+
return wave
264+
265+
266+
class DiffusionUpphaser1d(DiffusionUpsampler1d):
267+
def __init__(self, **kwargs):
268+
vocoder_kwargs, kwargs = groupby_kwargs_prefix("vocoder_", kwargs)
269+
super().__init__(**kwargs)
270+
self.stft = STFT(**vocoder_kwargs)
271+
272+
def random_rephase(self, x: Tensor) -> Tensor:
273+
magnitude, phase = self.stft.encode(x)
274+
phase_random = (torch.rand_like(phase) - 0.5) * 2 * pi
275+
wave = self.stft.decode(magnitude, phase_random)
276+
return wave
277+
278+
def forward(self, x: Tensor, **kwargs) -> Tensor:
279+
rephased = self.random_rephase(x)
280+
resampled, factors = self.random_reupsample(rephased)
281+
features = self.to_features(factors) if self.use_conditioning else None
282+
return self.diffusion(x, channels_list=[resampled], features=features, **kwargs)
283+
284+
226285
"""
227286
Audio Diffusion Classes (specific for 1d audio data)
228287
"""
@@ -315,3 +374,49 @@ def sample(self, *args, **kwargs):
315374
embedding_scale=5.0,
316375
)
317376
return super().sample(*args, **{**default_kwargs, **kwargs})
377+
378+
379+
class AudioDiffusionVocoder(DiffusionVocoder1d):
380+
def __init__(self, in_channels: int, **kwargs):
381+
default_kwargs = dict(
382+
in_channels=in_channels,
383+
vocoder_num_fft=1023,
384+
channels=32,
385+
patch_blocks=1,
386+
patch_factor=1,
387+
multipliers=[64, 32, 16, 8, 4, 2, 1],
388+
factors=[1, 1, 1, 1, 1, 1],
389+
num_blocks=[1, 1, 1, 1, 1, 1],
390+
attentions=[0, 0, 0, 1, 1, 1],
391+
attention_heads=8,
392+
attention_features=64,
393+
attention_multiplier=2,
394+
attention_use_rel_pos=False,
395+
resnet_groups=8,
396+
kernel_multiplier_downsample=2,
397+
use_nearest_upsample=False,
398+
use_skip_scale=True,
399+
use_context_time=True,
400+
use_magnitude_channels=False,
401+
diffusion_type="v",
402+
diffusion_sigma_distribution=UniformDistribution(),
403+
)
404+
super().__init__(**{**default_kwargs, **kwargs}) # type: ignore
405+
406+
def sample(self, *args, **kwargs):
407+
default_kwargs = dict(**get_default_sampling_kwargs())
408+
return super().sample(*args, **{**default_kwargs, **kwargs})
409+
410+
411+
class AudioDiffusionUpphaser(DiffusionUpphaser1d):
412+
def __init__(self, in_channels: int, **kwargs):
413+
default_kwargs = dict(
414+
**get_default_model_kwargs(),
415+
in_channels=in_channels,
416+
context_channels=[in_channels],
417+
factor=1,
418+
)
419+
super().__init__(**{**default_kwargs, **kwargs}) # type: ignore
420+
421+
def sample(self, *args, **kwargs):
422+
return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})

audio_diffusion_pytorch/modules.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import math
2-
from math import pi
1+
from math import floor, log, pi
32
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
43

54
import torch
@@ -9,7 +8,7 @@
98
from einops_exts import rearrange_many
109
from torch import Tensor, einsum
1110

12-
from .utils import default, exists, prod, to_list
11+
from .utils import closest_power_2, default, exists, prod, to_list
1312

1413
"""
1514
Utils
@@ -338,7 +337,7 @@ def _relative_position_bucket(
338337
max_exact
339338
+ (
340339
torch.log(n.float() / max_exact)
341-
/ math.log(max_distance / max_exact)
340+
/ log(max_distance / max_exact)
342341
* (num_buckets - max_exact)
343342
).long()
344343
)
@@ -587,7 +586,7 @@ def __init__(self, dim: int):
587586

588587
def forward(self, x: Tensor) -> Tensor:
589588
device, half_dim = x.device, self.dim // 2
590-
emb = torch.tensor(math.log(10000) / (half_dim - 1), device=device)
589+
emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
591590
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
592591
emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
593592
return torch.cat((emb.sin(), emb.cos()), dim=-1)
@@ -1692,17 +1691,17 @@ def decode(self, latent: Tensor) -> List[Tensor]:
16921691
class STFT(nn.Module):
16931692
def __init__(
16941693
self,
1695-
length: int,
1696-
num_fft: int = 1024,
1697-
hop_length: int = 256,
1698-
window_length: int = 1024,
1694+
num_fft: int = 1023,
1695+
hop_length: Optional[int] = None,
1696+
window_length: Optional[int] = None,
1697+
length: Optional[int] = None,
16991698
):
17001699
super().__init__()
17011700
self.num_fft = num_fft
1702-
self.hop_length = hop_length
1703-
self.window_length = window_length
1701+
self.hop_length = default(hop_length, floor(num_fft // 4))
1702+
self.window_length = default(window_length, num_fft)
17041703
self.length = length
1705-
self.register_buffer("window", torch.hann_window(window_length))
1704+
self.register_buffer("window", torch.hann_window(self.window_length))
17061705

17071706
def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
17081707
b = wave.shape[0]
@@ -1725,19 +1724,20 @@ def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
17251724
return mag, phase
17261725

17271726
def decode(self, magnitude: Tensor, phase: Tensor) -> Tensor:
1728-
b = magnitude.shape[0]
1727+
b, l = magnitude.shape[0], magnitude.shape[-1] # noqa
17291728
assert magnitude.shape == phase.shape, "magnitude and phase must be same shape"
17301729
real = rearrange(magnitude * torch.cos(phase), "b c f l -> (b c) f l")
17311730
imag = rearrange(magnitude * torch.sin(phase), "b c f l -> (b c) f l")
17321731
stft = torch.stack([real, imag], dim=-1)
1732+
length = closest_power_2(l * self.hop_length)
17331733

17341734
wave = torch.istft(
17351735
stft,
17361736
n_fft=self.num_fft,
17371737
hop_length=self.hop_length,
17381738
win_length=self.window_length,
17391739
window=self.window, # type: ignore
1740-
length=self.length,
1740+
length=default(self.length, length),
17411741
)
17421742
wave = rearrange(wave, "(b c) t -> b c t", b=b)
17431743
return wave

audio_diffusion_pytorch/utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import math
21
from functools import reduce
32
from inspect import isfunction
3+
from math import ceil, floor, log2, pi
44
from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
55

66
import torch
@@ -42,6 +42,13 @@ def prod(vals: Sequence[int]) -> int:
4242
return reduce(lambda x, y: x * y, vals)
4343

4444

45+
def closest_power_2(x: float) -> int:
46+
exponent = log2(x)
47+
distance_fn = lambda z: abs(x - 2 ** z) # noqa
48+
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
49+
return 2 ** int(exponent_closest)
50+
51+
4552
"""
4653
Kwargs Utils
4754
"""
@@ -79,10 +86,10 @@ def resample(
7986
d = dict(device=waveforms.device, dtype=waveforms.dtype)
8087

8188
base_factor = min(factor_in, factor_out) * rolloff
82-
width = math.ceil(lowpass_filter_width * factor_in / base_factor)
89+
width = ceil(lowpass_filter_width * factor_in / base_factor)
8390
idx = torch.arange(-width, width + factor_in, **d)[None, None] / factor_in # type: ignore # noqa
8491
t = torch.arange(0, -factor_out, step=-1, **d)[:, None, None] / factor_out + idx # type: ignore # noqa
85-
t = (t * base_factor).clamp(-lowpass_filter_width, lowpass_filter_width) * math.pi
92+
t = (t * base_factor).clamp(-lowpass_filter_width, lowpass_filter_width) * pi
8693

8794
window = torch.cos(t / lowpass_filter_width / 2) ** 2
8895
scale = base_factor / factor_in

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.79",
6+
version="0.0.80",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)