|
| 1 | +import torch |
| 2 | +from torch import nn |
| 3 | +from transformers import Lfm2Config, Lfm2Model |
| 4 | + |
| 5 | + |
| 6 | +class FusedEmbedding(nn.Module): |
| 7 | + """Turn codes into embeddings""" |
| 8 | + |
| 9 | + def __init__( |
| 10 | + self, |
| 11 | + dim: int, |
| 12 | + codeboooks: int = 8, |
| 13 | + vocab_size: int = 2048, |
| 14 | + ): |
| 15 | + super().__init__() |
| 16 | + self.emb = nn.Embedding(codeboooks * vocab_size, dim) |
| 17 | + |
| 18 | + self.codeboooks = codeboooks |
| 19 | + self.vocab_size = vocab_size |
| 20 | + |
| 21 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 22 | + offsets = torch.arange(self.codeboooks, device=x.device) * self.vocab_size # TODO: buffer? |
| 23 | + offset_x = offsets[:, None] + x |
| 24 | + return self.emb(offset_x).mean(1) # B L D |
| 25 | + |
| 26 | + |
| 27 | +class ISTFT(nn.Module): |
| 28 | + """ |
| 29 | + Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with |
| 30 | + windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. |
| 31 | + See issue: https://github.com/pytorch/pytorch/issues/62323 |
| 32 | + Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. |
| 33 | + The NOLA constraint is met as we trim padded samples anyway. |
| 34 | +
|
| 35 | + Adapted from Vocos: https://github.com/gemelo-ai/vocos/blob/c859e3b7b534f3776a357983029d34170ddd6fc3/vocos/spectral_ops.py#L7 |
| 36 | + Args: |
| 37 | + n_fft (int): Size of Fourier transform. |
| 38 | + hop_length (int): The distance between neighboring sliding window frames. |
| 39 | + win_length (int): The size of window frame and STFT filter. |
| 40 | + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". |
| 41 | + """ |
| 42 | + |
| 43 | + def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): |
| 44 | + super().__init__() |
| 45 | + if padding not in ["center", "same"]: |
| 46 | + raise ValueError("Padding must be 'center' or 'same'.") |
| 47 | + self.padding = padding |
| 48 | + self.n_fft = n_fft |
| 49 | + self.hop_length = hop_length |
| 50 | + self.win_length = win_length |
| 51 | + window = torch.hann_window(win_length) |
| 52 | + self.register_buffer("window", window) |
| 53 | + |
| 54 | + def forward(self, spec: torch.Tensor) -> torch.Tensor: |
| 55 | + """ |
| 56 | + Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. |
| 57 | + Args: |
| 58 | + spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, |
| 59 | + N is the number of frequency bins, and T is the number of time frames. |
| 60 | + Returns: |
| 61 | + Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. |
| 62 | + """ |
| 63 | + if self.padding == "center": |
| 64 | + # Fallback to pytorch native implementation |
| 65 | + return torch.istft( |
| 66 | + spec, |
| 67 | + self.n_fft, |
| 68 | + self.hop_length, |
| 69 | + self.win_length, |
| 70 | + self.window, # type: ignore[arg-type] |
| 71 | + center=True, |
| 72 | + ) |
| 73 | + elif self.padding == "same": |
| 74 | + pad = (self.win_length - self.hop_length) // 2 |
| 75 | + else: |
| 76 | + raise ValueError("Padding must be 'center' or 'same'.") |
| 77 | + |
| 78 | + assert spec.dim() == 3, "Expected a 3D tensor as input" |
| 79 | + _B, _N, T = spec.shape |
| 80 | + |
| 81 | + # Inverse FFT |
| 82 | + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") |
| 83 | + ifft = ifft * self.window[None, :, None] # type: ignore[index] |
| 84 | + |
| 85 | + # Overlap and Add |
| 86 | + output_size = (T - 1) * self.hop_length + self.win_length |
| 87 | + y = torch.nn.functional.fold( |
| 88 | + ifft, |
| 89 | + output_size=(1, output_size), |
| 90 | + kernel_size=(1, self.win_length), |
| 91 | + stride=(1, self.hop_length), |
| 92 | + )[:, 0, 0, pad:-pad] |
| 93 | + |
| 94 | + # Window envelope |
| 95 | + window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) # type: ignore[operator] |
| 96 | + window_envelope = torch.nn.functional.fold( |
| 97 | + window_sq, |
| 98 | + output_size=(1, output_size), |
| 99 | + kernel_size=(1, self.win_length), |
| 100 | + stride=(1, self.hop_length), |
| 101 | + ).squeeze()[pad:-pad] |
| 102 | + |
| 103 | + # Normalize |
| 104 | + assert (window_envelope > 1e-11).all() |
| 105 | + y = y / window_envelope |
| 106 | + |
| 107 | + return y |
| 108 | + |
| 109 | + |
| 110 | +class LFM2AudioDetokenizer(nn.Module): |
| 111 | + def __init__(self, backbone_config: Lfm2Config): |
| 112 | + super().__init__() |
| 113 | + self.emb = FusedEmbedding(512) |
| 114 | + self.lfm = Lfm2Model(backbone_config) |
| 115 | + self.lin = nn.Linear(512, 1282) # half are log-magnitude, half are angle |
| 116 | + |
| 117 | + self.istft = ISTFT(1280, 320, 1280, padding="same") |
| 118 | + self.sliding_window_size = getattr(backbone_config, "sliding_window", 30) |
| 119 | + |
| 120 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 121 | + x = self.emb(x) |
| 122 | + upsample_size = 6 * x.shape[1] |
| 123 | + x = nn.functional.interpolate(x.mT, upsample_size, mode="nearest-exact").mT |
| 124 | + |
| 125 | + # Set attn mask |
| 126 | + idx = torch.arange(x.shape[1], device=x.device) |
| 127 | + d_idx = idx - idx[:, None] |
| 128 | + mask = torch.logical_and(d_idx <= 0, d_idx > -self.sliding_window_size)[None, None, ...] |
| 129 | + |
| 130 | + x = self.lfm(inputs_embeds=x, attention_mask=mask, use_cache=False).last_hidden_state |
| 131 | + x = self.lin(x) |
| 132 | + |
| 133 | + log_abs, angle = torch.chunk(x.mT.contiguous(), 2, 1) |
| 134 | + y = torch.polar(log_abs.exp(), angle) |
| 135 | + |
| 136 | + return self.istft(y) |
0 commit comments