|
1 | | -import pytest |
2 | | - |
3 | 1 | import torch |
4 | 2 | from torchaudio.models import Wav2Letter |
5 | 3 |
|
6 | 4 |
|
7 | 5 | class TestWav2Letter: |
8 | | - @pytest.mark.parametrize('batch_size', [2]) |
9 | | - @pytest.mark.parametrize('num_features', [1]) |
10 | | - @pytest.mark.parametrize('num_classes', [40]) |
11 | | - @pytest.mark.parametrize('input_length', [320]) |
12 | | - def test_waveform(self, batch_size, num_features, num_classes, input_length): |
13 | | - model = Wav2Letter() |
| 6 | + |
| 7 | + def test_waveform(self): |
| 8 | + batch_size = 2 |
| 9 | + num_features = 1 |
| 10 | + num_classes = 40 |
| 11 | + input_length = 320 |
| 12 | + |
| 13 | + model = Wav2Letter(num_classes=num_classes, num_features=num_features) |
14 | 14 |
|
15 | 15 | x = torch.rand(batch_size, num_features, input_length) |
16 | 16 | out = model(x) |
17 | 17 |
|
18 | 18 | assert out.size() == (batch_size, num_classes, 2) |
19 | 19 |
|
20 | | - @pytest.mark.parametrize('batch_size', [2]) |
21 | | - @pytest.mark.parametrize('num_features', [13]) |
22 | | - @pytest.mark.parametrize('num_classes', [40]) |
23 | | - @pytest.mark.parametrize('input_length', [2]) |
24 | | - def test_mfcc(self, batch_size, num_features, num_classes, input_length): |
25 | | - model = Wav2Letter(input_type="mfcc", num_features=13) |
| 20 | + def test_mfcc(self): |
| 21 | + batch_size = 2 |
| 22 | + num_features = 13 |
| 23 | + num_classes = 40 |
| 24 | + input_length = 2 |
| 25 | + |
| 26 | + model = Wav2Letter(num_classes=num_classes, input_type="mfcc", num_features=num_features) |
26 | 27 |
|
27 | 28 | x = torch.rand(batch_size, num_features, input_length) |
28 | 29 | out = model(x) |
|
0 commit comments