Skip to content

Commit 365d0cd

Browse files
feat: add audio ddsp
1 parent d51d9a9 commit 365d0cd

File tree

2 files changed

+249
-1
lines changed

2 files changed

+249
-1
lines changed

audio_diffusion_pytorch/ddsp.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
""" Audio DDSP inspired by https://github.com/acids-ircam/RAVE """
2+
3+
from math import ceil, log2, pi, prod
4+
from typing import Sequence
5+
6+
import numpy as np
7+
import torch
8+
import torch.nn as nn
9+
from einops import rearrange
10+
from scipy.optimize import fmin
11+
from scipy.signal import firwin, kaiserord
12+
from torch import Tensor
13+
from torch.nn import functional as F
14+
15+
from .modules import Conv1d, ConvBlock1d
16+
17+
18+
def reverse_half(x: Tensor) -> Tensor:
19+
mask = torch.ones_like(x)
20+
mask[..., 1::2, ::2] = -1
21+
return x * mask
22+
23+
24+
def center_pad_next_pow_2(x: Tensor) -> Tensor:
25+
next_2 = 2 ** ceil(log2(x.shape[-1]))
26+
pad = next_2 - x.shape[-1]
27+
return F.pad(x, (pad // 2, pad // 2 + int(pad % 2)))
28+
29+
30+
def get_qmf_bank(h: Tensor, nun_bands: int) -> Tensor:
31+
"""
32+
Modulates an input protoype filter into a bank of cosine modulated filters
33+
h: prototype filter
34+
nun_bands: number of sub-bands
35+
"""
36+
k = torch.arange(nun_bands).reshape(-1, 1)
37+
N = h.shape[-1]
38+
t = torch.arange(-(N // 2), N // 2 + 1)
39+
40+
p = (-1) ** k * pi / 4
41+
42+
mod = torch.cos((2 * k + 1) * pi / (2 * nun_bands) * t + p)
43+
hk = 2 * h * mod
44+
45+
return hk
46+
47+
48+
def kaiser_filter(wc: float, attenuation: float) -> np.ndarray:
49+
"""
50+
wc: Angular frequency
51+
attenuation: Attenuation (dB, positive)
52+
"""
53+
N, beta = kaiserord(attenuation, wc / np.pi)
54+
N = 2 * (N // 2) + 1
55+
h = firwin(N, wc, window=("kaiser", beta), scale=False, nyq=np.pi)
56+
return h
57+
58+
59+
def loss_wc(wc: float, attenuation: float, num_bands: int) -> np.ndarray:
60+
"""
61+
Computes the objective described in https://ieeexplore.ieee.org/document/681427
62+
"""
63+
h = kaiser_filter(wc, attenuation)
64+
g = np.convolve(h, h[::-1], "full") # type: ignore
65+
start_idx = g.shape[-1] // 2
66+
stride = 2 * num_bands
67+
g = abs(g[start_idx::stride][1:])
68+
return np.max(g)
69+
70+
71+
def get_prototype(attenuation: float, num_bands: int) -> np.ndarray:
72+
"""
73+
Returns the corresponding lowpass filter
74+
"""
75+
wc = fmin(lambda w: loss_wc(w, attenuation, num_bands), 1.0 / num_bands, disp=0)[0]
76+
return kaiser_filter(wc, attenuation)
77+
78+
79+
def polyphase_forward(x: Tensor, hk: Tensor) -> Tensor:
80+
"""
81+
x: [b, 1, t]
82+
hk: filter bank [m, t]
83+
"""
84+
x = rearrange(x, "b c (t m) -> b (c m) t", m=hk.shape[0])
85+
hk = rearrange(hk, "c (t m) -> c m t", m=hk.shape[0])
86+
x = F.conv1d(x, hk, padding=hk.shape[-1] // 2)[..., :-1]
87+
return x
88+
89+
90+
def polyphase_inverse(x: Tensor, hk: Tensor) -> Tensor:
91+
"""
92+
x: signal to synthesize from [b, 1, t]
93+
hk: filter bank [m, t]
94+
"""
95+
m = hk.shape[0]
96+
97+
hk = hk.flip(-1)
98+
hk = rearrange(hk, "c (t m) -> m c t", m=m) # polyphase
99+
100+
pad = hk.shape[-1] // 2 + 1
101+
x = F.conv1d(x, hk, padding=int(pad))[..., :-1] * m
102+
103+
x = x.flip(1)
104+
x = rearrange(x, "b (c m) t -> b c (t m)", m=m)
105+
start_idx = 2 * hk.shape[1]
106+
x = x[..., start_idx:]
107+
return x
108+
109+
110+
def amp_to_impulse_response(amp: Tensor, target_size: int) -> Tensor:
111+
"""
112+
Transforms frequecny amps to ir on the last dimension
113+
"""
114+
# Set complex part to zero
115+
amp = torch.stack([amp, torch.zeros_like(amp)], -1)
116+
amp = torch.view_as_complex(amp)
117+
# Compute irrt i.e. fourier domain => real-valued amplitude domain
118+
amp = torch.fft.irfft(amp)
119+
#
120+
filter_size = amp.shape[-1]
121+
amp = torch.roll(amp, filter_size // 2, -1)
122+
123+
win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device)
124+
amp = amp * win
125+
126+
amp = F.pad(amp, (0, int(target_size) - int(filter_size)))
127+
amp = torch.roll(amp, -filter_size // 2, -1)
128+
129+
return amp
130+
131+
132+
def fft_convolve(signal: Tensor, kernel: Tensor) -> Tensor:
133+
"""
134+
convolves signal by kernel on the last dimension
135+
"""
136+
signal = F.pad(signal, (0, signal.shape[-1]))
137+
kernel = F.pad(kernel, (kernel.shape[-1], 0))
138+
139+
output = torch.fft.irfft(torch.fft.rfft(signal) * torch.fft.rfft(kernel))
140+
start_idx = output.shape[-1] // 2
141+
output = output[..., start_idx:]
142+
143+
return output
144+
145+
146+
def scaled_simgoid(x: Tensor) -> Tensor:
147+
return 2 * torch.sigmoid(x) ** 2.3 + 1e-7
148+
149+
150+
class PQMF(nn.Module):
151+
def __init__(self, attenuation: float, num_bands: int):
152+
super().__init__()
153+
self.num_bands = num_bands
154+
assert log2(num_bands).is_integer(), "num_bands must be a power of 2"
155+
156+
h = get_prototype(attenuation, num_bands)
157+
hk = get_qmf_bank(torch.from_numpy(h).float(), num_bands)
158+
hk = center_pad_next_pow_2(hk)
159+
print(hk.shape)
160+
self.register_buffer("hk", hk)
161+
162+
def forward(self, x):
163+
b, _, _ = x.shape
164+
x = rearrange(x, "b c t -> (b c) 1 t")
165+
x = polyphase_forward(x, self.hk)
166+
x = reverse_half(x)
167+
x = rearrange(x, "(b c) k t -> b (c k) t", b=b)
168+
return x
169+
170+
def inverse(self, x):
171+
b, k = x.shape[0], self.num_bands
172+
x = rearrange(x, "b (c k) t -> (b c) k t", k=k)
173+
x = reverse_half(x)
174+
x = polyphase_inverse(x, self.hk)
175+
x = rearrange(x, "(b c) 1 t -> b c t", b=b)
176+
return x
177+
178+
179+
class AudioProcessor(nn.Module):
180+
def __init__(
181+
self,
182+
in_channels: int,
183+
channels: int,
184+
pqmf_bands: int,
185+
pqmf_attenuation: float,
186+
noise_bands: int,
187+
noise_ratios: Sequence[int],
188+
):
189+
super().__init__()
190+
191+
pqmf_channels = in_channels * pqmf_bands
192+
amp_channels = [channels] * len(noise_ratios) + [pqmf_channels * noise_bands]
193+
194+
self.noise_bands = noise_bands
195+
self.noise_multiplier = prod(noise_ratios)
196+
197+
self.pqmf = PQMF(num_bands=pqmf_bands, attenuation=pqmf_attenuation)
198+
199+
# Input processing
200+
201+
self.to_in = Conv1d(
202+
in_channels=pqmf_channels, out_channels=channels, kernel_size=1
203+
)
204+
205+
# Output processing
206+
207+
self.to_wave = Conv1d(
208+
in_channels=channels, out_channels=pqmf_channels, kernel_size=1
209+
)
210+
211+
self.to_loudness = Conv1d(
212+
in_channels=channels, out_channels=pqmf_channels, kernel_size=1
213+
)
214+
215+
self.to_amp = nn.Sequential(
216+
*[
217+
ConvBlock1d(
218+
in_channels=amp_channels[i],
219+
out_channels=amp_channels[i + 1],
220+
stride=noise_ratios[i],
221+
)
222+
for i in range(len(amp_channels) - 1)
223+
]
224+
)
225+
226+
def encode(self, x: Tensor) -> Tensor:
227+
x = self.pqmf(x)
228+
x = self.to_in(x)
229+
return x
230+
231+
def decode(self, x: Tensor) -> Tensor:
232+
n = self.noise_bands
233+
wave, loudness, amp = self.to_wave(x), self.to_loudness(x), self.to_amp(x)
234+
235+
# Convert computed amp to noise
236+
amp = rearrange(scaled_simgoid(amp - 5), "b (c n) t -> b t c n", n=n)
237+
impulse_response = amp_to_impulse_response(amp, self.noise_multiplier)
238+
noise = torch.rand_like(impulse_response) * 2 - 1
239+
noise = fft_convolve(noise, impulse_response)
240+
noise = rearrange(noise, "b t c n -> b c (t n)")
241+
242+
x = torch.tanh(wave) * scaled_simgoid(loudness) + noise
243+
x = self.pqmf.inverse(x)
244+
return x

audio_diffusion_pytorch/modules.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ def __init__(
8080
in_channels: int,
8181
out_channels: int,
8282
*,
83+
kernel_size: int = 3,
84+
stride: int = 1,
85+
padding: int = 1,
8386
dilation: int = 1,
8487
num_groups: int = 8,
8588
use_norm: bool = True,
@@ -95,7 +98,8 @@ def __init__(
9598
self.project = Conv1d(
9699
in_channels=in_channels,
97100
out_channels=out_channels,
98-
kernel_size=3,
101+
kernel_size=kernel_size,
102+
stride=stride,
99103
padding=dilation,
100104
dilation=dilation,
101105
)

0 commit comments

Comments
 (0)