Skip to content

Commit fc1fb9b

Browse files
feat: add stftautoencoder
1 parent aa83393 commit fc1fb9b

File tree

3 files changed

+98
-1
lines changed

3 files changed

+98
-1
lines changed

audio_diffusion_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
Encoder1d,
3535
MultiEncoder1d,
3636
Noiser,
37+
STFTAutoEncoder1d,
3738
T5Embedder,
3839
Tanh,
3940
UNet1d,

audio_diffusion_pytorch/modules.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1687,3 +1687,99 @@ def decode(self, latent: Tensor) -> List[Tensor]:
16871687
x = self.to_out(x)
16881688
channels_list += [x]
16891689
return channels_list[::-1]
1690+
1691+
1692+
class STFT(nn.Module):
1693+
def __init__(
1694+
self,
1695+
length: int,
1696+
num_fft: int = 1024,
1697+
hop_length: int = 256,
1698+
window_length: int = 1024,
1699+
):
1700+
super().__init__()
1701+
self.num_fft = num_fft
1702+
self.hop_length = hop_length
1703+
self.window_length = window_length
1704+
self.length = length
1705+
self.register_buffer("window", torch.hann_window(window_length))
1706+
1707+
def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
1708+
b = wave.shape[0]
1709+
wave = rearrange(wave, "b c t -> (b c) t")
1710+
1711+
stft = torch.stft(
1712+
wave,
1713+
n_fft=self.num_fft,
1714+
hop_length=self.hop_length,
1715+
win_length=self.window_length,
1716+
window=self.window, # type: ignore
1717+
return_complex=True,
1718+
)
1719+
1720+
mag = torch.sqrt(torch.clamp((stft.real ** 2) + (stft.imag ** 2), min=1e-8))
1721+
mag = rearrange(mag, "(b c) f l -> b c f l", b=b)
1722+
1723+
phase = torch.angle(stft)
1724+
phase = rearrange(phase, "(b c) f l -> b c f l", b=b)
1725+
return mag, phase
1726+
1727+
def decode(self, magnitude: Tensor, phase: Tensor) -> Tensor:
1728+
b = magnitude.shape[0]
1729+
assert magnitude.shape == phase.shape, "magnitude and phase must be same shape"
1730+
real = rearrange(magnitude * torch.cos(phase), "b c f l -> (b c) f l")
1731+
imag = rearrange(magnitude * torch.sin(phase), "b c f l -> (b c) f l")
1732+
stft = torch.stack([real, imag], dim=-1)
1733+
1734+
wave = torch.istft(
1735+
stft,
1736+
n_fft=self.num_fft,
1737+
hop_length=self.hop_length,
1738+
win_length=self.window_length,
1739+
window=self.window, # type: ignore
1740+
length=self.length,
1741+
)
1742+
wave = rearrange(wave, "(b c) t -> b c t", b=b)
1743+
return wave
1744+
1745+
1746+
class STFTAutoEncoder1d(AutoEncoder1d):
1747+
def __init__(
1748+
self,
1749+
in_channels: int,
1750+
length: int,
1751+
num_fft: int = 1024,
1752+
hop_length: int = 256,
1753+
window_length: int = 1024,
1754+
**kwargs,
1755+
):
1756+
self.frequency_channels = num_fft // 2 + 1
1757+
1758+
super().__init__(
1759+
in_channels=in_channels * self.frequency_channels,
1760+
out_channels=in_channels * self.frequency_channels * 2,
1761+
patch_blocks=1,
1762+
patch_factor=1,
1763+
**kwargs,
1764+
)
1765+
1766+
self.stft = STFT(
1767+
num_fft=num_fft,
1768+
hop_length=hop_length,
1769+
window_length=window_length,
1770+
length=length,
1771+
)
1772+
1773+
def encode(
1774+
self, wave: Tensor, with_info: bool = False
1775+
) -> Union[Tensor, Tuple[Tensor, Any]]:
1776+
magnitude, phase = self.stft.encode(wave)
1777+
log_magnitude = rearrange(torch.log(magnitude), "b c f t -> b (c f) t")
1778+
return super().encode(log_magnitude, with_info)
1779+
1780+
def decode(self, z: Tensor) -> Tensor:
1781+
f = self.frequency_channels
1782+
stft = super().decode(z)
1783+
stft = rearrange(stft, "b (c f i) t -> b (c i) f t", i=2, f=f)
1784+
log_magnitude, phase = stft.chunk(chunks=2, dim=1)
1785+
return self.stft.decode(magnitude=torch.exp(log_magnitude), phase=phase)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.76",
6+
version="0.0.77",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)