Skip to content

Commit d58ac21

Browse files
authored
Improve MelSpectrogram librosa compatibility test (#1267)
1 parent 9222e43 commit d58ac21

File tree

1 file changed

+80
-64
lines changed

1 file changed

+80
-64
lines changed

test/torchaudio_unittest/librosa_compatibility_test.py

Lines changed: 80 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torchaudio
88
import torchaudio.functional as F
99
from torchaudio._internal.module_utils import is_module_available
10-
from parameterized import parameterized
10+
from parameterized import parameterized, param
1111
import itertools
1212

1313
LIBROSA_AVAILABLE = is_module_available('librosa')
@@ -165,13 +165,17 @@ def _load_audio_asset(*asset_paths, **kwargs):
165165
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
166166
class TestTransforms(common_utils.TorchaudioTestCase):
167167
"""Test suite for functions in `transforms` module."""
168-
def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
169-
common_utils.set_audio_backend('default')
170-
path = common_utils.get_asset_path('sinewave.wav')
171-
sound, sample_rate = common_utils.load_wav(path)
172-
sound_librosa = sound.cpu().numpy().squeeze() # (64000)
173168

174-
# test core spectrogram
169+
@parameterized.expand([
170+
param(n_fft=400, hop_length=200, power=2.0),
171+
param(n_fft=600, hop_length=100, power=2.0),
172+
param(n_fft=400, hop_length=200, power=3.0),
173+
param(n_fft=200, hop_length=50, power=2.0),
174+
])
175+
def test_spectrogram(self, n_fft, hop_length, power):
176+
sample_rate = 16000
177+
sound = common_utils.get_sinusoid(n_channels=1, sample_rate=sample_rate)
178+
sound_librosa = sound.cpu().numpy().squeeze()
175179
spect_transform = torchaudio.transforms.Spectrogram(
176180
n_fft=n_fft, hop_length=hop_length, power=power)
177181
out_librosa, _ = librosa.core.spectrum._spectrogram(
@@ -180,19 +184,58 @@ def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sampl
180184
out_torch = spect_transform(sound).squeeze().cpu()
181185
self.assertEqual(out_torch, torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5)
182186

183-
# test mel spectrogram
187+
@parameterized.expand([
188+
param(norm=norm, **p.kwargs)
189+
for p in [
190+
param(n_fft=400, hop_length=200, n_mels=128),
191+
param(n_fft=600, hop_length=100, n_mels=128),
192+
param(n_fft=200, hop_length=50, n_mels=128),
193+
]
194+
for norm in [None, 'slaney']
195+
])
196+
def test_mel_spectrogram(self, n_fft, hop_length, n_mels, norm):
197+
sample_rate = 16000
198+
sound = common_utils.get_sinusoid(n_channels=1, sample_rate=sample_rate)
199+
sound_librosa = sound.cpu().numpy().squeeze()
184200
melspect_transform = torchaudio.transforms.MelSpectrogram(
185201
sample_rate=sample_rate, window_fn=torch.hann_window,
186-
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
202+
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft, norm=norm)
187203
librosa_mel = librosa.feature.melspectrogram(
188204
y=sound_librosa, sr=sample_rate, n_fft=n_fft,
189-
hop_length=hop_length, n_mels=n_mels, htk=True, norm=None)
205+
hop_length=hop_length, n_mels=n_mels, htk=True, norm=norm)
190206
librosa_mel_tensor = torch.from_numpy(librosa_mel)
191207
torch_mel = melspect_transform(sound).squeeze().cpu()
192208
self.assertEqual(
193209
torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3, rtol=1e-5)
194210

195-
# test s2db
211+
@parameterized.expand([
212+
param(norm=norm, **p.kwargs)
213+
for p in [
214+
param(n_fft=400, hop_length=200, power=2.0, n_mels=128),
215+
param(n_fft=600, hop_length=100, power=2.0, n_mels=128),
216+
param(n_fft=400, hop_length=200, power=3.0, n_mels=128),
217+
# NOTE: Test passes offline, but fails on TravisCI (and CircleCI), see #372.
218+
param(n_fft=200, hop_length=50, power=2.0, n_mels=128, skip_ci=True),
219+
]
220+
for norm in [None, 'slaney']
221+
])
222+
def test_s2db(self, n_fft, hop_length, power, n_mels, norm, skip_ci=False):
223+
if skip_ci and 'CI' in os.environ:
224+
self.skipTest('Test is known to fail on CI')
225+
sample_rate = 16000
226+
sound = common_utils.get_sinusoid(n_channels=1, sample_rate=sample_rate)
227+
sound_librosa = sound.cpu().numpy().squeeze()
228+
spect_transform = torchaudio.transforms.Spectrogram(
229+
n_fft=n_fft, hop_length=hop_length, power=power)
230+
out_librosa, _ = librosa.core.spectrum._spectrogram(
231+
y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=power)
232+
melspect_transform = torchaudio.transforms.MelSpectrogram(
233+
sample_rate=sample_rate, window_fn=torch.hann_window,
234+
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft, norm=norm)
235+
librosa_mel = librosa.feature.melspectrogram(
236+
y=sound_librosa, sr=sample_rate, n_fft=n_fft,
237+
hop_length=hop_length, n_mels=n_mels, htk=True, norm=norm)
238+
196239
power_to_db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
197240
power_to_db_torch = power_to_db_transform(spect_transform(sound)).squeeze().cpu()
198241
power_to_db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
@@ -209,10 +252,19 @@ def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sampl
209252
self.assertEqual(
210253
power_to_db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3, rtol=1e-5)
211254

212-
# test MFCC
213-
melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
214-
mfcc_transform = torchaudio.transforms.MFCC(
215-
sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho', melkwargs=melkwargs)
255+
@parameterized.expand([
256+
param(n_fft=400, hop_length=200, n_mels=128, n_mfcc=40),
257+
param(n_fft=600, hop_length=100, n_mels=128, n_mfcc=20),
258+
param(n_fft=200, hop_length=50, n_mels=128, n_mfcc=50),
259+
])
260+
def test_mfcc(self, n_fft, hop_length, n_mels, n_mfcc):
261+
sample_rate = 16000
262+
sound = common_utils.get_sinusoid(n_channels=1, sample_rate=sample_rate)
263+
sound_librosa = sound.cpu().numpy().squeeze()
264+
librosa_mel = librosa.feature.melspectrogram(
265+
y=sound_librosa, sr=sample_rate, n_fft=n_fft,
266+
hop_length=hop_length, n_mels=n_mels, htk=True, norm=None)
267+
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
216268

217269
# librosa.feature.mfcc doesn't pass kwargs properly since some of the
218270
# kwargs for melspectrogram and mfcc are the same. We just follow the
@@ -226,14 +278,24 @@ def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sampl
226278

227279
librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
228280
librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
281+
282+
melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
283+
mfcc_transform = torchaudio.transforms.MFCC(
284+
sample_rate=sample_rate, n_mfcc=n_mfcc, norm='ortho', melkwargs=melkwargs)
229285
torch_mfcc = mfcc_transform(sound).squeeze().cpu()
230286

231287
self.assertEqual(
232288
torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3, rtol=1e-5)
233289

234-
self.assert_compatibilities_spectral_centroid(sample_rate, n_fft, hop_length, sound, sound_librosa)
235-
236-
def assert_compatibilities_spectral_centroid(self, sample_rate, n_fft, hop_length, sound, sound_librosa):
290+
@parameterized.expand([
291+
param(n_fft=400, hop_length=200),
292+
param(n_fft=600, hop_length=100),
293+
param(n_fft=200, hop_length=50),
294+
])
295+
def test_spectral_centroid(self, n_fft, hop_length):
296+
sample_rate = 16000
297+
sound = common_utils.get_sinusoid(n_channels=1, sample_rate=sample_rate)
298+
sound_librosa = sound.cpu().numpy().squeeze()
237299
spect_centroid = torchaudio.transforms.SpectralCentroid(
238300
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length)
239301
out_torch = spect_centroid(sound).squeeze().cpu()
@@ -244,52 +306,6 @@ def assert_compatibilities_spectral_centroid(self, sample_rate, n_fft, hop_lengt
244306

245307
self.assertEqual(out_torch.type(out_librosa.dtype), out_librosa, atol=1e-5, rtol=1e-5)
246308

247-
def test_basics1(self):
248-
kwargs = {
249-
'n_fft': 400,
250-
'hop_length': 200,
251-
'power': 2.0,
252-
'n_mels': 128,
253-
'n_mfcc': 40,
254-
'sample_rate': 16000
255-
}
256-
self.assert_compatibilities(**kwargs)
257-
258-
def test_basics2(self):
259-
kwargs = {
260-
'n_fft': 600,
261-
'hop_length': 100,
262-
'power': 2.0,
263-
'n_mels': 128,
264-
'n_mfcc': 20,
265-
'sample_rate': 16000
266-
}
267-
self.assert_compatibilities(**kwargs)
268-
269-
# NOTE: Test passes offline, but fails on TravisCI (and CircleCI), see #372.
270-
@unittest.skipIf('CI' in os.environ, 'Test is known to fail on CI')
271-
def test_basics3(self):
272-
kwargs = {
273-
'n_fft': 200,
274-
'hop_length': 50,
275-
'power': 2.0,
276-
'n_mels': 128,
277-
'n_mfcc': 50,
278-
'sample_rate': 24000
279-
}
280-
self.assert_compatibilities(**kwargs)
281-
282-
def test_basics4(self):
283-
kwargs = {
284-
'n_fft': 400,
285-
'hop_length': 200,
286-
'power': 3.0,
287-
'n_mels': 128,
288-
'n_mfcc': 40,
289-
'sample_rate': 16000
290-
}
291-
self.assert_compatibilities(**kwargs)
292-
293309
def test_MelScale(self):
294310
"""MelScale transform is comparable to that of librosa"""
295311
n_fft = 2048

0 commit comments

Comments
 (0)