Skip to content

Commit b17da7a

Browse files
authored
Make TestCases backend-aware (#719)
* Make tests backend aware by introducing TorchaudioTestCase and reset backend for each TestCase. * Set backends for the test cases that require specific backend.
1 parent 03da871 commit b17da7a

21 files changed

+160
-217
lines changed

test/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@ The following test modules are defined for corresponding `torchaudio` module/fun
4444

4545
## Adding test
4646

47+
The following is the current practice of torchaudio test suite.
48+
49+
1. Unless the tests are related to I/O, use synthetic data. [`common_utils`](./common_utils.py) has some data generator functions.
50+
1. When you add a new test case, use `common_utils.TorchaudioTestCase` as base class unless you are writing tests that are common to CPU / CUDA.
51+
- Set class memeber `dtype`, `device` and `backend` for the desired behavior.
52+
- If you do not set `backend` value in your test suite, then I/O functions will be unassigned and attempt to load/save file will fail.
53+
- For `backend` value, in addition to available backends, you can also provide the value "default" and backend will be picked automatically based on availability.
54+
1. If you are writing tests that should pass on diffrent dtype/devices, write a common class inheriting `common_utils.TestBaseMixin`, then inherit `common_utils.PytorchTestCase` and define class attributes (`dtype` / `device` / `backend`) there. See [Torchscript consistency test implementation](./torchscript_consistency_impl.py) and test definitions for [CPU](./torchscript_consistency_cpu_test.py) and [CUDA](./torchscript_consistency_cuda_test.py) devices.
55+
1. For numerically comparing Tensors, use `assertEqual` method from `common_utils.PytorchTestCase` class. This method has a better support for a wide variety of Tensor types.
56+
4757
When you add a new feature(functional/transform), consider the following
4858

4959
1. When you add a new feature, please make it Torchscript-able and batch-consistent unless it degrades the performance. Please add the tests to see if the new feature meet these requirements.

test/common_utils.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import os
22
import tempfile
33
import unittest
4-
from typing import Iterable, Union
5-
from contextlib import contextmanager
4+
from typing import Union
65
from shutil import copytree
76

87
import torch
9-
from torch.testing._internal.common_utils import TestCase
8+
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
109
import torchaudio
1110

1211
_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
@@ -55,24 +54,14 @@ def random_float_tensor(seed, size, a=22695477, c=1, m=2 ** 32):
5554
return torch.tensor(arr).float().view(size) / m
5655

5756

58-
@contextmanager
59-
def AudioBackendScope(new_backend):
60-
previous_backend = torchaudio.get_audio_backend()
61-
try:
62-
torchaudio.set_audio_backend(new_backend)
63-
yield
64-
finally:
65-
torchaudio.set_audio_backend(previous_backend)
66-
67-
6857
def filter_backends_with_mp3(backends):
6958
# Filter out backends that do not support mp3
7059
test_filepath = get_asset_path('steam-train-whistle-daniel_simon.mp3')
7160

7261
def supports_mp3(backend):
62+
torchaudio.set_audio_backend(backend)
7363
try:
74-
with AudioBackendScope(backend):
75-
torchaudio.load(test_filepath)
64+
torchaudio.load(test_filepath)
7665
return True
7766
except (RuntimeError, ImportError):
7867
return False
@@ -83,21 +72,38 @@ def supports_mp3(backend):
8372
BACKENDS_MP3 = filter_backends_with_mp3(BACKENDS)
8473

8574

75+
def set_audio_backend(backend):
76+
"""Allow additional backend value, 'default'"""
77+
if backend == 'default':
78+
if 'sox' in BACKENDS:
79+
be = 'sox'
80+
elif 'soundfile' in BACKENDS:
81+
be = 'soundfile'
82+
else:
83+
raise unittest.SkipTest('No default backend available')
84+
else:
85+
be = backend
86+
87+
torchaudio.set_audio_backend(be)
88+
89+
8690
class TestBaseMixin:
91+
"""Mixin to provide consistent way to define device/dtype/backend aware TestCase"""
8792
dtype = None
8893
device = None
94+
backend = None
8995

96+
def setUp(self):
97+
super().setUp()
98+
set_audio_backend(self.backend)
9099

91-
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
92100

101+
class TorchaudioTestCase(TestBaseMixin, PytorchTestCase):
102+
pass
93103

94-
def common_test_class_parameters(
95-
dtypes: Iterable[str] = ("float32", "float64"),
96-
devices: Iterable[str] = ("cpu", "cuda"),
97-
):
98-
for device in devices:
99-
for dtype in dtypes:
100-
yield {"device": torch.device(device), "dtype": getattr(torch, dtype)}
104+
105+
skipIfNoSoxBackend = unittest.skipIf('sox' not in BACKENDS, 'Sox backend not available')
106+
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
101107

102108

103109
def get_whitenoise(

test/functional_cpu_test.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,17 @@
1010
from .functional_impl import Lfilter
1111

1212

13-
class TestLFilterFloat32(Lfilter, common_utils.TestCase):
13+
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
1414
dtype = torch.float32
1515
device = torch.device('cpu')
1616

1717

18-
class TestLFilterFloat64(Lfilter, common_utils.TestCase):
18+
class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
1919
dtype = torch.float64
2020
device = torch.device('cpu')
2121

2222

23-
class TestComputeDeltas(unittest.TestCase):
23+
class TestComputeDeltas(common_utils.TorchaudioTestCase):
2424
"""Test suite for correctness of compute_deltas"""
2525
def test_one_channel(self):
2626
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]])
@@ -57,7 +57,7 @@ def _test_istft_is_inverse_of_stft(kwargs):
5757
_compare_estimate(sound, estimate)
5858

