Skip to content

Commit 870c3ce

Browse files
committed
[Backend Tester] Add TorchAudio tests
ghstack-source-id: 01e7a75 ghstack-comment-id: 3095005640 Pull-Request: #12666
1 parent fb28761 commit 870c3ce

File tree

3 files changed

+93
-2
lines changed

3 files changed

+93
-2
lines changed

backends/test/suite/models/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@ def _expand_test(cls, test_name: str) -> None:
6767
test_func = getattr(cls, test_name)
6868
supports_dynamic_shapes = getattr(test_func, "supports_dynamic_shapes", True)
6969
dynamic_shape_values = [True, False] if supports_dynamic_shapes else [False]
70+
dtypes = getattr(test_func, "dtypes", DTYPES)
7071

71-
for flow, dtype, use_dynamic_shapes in itertools.product(get_test_flows(), DTYPES, dynamic_shape_values):
72+
for flow, dtype, use_dynamic_shapes in itertools.product(get_test_flows(), dtypes, dynamic_shape_values):
7273
_create_test(cls, test_func, flow, dtype, use_dynamic_shapes)
7374
delattr(cls, test_name)
7475

@@ -81,10 +82,17 @@ def model_test_cls(cls) -> Callable | None:
8182
return cls
8283

8384

84-
def model_test_params(supports_dynamic_shapes: bool) -> Callable:
85+
def model_test_params(
86+
supports_dynamic_shapes: bool = True,
87+
dtypes: list[torch.dtype] | None = None,
88+
) -> Callable:
8589
""" Optional parameter decorator for model tests. Specifies test pararameters. Only valid with a class decorated by model_test_cls. """
8690
def inner_decorator(func: Callable) -> Callable:
8791
setattr(func, "supports_dynamic_shapes", supports_dynamic_shapes)
92+
93+
if dtypes is not None:
94+
setattr(func, "dtypes", dtypes)
95+
8896
return func
8997
return inner_decorator
9098

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
import torchaudio
11+
import unittest
12+
13+
from executorch.backends.test.suite.models import model_test_params, model_test_cls, run_model_test
14+
from torch.export import Dim
15+
from typing import Callable, Tuple
16+
17+
#
18+
# This file contains model integration tests for supported torchaudio models.
19+
#
20+
21+
class PatchedConformer(torch.nn.Module):
22+
"""
23+
A lightly modified version of the top-level Conformer module, such that it can be exported.
24+
Instead of taking lengths and computing the padding mask, it takes the padding mask directly.
25+
See https://github.com/pytorch/audio/blob/main/src/torchaudio/models/conformer.py#L215
26+
"""
27+
28+
def __init__(self, conformer):
29+
super().__init__()
30+
self.conformer = conformer
31+
32+
def forward(self, input: torch.Tensor, encoder_padding_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
33+
x = input.transpose(0, 1)
34+
for layer in self.conformer.conformer_layers:
35+
x = layer(x, encoder_padding_mask)
36+
return x.transpose(0, 1)
37+
38+
@model_test_cls
39+
class TorchAudio(unittest.TestCase):
40+
@model_test_params(dtypes=[torch.float32], supports_dynamic_shapes=False)
41+
def test_conformer(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
42+
inner_model = torchaudio.models.Conformer(
43+
input_dim=80,
44+
num_heads=4,
45+
ffn_dim=128,
46+
num_layers=4,
47+
depthwise_conv_kernel_size=31,
48+
)
49+
model = PatchedConformer(inner_model)
50+
lengths = torch.randint(1, 400, (10,))
51+
52+
encoder_padding_mask = torchaudio.models.conformer._lengths_to_padding_mask(lengths)
53+
inputs = (
54+
torch.rand(10, int(lengths.max()), 80),
55+
encoder_padding_mask,
56+
)
57+
58+
run_model_test(model, inputs, dtype, None, tester_factory)
59+
60+
@model_test_params(dtypes=[torch.float32])
61+
def test_wav2letter(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
62+
model = torchaudio.models.Wav2Letter()
63+
inputs = (torch.randn(1, 1, 1024, dtype=dtype),)
64+
dynamic_shapes = {
65+
"x": {
66+
2: Dim("d", min=900, max=1024),
67+
}
68+
} if use_dynamic_shapes else None
69+
run_model_test(model, inputs, dtype, dynamic_shapes, tester_factory)
70+
71+
@unittest.skip("This model times out on all backends.")
72+
def test_wavernn(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
73+
model = torchaudio.models.WaveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200).eval()
74+
75+
# See https://docs.pytorch.org/audio/stable/generated/torchaudio.models.WaveRNN.html#forward
76+
inputs = (
77+
torch.randn(1, 1, (64 - 5 + 1) * 200), # waveform
78+
torch.randn(1, 1, 128, 64), # specgram
79+
)
80+
81+
run_model_test(model, inputs, dtype, None, tester_factory)

backends/test/suite/runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def build_result(
5151
result=result,
5252
error=error,
5353
)
54+
55+
model.eval()
5456

5557
# Ensure the model can run in eager mode.
5658
try:

0 commit comments

Comments
 (0)