Skip to content

Commit d780a06

Browse files
committed
[Backend Tester] Add model test skeleton and torchvision tests
ghstack-source-id: ac0e6f7 ghstack-comment-id: 3091440611 Pull-Request: #12658
1 parent 1bbe556 commit d780a06

File tree

4 files changed

+282
-3
lines changed

4 files changed

+282
-3
lines changed

backends/test/suite/discovery.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def _filter_tests(
6868

6969
def _is_test_enabled(test_case: unittest.TestCase, test_filter: TestFilter) -> bool:
7070
test_method = getattr(test_case, test_case._testMethodName)
71+
72+
if not hasattr(test_method, "_flow"):
73+
print(f"Test missing flow: {test_method}")
74+
7175
flow: TestFlow = getattr(test_method, "_flow")
7276

7377
if test_filter.backends is not None and flow.backend not in test_filter.backends:
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from executorch.backends.test.harness import Tester
10+
from executorch.backends.test.suite import get_test_flows
11+
from executorch.backends.test.suite.context import get_active_test_context, TestContext
12+
from executorch.backends.test.suite.flow import TestFlow
13+
from executorch.backends.test.suite.reporting import log_test_summary
14+
from executorch.backends.test.suite.runner import run_test
15+
from typing import Any, Callable
16+
17+
import itertools
18+
import os
19+
import torch
20+
import unittest
21+
22+
23+
DTYPES = [
24+
torch.float16,
25+
torch.float32,
26+
torch.float64,
27+
]
28+
29+
30+
def load_tests(loader, suite, pattern):
31+
package_dir = os.path.dirname(__file__)
32+
discovered_suite = loader.discover(
33+
start_dir=package_dir, pattern=pattern or "test_*.py"
34+
)
35+
suite.addTests(discovered_suite)
36+
return suite
37+
38+
39+
def _create_test(
40+
cls,
41+
test_func: Callable,
42+
flow: TestFlow,
43+
dtype: torch.dtype,
44+
use_dynamic_shapes: bool,
45+
):
46+
def wrapped_test(self):
47+
params = {
48+
"dtype": dtype,
49+
"use_dynamic_shapes": use_dynamic_shapes,
50+
}
51+
with TestContext(test_name, flow.name, params):
52+
test_func(self, dtype, use_dynamic_shapes, flow.tester_factory)
53+
54+
dtype_name = str(dtype)[6:] # strip "torch."
55+
test_name = f"{test_func.__name__}_{flow.name}_{dtype_name}"
56+
if use_dynamic_shapes:
57+
test_name += "_dynamic_shape"
58+
59+
setattr(wrapped_test, "_name", test_func.__name__)
60+
setattr(wrapped_test, "_flow", flow)
61+
62+
setattr(cls, test_name, wrapped_test)
63+
64+
65+
# Expand a test into variants for each registered flow.
66+
def _expand_test(cls, test_name: str) -> None:
67+
test_func = getattr(cls, test_name)
68+
supports_dynamic_shapes = getattr(test_func, "supports_dynamic_shapes", True)
69+
dynamic_shape_values = [True, False] if supports_dynamic_shapes else [False]
70+
71+
for flow, dtype, use_dynamic_shapes in itertools.product(get_test_flows(), DTYPES, dynamic_shape_values):
72+
_create_test(cls, test_func, flow, dtype, use_dynamic_shapes)
73+
delattr(cls, test_name)
74+
75+
76+
def model_test_cls(cls) -> Callable | None:
77+
""" Decorator for model tests. Handles generating test variants for each test flow and configuration. """
78+
for key in dir(cls):
79+
if key.startswith("test_"):
80+
_expand_test(cls, key)
81+
return cls
82+
83+
84+
def model_test_params(supports_dynamic_shapes: bool) -> Callable:
85+
""" Optional parameter decorator for model tests. Specifies test pararameters. Only valid with a class decorated by model_test_cls. """
86+
def inner_decorator(func: Callable) -> Callable:
87+
setattr(func, "supports_dynamic_shapes", supports_dynamic_shapes)
88+
return func
89+
return inner_decorator
90+
91+
92+
def run_model_test(
93+
model: torch.nn.Module,
94+
inputs: tuple[Any],
95+
dtype: torch.dtype,
96+
dynamic_shapes: Any | None,
97+
tester_factory: Callable[[], Tester],
98+
):
99+
model = model.to(dtype)
100+
context = get_active_test_context()
101+
102+
# This should be set in the wrapped test. See _create_test above.
103+
assert context is not None, "Missing test context."
104+
105+
run_summary = run_test(
106+
model,
107+
inputs,
108+
tester_factory,
109+
context.test_name,
110+
context.flow_name,
111+
context.params,
112+
dynamic_shapes=dynamic_shapes,
113+
)
114+
115+
log_test_summary(run_summary)
116+
117+
if not run_summary.result.is_success():
118+
if run_summary.result.is_backend_failure():
119+
raise RuntimeError("Test failure.") from run_summary.error
120+
else:
121+
# Non-backend failure indicates a bad test. Mark as skipped.
122+
raise unittest.SkipTest(
123+
f"Test failed for reasons other than backend failure. Error: {run_summary.error}"
124+
)
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
import torchvision
11+
import unittest
12+
13+
from executorch.backends.test.suite.models import model_test_params, model_test_cls, run_model_test
14+
from torch.export import Dim
15+
from typing import Callable
16+
17+
#
18+
# This file contains model integration tests for supported torchvision models.
19+
#
20+
21+
@model_test_cls
22+
class TorchVision(unittest.TestCase):
23+
def _test_cv_model(
24+
self,
25+
model: torch.nn.Module,
26+
dtype: torch.dtype,
27+
use_dynamic_shapes: bool,
28+
tester_factory: Callable,
29+
):
30+
# Test a CV model that follows the standard conventions.
31+
inputs = (
32+
torch.randn(1, 3, 224, 224, dtype=dtype),
33+
)
34+
35+
dynamic_shapes = (
36+
{
37+
2: Dim("height", min=1, max=16)*16,
38+
3: Dim("width", min=1, max=16)*16,
39+
},
40+
) if use_dynamic_shapes else None
41+
42+
run_model_test(model, inputs, dtype, dynamic_shapes, tester_factory)
43+
44+
45+
def test_alexnet(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
46+
model = torchvision.models.alexnet()
47+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
48+
49+
50+
def test_convnext_small(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
51+
model = torchvision.models.convnext_small()
52+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
53+
54+
55+
def test_densenet161(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
56+
model = torchvision.models.densenet161()
57+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
58+
59+
60+
def test_efficientnet_b4(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
61+
model = torchvision.models.efficientnet_b4()
62+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
63+
64+
65+
def test_efficientnet_v2_s(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
66+
model = torchvision.models.efficientnet_v2_s()
67+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
68+
69+
70+
def test_googlenet(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
71+
model = torchvision.models.googlenet()
72+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
73+
74+
75+
def test_inception_v3(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
76+
model = torchvision.models.inception_v3()
77+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
78+
79+
80+
@model_test_params(supports_dynamic_shapes=False)
81+
def test_maxvit_t(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
82+
model = torchvision.models.maxvit_t()
83+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
84+
85+
86+
def test_mnasnet1_0(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
87+
model = torchvision.models.mnasnet1_0()
88+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
89+
90+
91+
def test_mobilenet_v2(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
92+
model = torchvision.models.mobilenet_v2()
93+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
94+
95+
96+
def test_mobilenet_v3_small(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
97+
model = torchvision.models.mobilenet_v3_small()
98+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
99+
100+
101+
def test_regnet_y_1_6gf(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
102+
model = torchvision.models.regnet_y_1_6gf()
103+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
104+
105+
106+
def test_resnet50(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
107+
model = torchvision.models.resnet50()
108+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
109+
110+
111+
def test_resnext50_32x4d(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
112+
model = torchvision.models.resnext50_32x4d()
113+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
114+
115+
116+
def test_shufflenet_v2_x1_0(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
117+
model = torchvision.models.shufflenet_v2_x1_0()
118+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
119+
120+
121+
def test_squeezenet1_1(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
122+
model = torchvision.models.squeezenet1_1()
123+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
124+
125+
126+
def test_swin_v2_t(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
127+
model = torchvision.models.swin_v2_t()
128+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
129+
130+
131+
def test_vgg11(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
132+
model = torchvision.models.vgg11()
133+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
134+
135+
136+
@model_test_params(supports_dynamic_shapes=False)
137+
def test_vit_b_16(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
138+
model = torchvision.models.vit_b_16()
139+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
140+
141+
142+
def test_wide_resnet50_2(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable):
143+
model = torchvision.models.wide_resnet50_2()
144+
self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory)
145+

backends/test/suite/runner.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import re
44
import unittest
55

6-
from typing import Callable
6+
from typing import Any, Callable
77

88
import torch
99

1010
from executorch.backends.test.harness import Tester
11+
from executorch.backends.test.harness.stages import StageType
1112
from executorch.backends.test.suite.discovery import discover_tests, TestFilter
1213
from executorch.backends.test.suite.reporting import (
1314
begin_test_session,
@@ -20,17 +21,19 @@
2021

2122
# A list of all runnable test suites and the corresponding python package.
2223
NAMED_SUITES = {
24+
"models": "executorch.backends.test.suite.models",
2325
"operators": "executorch.backends.test.suite.operators",
2426
}
2527

2628

2729
def run_test( # noqa: C901
2830
model: torch.nn.Module,
29-
inputs: any,
31+
inputs: Any,
3032
tester_factory: Callable[[], Tester],
3133
test_name: str,
3234
flow_name: str,
3335
params: dict | None,
36+
dynamic_shapes: Any | None = None,
3437
) -> TestCaseSummary:
3538
"""
3639
Top-level test run function for a model, input set, and tester. Handles test execution
@@ -61,7 +64,10 @@ def build_result(
6164
return build_result(TestResult.UNKNOWN_FAIL, e)
6265

6366
try:
64-
tester.export()
67+
# TODO Use Tester dynamic_shapes parameter once input generation can properly handle derived dims.
68+
tester.export(
69+
tester._get_default_stage(StageType.EXPORT, dynamic_shapes=dynamic_shapes),
70+
)
6571
except Exception as e:
6672
return build_result(TestResult.EXPORT_FAIL, e)
6773

0 commit comments

Comments
 (0)