5959

60-
class TestIstft(unittest.TestCase):
60+
class TestIstft(common_utils.TorchaudioTestCase):
6161
"""Test suite for correctness of istft with various input"""
6262
number_of_trials = 100
6363

@@ -273,7 +273,9 @@ def test_linearity_of_istft4(self):
273273
self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8)
274274

275275

276-
class TestDetectPitchFrequency(unittest.TestCase):
276+
class TestDetectPitchFrequency(common_utils.TorchaudioTestCase):
277+
backend = 'default'
278+
277279
def test_pitch(self):
278280
test_filepath_100 = common_utils.get_asset_path("100Hz_44100Hz_16bit_05sec.wav")
279281
test_filepath_440 = common_utils.get_asset_path("440Hz_44100Hz_16bit_05sec.wav")
@@ -294,7 +296,7 @@ def test_pitch(self):
294296
self.assertFalse(s)
295297

296298

297-
class TestDB_to_amplitude(unittest.TestCase):
299+
class TestDB_to_amplitude(common_utils.TorchaudioTestCase):
298300
def test_DB_to_amplitude(self):
299301
# Make some noise
300302
x = torch.rand(1000)

test/functional_cuda_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66

77
@common_utils.skipIfNoCuda
8-
class TestLFilterFloat32(Lfilter, common_utils.TestCase):
8+
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
99
dtype = torch.float32
1010
device = torch.device('cuda')
1111

1212

1313
@common_utils.skipIfNoCuda
14-
class TestLFilterFloat64(Lfilter, common_utils.TestCase):
14+
class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
1515
dtype = torch.float64
1616
device = torch.device('cuda')

test/kaldi_compatibility_cpu_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from .kaldi_compatibility_impl import Kaldi
55

66

7-
class TestKaldiFloat32(Kaldi, common_utils.TestCase):
7+
class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase):
88
dtype = torch.float32
99
device = torch.device('cpu')
1010

1111

12-
class TestKaldiFloat64(Kaldi, common_utils.TestCase):
12+
class TestKaldiFloat64(Kaldi, common_utils.PytorchTestCase):
1313
dtype = torch.float64
1414
device = torch.device('cpu')

test/kaldi_compatibility_cuda_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66

77
@common_utils.skipIfNoCuda
8-
class TestKaldiFloat32(Kaldi, common_utils.TestCase):
8+
class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase):
99
dtype = torch.float32
1010
device = torch.device('cuda')
1111

1212

1313
@common_utils.skipIfNoCuda
14-
class TestKaldiFloat64(Kaldi, common_utils.TestCase):
14+
class TestKaldiFloat64(Kaldi, common_utils.PytorchTestCase):
1515
dtype = torch.float64
1616
device = torch.device('cuda')

test/kaldi_compatibility_impl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def _load_params(path):
5555

5656

