99
1010import logging
1111import os
12- import unittest
13-
14- from enum import Enum
15- from typing import Callable
1612
1713import executorch .backends .test .suite .flow
1814
19- import torch
20- from executorch .backends .test .suite .context import get_active_test_context , TestContext
2115from executorch .backends .test .suite .flow import TestFlow
22- from executorch .backends .test .suite .reporting import log_test_summary
23- from executorch .backends .test .suite .runner import run_test , runner_main
16+ from executorch .backends .test .suite .runner import runner_main
2417
2518logger = logging .getLogger (__name__ )
2619logger .setLevel (logging .INFO )
@@ -62,109 +55,6 @@ def get_test_flows() -> dict[str, TestFlow]:
6255 return _ALL_TEST_FLOWS
6356
6457
65- DTYPES = [
66- # torch.int8,
67- # torch.uint8,
68- # torch.int16,
69- # torch.uint16,
70- # torch.int32,
71- # torch.uint32,
72- # torch.int64,
73- # torch.uint64,
74- # torch.float16,
75- torch .float32 ,
76- # torch.float64,
77- ]
78-
79- FLOAT_DTYPES = [
80- torch .float16 ,
81- torch .float32 ,
82- torch .float64 ,
83- ]
84-
85-
86- # The type of test function. This controls the test generation and expected signature.
87- # Standard tests are run, as is. Dtype tests get a variant generated for each dtype and
88- # take an additional dtype parameter.
89- class TestType (Enum ):
90- STANDARD = 1
91- DTYPE = 2
92-
93-
94- # Function annotation for dtype tests. This instructs the test framework to run the test
95- # for each supported dtype and to pass dtype as a test parameter.
96- def dtype_test (func ):
97- func .test_type = TestType .DTYPE
98- return func
99-
100-
101- # Class annotation for operator tests. This triggers the test framework to register
102- # the tests.
103- def operator_test (cls ):
104- _create_tests (cls )
105- return cls
106-
107-
108- # Generate test cases for each backend flow.
109- def _create_tests (cls ):
110- for key in dir (cls ):
111- if key .startswith ("test_" ):
112- _expand_test (cls , key )
113-
114-
115- # Expand a test into variants for each registered flow.
116- def _expand_test (cls , test_name : str ):
117- test_func = getattr (cls , test_name )
118- for flow in get_test_flows ().values ():
119- _create_test_for_backend (cls , test_func , flow )
120- delattr (cls , test_name )
121-
122-
123- def _make_wrapped_test (
124- test_func : Callable ,
125- test_name : str ,
126- flow : TestFlow ,
127- params : dict | None = None ,
128- ):
129- def wrapped_test (self ):
130- with TestContext (test_name , flow .name , params ):
131- test_kwargs = params or {}
132- test_kwargs ["flow" ] = flow
133-
134- test_func (self , ** test_kwargs )
135-
136- wrapped_test ._name = test_name
137- wrapped_test ._flow = flow
138-
139- return wrapped_test
140-
141-
142- def _create_test_for_backend (
143- cls ,
144- test_func : Callable ,
145- flow : TestFlow ,
146- ):
147- test_type = getattr (test_func , "test_type" , TestType .STANDARD )
148-
149- if test_type == TestType .STANDARD :
150- wrapped_test = _make_wrapped_test (test_func , test_func .__name__ , flow )
151- test_name = f"{ test_func .__name__ } _{ flow .name } "
152- setattr (cls , test_name , wrapped_test )
153- elif test_type == TestType .DTYPE :
154- for dtype in DTYPES :
155- wrapped_test = _make_wrapped_test (
156- test_func ,
157- test_func .__name__ ,
158- flow ,
159- {"dtype" : dtype },
160- )
161- dtype_name = str (dtype )[6 :] # strip "torch."
162- test_name = f"{ test_func .__name__ } _{ dtype_name } _{ flow .name } "
163- setattr (cls , test_name , wrapped_test )
164- else :
165- raise NotImplementedError (f"Unknown test type { test_type } ." )
166-
167-
16858def load_tests (loader , suite , pattern ):
16959 package_dir = os .path .dirname (__file__ )
17060 discovered_suite = loader .discover (
@@ -174,32 +64,5 @@ def load_tests(loader, suite, pattern):
17464 return suite
17565
17666
177- class OperatorTest (unittest .TestCase ):
178- def _test_op (self , model , inputs , flow : TestFlow ):
179- context = get_active_test_context ()
180-
181- # This should be set in the wrapped test. See _make_wrapped_test above.
182- assert context is not None , "Missing test context."
183-
184- run_summary = run_test (
185- model ,
186- inputs ,
187- flow ,
188- context .test_name ,
189- context .params ,
190- )
191-
192- log_test_summary (run_summary )
193-
194- if not run_summary .result .is_success ():
195- if run_summary .result .is_backend_failure ():
196- raise RuntimeError ("Test failure." ) from run_summary .error
197- else :
198- # Non-backend failure indicates a bad test. Mark as skipped.
199- raise unittest .SkipTest (
200- f"Test failed for reasons other than backend failure. Error: { run_summary .error } "
201- )
202-
203-
20467if __name__ == "__main__" :
20568 runner_main ()
0 commit comments