Skip to content

Commit ab733e7

Browse files
jimchen90Ji Chen
andauthored
update wav2letter test (#722)
Co-authored-by: Ji Chen <[email protected]>
1 parent e9f19c3 commit ab733e7

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

test/test_models.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,29 @@
1-
import pytest
2-
31
import torch
42
from torchaudio.models import Wav2Letter
53

64

75
class TestWav2Letter:
8-
@pytest.mark.parametrize('batch_size', [2])
9-
@pytest.mark.parametrize('num_features', [1])
10-
@pytest.mark.parametrize('num_classes', [40])
11-
@pytest.mark.parametrize('input_length', [320])
12-
def test_waveform(self, batch_size, num_features, num_classes, input_length):
13-
model = Wav2Letter()
6+
7+
def test_waveform(self):
8+
batch_size = 2
9+
num_features = 1
10+
num_classes = 40
11+
input_length = 320
12+
13+
model = Wav2Letter(num_classes=num_classes, num_features=num_features)
1414

1515
x = torch.rand(batch_size, num_features, input_length)
1616
out = model(x)
1717

1818
assert out.size() == (batch_size, num_classes, 2)
1919

20-
@pytest.mark.parametrize('batch_size', [2])
21-
@pytest.mark.parametrize('num_features', [13])
22-
@pytest.mark.parametrize('num_classes', [40])
23-
@pytest.mark.parametrize('input_length', [2])
24-
def test_mfcc(self, batch_size, num_features, num_classes, input_length):
25-
model = Wav2Letter(input_type="mfcc", num_features=13)
20+
def test_mfcc(self):
21+
batch_size = 2
22+
num_features = 13
23+
num_classes = 40
24+
input_length = 2
25+
26+
model = Wav2Letter(num_classes=num_classes, input_type="mfcc", num_features=num_features)
2627

2728
x = torch.rand(batch_size, num_features, input_length)
2829
out = model(x)

0 commit comments

Comments
 (0)