5757
class Kaldi(common_utils.TestBaseMixin):
58+
backend = 'sox'
59+
5860
def assert_equal(self, output, *, expected, rtol=None, atol=None):
5961
expected = expected.to(dtype=self.dtype, device=self.device)
6062
self.assertEqual(output, expected, rtol=rtol, atol=atol)

test/test_backend.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import torchaudio
44
from torchaudio._internal.module_utils import is_module_available
55

6+
from . import common_utils
67

7-
class BackendSwitch:
8+
9+
class BackendSwitchMixin:
810
"""Test set/get_audio_backend works"""
911
backend = None
1012
backend_module = None
@@ -21,20 +23,20 @@ def test_switch(self):
2123
assert torchaudio.info == self.backend_module.info
2224

2325

24-
class TestBackendSwitch_NoBackend(BackendSwitch, unittest.TestCase):
26+
class TestBackendSwitch_NoBackend(BackendSwitchMixin, common_utils.TorchaudioTestCase):
2527
backend = None
2628
backend_module = torchaudio.backend.no_backend
2729

2830

2931
@unittest.skipIf(
3032
not is_module_available('torchaudio._torchaudio'),
3133
'torchaudio C++ extension not available')
32-
class TestBackendSwitch_SoX(BackendSwitch, unittest.TestCase):
34+
class TestBackendSwitch_SoX(BackendSwitchMixin, common_utils.TorchaudioTestCase):
3335
backend = 'sox'
3436
backend_module = torchaudio.backend.sox_backend
3537

3638

3739
@unittest.skipIf(not is_module_available('soundfile'), '"soundfile" not available')
38-
class TestBackendSwitch_soundfile(BackendSwitch, unittest.TestCase):
40+
class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase):
3941
backend = 'soundfile'
4042
backend_module = torchaudio.backend.soundfile_backend

test/test_batch_consistency.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
import unittest
33

44
import torch
5-
from torch.testing._internal.common_utils import TestCase
65
import torchaudio
76
import torchaudio.functional as F
87

98
from . import common_utils
109

1110

12-
class TestFunctional(TestCase):
11+
class TestFunctional(common_utils.TorchaudioTestCase):
12+
backend = 'default'
1313
"""Test functions defined in `functional` module"""
1414
def assert_batch_consistency(
1515
self, functional, tensor, *args, batch_size=1, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
@@ -98,12 +98,15 @@ def test_sliding_window_cmn(self):
9898
self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=False, norm_vars=False)
9999

100100
def test_vad(self):
101+
common_utils.set_audio_backend('default')
101102
filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
102103
waveform, sample_rate = torchaudio.load(filepath)
103104
self.assert_batch_consistencies(F.vad, waveform, sample_rate=sample_rate)
104105

105106

106-
class TestTransforms(TestCase):
107+
class TestTransforms(common_utils.TorchaudioTestCase):
108+
backend = 'default'
109+
107110
"""Test suite for classes defined in `transforms` module"""
108111
def test_batch_AmplitudeToDB(self):
109112
spec = torch.rand((6, 201))

test/test_compliance_kaldi.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
import math
21
import os
2+
import math
3+
import unittest
4+
35
import torch
46
import torchaudio
57
import torchaudio.compliance.kaldi as kaldi
6-
import unittest
78

89
from . import common_utils
910
from .compliance import utils as compliance_utils
10-
from .common_utils import AudioBackendScope, BACKENDS
1111

1212

1313
def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
@@ -46,7 +46,10 @@ def first_sample_of_frame(frame, window_size, window_shift, snip_edges):
4646
window[f, s] = wave[s_in_wave]
4747

4848

49-
class Test_Kaldi(unittest.TestCase):
49+
@common_utils.skipIfNoSoxBackend
50+
class Test_Kaldi(common_utils.TorchaudioTestCase):
51+
backend = 'sox'
52+
5053
test_filepath = common_utils.get_asset_path('kaldi_file.wav')
5154
test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav')
5255
kaldi_output_dir = common_utils.get_asset_path('kaldi')
@@ -162,8 +165,6 @@ def test_mfcc_empty(self):
162165
# Passing in an empty tensor should result in an error
163166
self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0))
164167

165-
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
166-
@AudioBackendScope("sox")
167168
def test_resample_waveform(self):
168169
def get_output_fn(sound, args):
169170
output = kaldi.resample_waveform(sound, args[1], args[2])

0 commit comments

Comments
 (0)