Skip to content

Commit b63df7d

Browse files
feat: add use_complex stft diffusion option
1 parent 145609e commit b63df7d

File tree

2 files changed

+37
-22
lines changed

2 files changed

+37
-22
lines changed

audio_diffusion_pytorch/modules.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,19 +1277,23 @@ def forward(self, texts: List[str]) -> Tensor:
12771277

12781278

12791279
class STFT(nn.Module):
1280+
"""Helper for torch stft and istft"""
1281+
12801282
def __init__(
12811283
self,
12821284
num_fft: int = 1023,
1283-
hop_length: Optional[int] = None,
1285+
hop_length: int = 256,
12841286
window_length: Optional[int] = None,
12851287
length: Optional[int] = None,
1288+
use_complex: bool = False,
12861289
):
12871290
super().__init__()
12881291
self.num_fft = num_fft
12891292
self.hop_length = default(hop_length, floor(num_fft // 4))
12901293
self.window_length = default(window_length, num_fft)
12911294
self.length = length
12921295
self.register_buffer("window", torch.hann_window(self.window_length))
1296+
self.use_complex = use_complex
12931297

12941298
def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
12951299
b = wave.shape[0]
@@ -1302,43 +1306,54 @@ def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
13021306
win_length=self.window_length,
13031307
window=self.window, # type: ignore
13041308
return_complex=True,
1309+
normalized=True,
13051310
)
13061311

1307-
mag = torch.sqrt(torch.clamp((stft.real ** 2) + (stft.imag ** 2), min=1e-8))
1308-
mag = rearrange(mag, "(b c) f l -> b c f l", b=b)
1312+
if self.use_complex:
1313+
# Returns real and imaginary
1314+
stft_a, stft_b = stft.real, stft.imag
1315+
else:
1316+
# Returns magnitude and phase matrices
1317+
magnitude, phase = torch.abs(stft), torch.angle(stft)
1318+
stft_a, stft_b = magnitude, phase
13091319

1310-
phase = torch.angle(stft)
1311-
phase = rearrange(phase, "(b c) f l -> b c f l", b=b)
1312-
return mag, phase
1320+
return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b)
13131321

1314-
def decode(self, magnitude: Tensor, phase: Tensor) -> Tensor:
1315-
b, l = magnitude.shape[0], magnitude.shape[-1] # noqa
1316-
assert magnitude.shape == phase.shape, "magnitude and phase must be same shape"
1317-
real = rearrange(magnitude * torch.cos(phase), "b c f l -> (b c) f l")
1318-
imag = rearrange(magnitude * torch.sin(phase), "b c f l -> (b c) f l")
1319-
stft = torch.stack([real, imag], dim=-1)
1322+
def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor:
1323+
b, l = stft_a.shape[0], stft_a.shape[-1] # noqa
13201324
length = closest_power_2(l * self.hop_length)
13211325

1326+
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l")
1327+
1328+
if self.use_complex:
1329+
real, imag = stft_a, stft_b
1330+
else:
1331+
magnitude, phase = stft_a, stft_b
1332+
real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase)
1333+
1334+
stft = torch.stack([real, imag], dim=-1)
1335+
13221336
wave = torch.istft(
13231337
stft,
13241338
n_fft=self.num_fft,
13251339
hop_length=self.hop_length,
13261340
win_length=self.window_length,
13271341
window=self.window, # type: ignore
13281342
length=default(self.length, length),
1343+
normalized=True,
13291344
)
1330-
wave = rearrange(wave, "(b c) t -> b c t", b=b)
1331-
return wave
1345+
1346+
return rearrange(wave, "(b c) t -> b c t", b=b)
13321347

13331348
def encode1d(
13341349
self, wave: Tensor, stacked: bool = True
13351350
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
1336-
magnitude, phase = self.encode(wave)
1337-
magnitude, phase = rearrange_many((magnitude, phase), "b c f l -> b (c f) l")
1338-
return torch.cat((magnitude, phase), dim=1) if stacked else (magnitude, phase)
1351+
stft_a, stft_b = self.encode(wave)
1352+
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l")
1353+
return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b)
13391354

1340-
def decode1d(self, magnitude_and_phase: Tensor) -> Tensor:
1355+
def decode1d(self, stft_pair: Tensor) -> Tensor:
13411356
f = self.num_fft // 2 + 1
1342-
magnitude, phase = magnitude_and_phase.chunk(chunks=2, dim=1)
1343-
mag, phase = rearrange_many((magnitude, phase), "b (c f) l -> b c f l", f=f)
1344-
return self.decode(mag, phase)
1357+
stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1)
1358+
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f)
1359+
return self.decode(stft_a, stft_b)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.84",
6+
version="0.0.85",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)