Skip to content

Commit 329d739

Browse files
authored
Refactor mel module (#132)
* Refactor wave-to-mel * Add docstring on mel * Refactor mel module import and variable names
1 parent a02ef40 commit 329d739

File tree

1 file changed

+38
-70
lines changed

1 file changed

+38
-70
lines changed

train/mel_processing.py

Lines changed: 38 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,8 @@
1-
import math
2-
import os
3-
import random
41
import torch
5-
from torch import nn
6-
import torch.nn.functional as F
72
import torch.utils.data
8-
import numpy as np
9-
import librosa
10-
import librosa.util as librosa_util
11-
from librosa.util import normalize, pad_center, tiny
12-
from scipy.signal import get_window
13-
from scipy.io.wavfile import read
143
from librosa.filters import mel as librosa_mel_fn
154

5+
166
MAX_WAV_VALUE = 32768.0
177

188

@@ -35,25 +25,38 @@ def dynamic_range_decompression_torch(x, C=1):
3525

3626

3727
def spectral_normalize_torch(magnitudes):
38-
output = dynamic_range_compression_torch(magnitudes)
39-
return output
28+
return dynamic_range_compression_torch(magnitudes)
4029

4130

4231
def spectral_de_normalize_torch(magnitudes):
43-
output = dynamic_range_decompression_torch(magnitudes)
44-
return output
32+
return dynamic_range_decompression_torch(magnitudes)
4533

4634

35+
# Reusable banks
4736
mel_basis = {}
4837
hann_window = {}
4938

5039

5140
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
41+
"""Convert waveform into Linear-frequency Linear-amplitude spectrogram.
42+
43+
Args:
44+
y :: (B, T) - Audio waveforms
45+
n_fft
46+
sampling_rate
47+
hop_size
48+
win_size
49+
center
50+
Returns:
51+
:: (B, Freq, Frame) - Linear-frequency Linear-amplitude spectrogram
52+
"""
53+
# Validation
5254
if torch.min(y) < -1.0:
5355
print("min value is ", torch.min(y))
5456
if torch.max(y) > 1.0:
5557
print("max value is ", torch.max(y))
5658

59+
# Window - Cache if needed
5760
global hann_window
5861
dtype_device = str(y.dtype) + "_" + str(y.device)
5962
wnsize_dtype_device = str(win_size) + "_" + dtype_device
@@ -62,13 +65,15 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
6265
dtype=y.dtype, device=y.device
6366
)
6467

68+
# Padding
6569
y = torch.nn.functional.pad(
6670
y.unsqueeze(1),
6771
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
6872
mode="reflect",
6973
)
7074
y = y.squeeze(1)
7175

76+
# Complex Spectrogram :: (B, T) -> (B, Freq, Frame, RealComplex=2)
7277
spec = torch.stft(
7378
y,
7479
n_fft,
@@ -82,11 +87,13 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
8287
return_complex=False,
8388
)
8489

90+
# Linear-frequency Linear-amplitude spectrogram :: (B, Freq, Frame, RealComplex=2) -> (B, Freq, Frame)
8591
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
8692
return spec
8793

8894

8995
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
96+
# MelBasis - Cache if needed
9097
global mel_basis
9198
dtype_device = str(spec.dtype) + "_" + str(spec.device)
9299
fmax_dtype_device = str(fmax) + "_" + dtype_device
@@ -95,66 +102,27 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
95102
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
96103
dtype=spec.dtype, device=spec.device
97104
)
98-
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
99-
spec = spectral_normalize_torch(spec)
100-
return spec
105+
106+
# Mel-frequency Log-amplitude spectrogram :: (B, Freq=num_mels, Frame)
107+
melspec = torch.matmul(mel_basis[fmax_dtype_device], spec)
108+
melspec = spectral_normalize_torch(melspec)
109+
return melspec
101110

102111

103112
def mel_spectrogram_torch(
104113
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
105114
):
106-
if torch.min(y) < -1.0:
107-
print("min value is ", torch.min(y))
108-
if torch.max(y) > 1.0:
109-
print("max value is ", torch.max(y))
110-
111-
global mel_basis, hann_window
112-
dtype_device = str(y.dtype) + "_" + str(y.device)
113-
fmax_dtype_device = str(fmax) + "_" + dtype_device
114-
wnsize_dtype_device = str(win_size) + "_" + dtype_device
115-
if fmax_dtype_device not in mel_basis:
116-
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
117-
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
118-
dtype=y.dtype, device=y.device
119-
)
120-
if wnsize_dtype_device not in hann_window:
121-
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
122-
dtype=y.dtype, device=y.device
123-
)
115+
"""Convert waveform into Mel-frequency Log-amplitude spectrogram.
124116
125-
y = torch.nn.functional.pad(
126-
y.unsqueeze(1),
127-
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
128-
mode="reflect",
129-
)
130-
y = y.squeeze(1)
131-
132-
# spec = torch.stft(
133-
# y,
134-
# n_fft,
135-
# hop_length=hop_size,
136-
# win_length=win_size,
137-
# window=hann_window[wnsize_dtype_device],
138-
# center=center,
139-
# pad_mode="reflect",
140-
# normalized=False,
141-
# onesided=True,
142-
# )
143-
spec = torch.stft(
144-
y,
145-
n_fft,
146-
hop_length=hop_size,
147-
win_length=win_size,
148-
window=hann_window[wnsize_dtype_device],
149-
center=center,
150-
pad_mode="reflect",
151-
normalized=False,
152-
onesided=True,
153-
return_complex=False,
154-
)
155-
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
117+
Args:
118+
y :: (B, T) - Waveforms
119+
Returns:
120+
melspec :: (B, Freq, Frame) - Mel-frequency Log-amplitude spectrogram
121+
"""
122+
# Linear-frequency Linear-amplitude spectrogram :: (B, T) -> (B, Freq, Frame)
123+
spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center)
156124

157-
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
158-
spec = spectral_normalize_torch(spec)
125+
# Mel-frequency Log-amplitude spectrogram :: (B, Freq, Frame) -> (B, Freq=num_mels, Frame)
126+
melspec = spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax)
159127

160-
return spec
128+
return melspec

0 commit comments

Comments
 (0)