Skip to content

Commit 015b152

Browse files
feat: back to previous vocoder
1 parent d0b206a commit 015b152

File tree

3 files changed

+40
-29
lines changed

3 files changed

+40
-29
lines changed

audio_diffusion_pytorch/model.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
)
2020
from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetConditional1d
2121
from .utils import (
22-
closest_power_2,
2322
default,
2423
downsample,
2524
exists,
@@ -204,7 +203,7 @@ def encode(
204203
) -> Union[Tensor, Tuple[Tensor, Any]]:
205204
latent, info = self.encoder(x, with_info=True)
206205
for bottleneck in self.bottlenecks:
207-
x, info_bottleneck = bottleneck(x, with_info=True)
206+
latent, info_bottleneck = bottleneck(latent, with_info=True)
208207
info = {**info, **prefix_dict("bottleneck_", info_bottleneck)}
209208
return (latent, info) if with_info else latent
210209

@@ -226,31 +225,43 @@ def decode(self, latent: Tensor, **kwargs) -> Tensor:
226225

227226

228227
class DiffusionVocoder1d(Model1d):
229-
def __init__(self, in_channels: int, stft_num_fft: int, **kwargs):
230-
self.stft_num_fft = stft_num_fft
231-
spectrogram_channels = stft_num_fft // 2 + 1
228+
def __init__(
229+
self,
230+
in_channels: int,
231+
stft_num_fft: int,
232+
**kwargs,
233+
):
234+
self.frequency_channels = stft_num_fft // 2 + 1
235+
spectrogram_channels = in_channels * self.frequency_channels
236+
237+
stft_kwargs, kwargs = groupby("stft_", kwargs)
232238
default_kwargs = dict(
233-
in_channels=in_channels,
234-
use_stft=True,
235-
stft_num_fft=stft_num_fft,
236-
context_channels=[in_channels * spectrogram_channels],
239+
in_channels=spectrogram_channels, context_channels=[spectrogram_channels]
237240
)
241+
238242
super().__init__(**{**default_kwargs, **kwargs}) # type: ignore
243+
self.stft = STFT(num_fft=stft_num_fft, **stft_kwargs)
239244

240-
def forward(self, x: Tensor, **kwargs) -> Tensor:
241-
# Get magnitude spectrogram from true wave
242-
magnitude, _ = self.unet.stft.encode(x)
243-
magnitude = rearrange(magnitude, "b c f t -> b (c f) t")
244-
# Get diffusion loss while conditioning on magnitude
245-
return self.diffusion(x, channels_list=[magnitude], **kwargs)
245+
def forward_wave(self, x: Tensor, **kwargs) -> Tensor:
246+
# Get magnitude and phase of true wave
247+
magnitude, phase = self.stft.encode(x)
248+
return self(magnitude, phase, **kwargs)
246249

247-
def sample(self, spectrogram: Tensor, **kwargs): # type: ignore
248-
b, c, _, t, device = *spectrogram.shape, spectrogram.device
249-
magnitude = rearrange(spectrogram, "b c f t -> b (c f) t")
250-
timesteps = closest_power_2(self.unet.stft.hop_length * t)
251-
noise = torch.randn((b, c, timesteps), device=device)
252-
default_kwargs = dict(channels_list=[magnitude])
253-
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore # noqa
250+
def forward(self, magnitude: Tensor, phase: Tensor, **kwargs) -> Tensor: # type: ignore # noqa
251+
magnitude = rearrange(magnitude, "b c f t -> b (c f) t")
252+
phase = rearrange(phase, "b c f t -> b (c f) t")
253+
# Get diffusion phase loss while conditioning on magnitude (/pi [-1,1] range)
254+
return self.diffusion(phase / pi, channels_list=[magnitude], **kwargs)
255+
256+
def sample(self, magnitude: Tensor, **kwargs): # type: ignore
257+
b, c, f, t, device = *magnitude.shape, magnitude.device
258+
magnitude_flat = rearrange(magnitude, "b c f t -> b (c f) t")
259+
noise = torch.randn((b, c * f, t), device=device)
260+
default_kwargs = dict(channels_list=[magnitude_flat])
261+
phase_flat = super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore # noqa
262+
phase = rearrange(phase_flat, "b (c f) t -> b c f t", c=c)
263+
wave = self.stft.decode(magnitude, phase * pi)
264+
return wave
254265

255266

256267
class DiffusionUpphaser1d(DiffusionUpsampler1d):
@@ -371,13 +382,13 @@ def __init__(self, in_channels: int, **kwargs):
371382
in_channels=in_channels,
372383
stft_num_fft=1023,
373384
stft_hop_length=256,
374-
channels=64,
385+
channels=512,
375386
patch_blocks=1,
376387
patch_factor=1,
377-
multipliers=[48, 32, 16, 8, 8, 8, 8],
378-
factors=[2, 2, 2, 1, 1, 1],
379-
num_blocks=[1, 1, 1, 1, 1, 1],
380-
attentions=[0, 0, 0, 1, 1, 1],
388+
multipliers=[3, 2, 1, 1, 1, 1, 1, 1],
389+
factors=[1, 2, 2, 2, 2, 2, 2],
390+
num_blocks=[1, 1, 1, 1, 1, 1, 1],
391+
attentions=[0, 0, 0, 0, 1, 1, 1],
381392
attention_heads=8,
382393
attention_features=64,
383394
attention_multiplier=2,

audio_diffusion_pytorch/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1096,7 +1096,7 @@ def get_channels(
10961096
assert exists(channels), message
10971097
# Check channels
10981098
num_channels = self.context_channels[layer]
1099-
message = f"Expected context with {channels} channels at index {channels_id}"
1099+
message = f"Expected context with {num_channels} channels at idx {channels_id}"
11001100
assert channels.shape[1] == num_channels, message
11011101
# STFT channels if requested
11021102
channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.81",
6+
version="0.0.82",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)