|
| 1 | +from math import pi |
1 | 2 | from typing import Any, Optional, Sequence, Tuple, Union |
2 | 3 |
|
3 | 4 | import torch |
| 5 | +from einops import rearrange |
4 | 6 | from torch import Tensor, nn |
5 | 7 |
|
6 | 8 | from .diffusion import ( |
|
15 | 17 | VSampler, |
16 | 18 | ) |
17 | 19 | from .modules import ( |
| 20 | + STFT, |
18 | 21 | Bottleneck, |
19 | 22 | MultiEncoder1d, |
20 | 23 | SinusoidalEmbedding, |
@@ -223,6 +226,62 @@ def decode(self, latent: Tensor, **kwargs) -> Tensor: |
223 | 226 | return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore |
224 | 227 |
|
225 | 228 |
|
| 229 | +class DiffusionVocoder1d(Model1d): |
| 230 | + def __init__( |
| 231 | + self, |
| 232 | + in_channels: int, |
| 233 | + vocoder_num_fft: int, |
| 234 | + **kwargs, |
| 235 | + ): |
| 236 | + self.frequency_channels = vocoder_num_fft // 2 + 1 |
| 237 | + spectrogram_channels = in_channels * self.frequency_channels |
| 238 | + |
| 239 | + vocoder_kwargs, kwargs = groupby_kwargs_prefix("vocoder_", kwargs) |
| 240 | + default_kwargs = dict( |
| 241 | + in_channels=spectrogram_channels, context_channels=[spectrogram_channels] |
| 242 | + ) |
| 243 | + |
| 244 | + super().__init__(**{**default_kwargs, **kwargs}) # type: ignore |
| 245 | + self.stft = STFT(num_fft=vocoder_num_fft, **vocoder_kwargs) |
| 246 | + |
| 247 | + def forward(self, x: Tensor, **kwargs) -> Tensor: |
| 248 | + # Get magnitude and phase of true wave |
| 249 | + magnitude, phase = self.stft.encode(x) |
| 250 | + magnitude = rearrange(magnitude, "b c f t -> b (c f) t") |
| 251 | + phase = rearrange(phase, "b c f t -> b (c f) t") |
| 252 | + # Get diffusion phase loss while conditioning on magnitude (/pi [-1,1] range) |
| 253 | + return self.diffusion(phase / pi, channels_list=[magnitude], **kwargs) |
| 254 | + |
| 255 | + def sample(self, spectrogram: Tensor, **kwargs): # type: ignore |
| 256 | + b, c, f, t, device = *spectrogram.shape, spectrogram.device |
| 257 | + magnitude = rearrange(spectrogram, "b c f t -> b (c f) t") |
| 258 | + noise = torch.randn((b, c * f, t), device=device) |
| 259 | + default_kwargs = dict(channels_list=[magnitude]) |
| 260 | + phase = super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore # noqa |
| 261 | + phase = rearrange(phase, "b (c f) t -> b c f t", c=c) |
| 262 | + wave = self.stft.decode(spectrogram, phase * pi) |
| 263 | + return wave |
| 264 | + |
| 265 | + |
| 266 | +class DiffusionUpphaser1d(DiffusionUpsampler1d): |
| 267 | + def __init__(self, **kwargs): |
| 268 | + vocoder_kwargs, kwargs = groupby_kwargs_prefix("vocoder_", kwargs) |
| 269 | + super().__init__(**kwargs) |
| 270 | + self.stft = STFT(**vocoder_kwargs) |
| 271 | + |
| 272 | + def random_rephase(self, x: Tensor) -> Tensor: |
| 273 | + magnitude, phase = self.stft.encode(x) |
| 274 | + phase_random = (torch.rand_like(phase) - 0.5) * 2 * pi |
| 275 | + wave = self.stft.decode(magnitude, phase_random) |
| 276 | + return wave |
| 277 | + |
| 278 | + def forward(self, x: Tensor, **kwargs) -> Tensor: |
| 279 | + rephased = self.random_rephase(x) |
| 280 | + resampled, factors = self.random_reupsample(rephased) |
| 281 | + features = self.to_features(factors) if self.use_conditioning else None |
| 282 | + return self.diffusion(x, channels_list=[resampled], features=features, **kwargs) |
| 283 | + |
| 284 | + |
226 | 285 | """ |
227 | 286 | Audio Diffusion Classes (specific for 1d audio data) |
228 | 287 | """ |
@@ -315,3 +374,49 @@ def sample(self, *args, **kwargs): |
315 | 374 | embedding_scale=5.0, |
316 | 375 | ) |
317 | 376 | return super().sample(*args, **{**default_kwargs, **kwargs}) |
| 377 | + |
| 378 | + |
| 379 | +class AudioDiffusionVocoder(DiffusionVocoder1d): |
| 380 | + def __init__(self, in_channels: int, **kwargs): |
| 381 | + default_kwargs = dict( |
| 382 | + in_channels=in_channels, |
| 383 | + vocoder_num_fft=1023, |
| 384 | + channels=32, |
| 385 | + patch_blocks=1, |
| 386 | + patch_factor=1, |
| 387 | + multipliers=[64, 32, 16, 8, 4, 2, 1], |
| 388 | + factors=[1, 1, 1, 1, 1, 1], |
| 389 | + num_blocks=[1, 1, 1, 1, 1, 1], |
| 390 | + attentions=[0, 0, 0, 1, 1, 1], |
| 391 | + attention_heads=8, |
| 392 | + attention_features=64, |
| 393 | + attention_multiplier=2, |
| 394 | + attention_use_rel_pos=False, |
| 395 | + resnet_groups=8, |
| 396 | + kernel_multiplier_downsample=2, |
| 397 | + use_nearest_upsample=False, |
| 398 | + use_skip_scale=True, |
| 399 | + use_context_time=True, |
| 400 | + use_magnitude_channels=False, |
| 401 | + diffusion_type="v", |
| 402 | + diffusion_sigma_distribution=UniformDistribution(), |
| 403 | + ) |
| 404 | + super().__init__(**{**default_kwargs, **kwargs}) # type: ignore |
| 405 | + |
| 406 | + def sample(self, *args, **kwargs): |
| 407 | + default_kwargs = dict(**get_default_sampling_kwargs()) |
| 408 | + return super().sample(*args, **{**default_kwargs, **kwargs}) |
| 409 | + |
| 410 | + |
| 411 | +class AudioDiffusionUpphaser(DiffusionUpphaser1d): |
| 412 | + def __init__(self, in_channels: int, **kwargs): |
| 413 | + default_kwargs = dict( |
| 414 | + **get_default_model_kwargs(), |
| 415 | + in_channels=in_channels, |
| 416 | + context_channels=[in_channels], |
| 417 | + factor=1, |
| 418 | + ) |
| 419 | + super().__init__(**{**default_kwargs, **kwargs}) # type: ignore |
| 420 | + |
| 421 | + def sample(self, *args, **kwargs): |
| 422 | + return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs}) |
0 commit comments