Skip to content

Commit 114a555

Browse files
committed
[Backend Tester] Add FACTO operator test skeleton
ghstack-source-id: 917aa11 ghstack-comment-id: 3003288787 Pull-Request: #11953
1 parent 91c9ffa commit 114a555

File tree

8 files changed

+635
-10
lines changed

8 files changed

+635
-10
lines changed

backends/test/harness/tester.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
361361
ref,
362362
atol=atol,
363363
rtol=rtol,
364+
equal_nan=True,
364365
), (
365366
f"Output {i} does not match reference output.\n"
366367
f"\tGiven atol: {atol}, rtol: {rtol}.\n"

backends/test/operators/__init__.py

Whitespace-only changes.
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import facto.specdb.function as fn
2+
import torch
3+
4+
from facto.inputgen.argument.type import ArgType
5+
from facto.inputgen.specs.model import (
6+
ConstraintProducer as cp,
7+
InKwArg,
8+
InPosArg,
9+
OutArg,
10+
Spec,
11+
)
12+
13+
"""
14+
This file contains FACTO operator specs for ops not in the standard FACTO db. This mainly
15+
includes ops not in the Core ATen op set and preserved by a backend, such as linear.
16+
"""
17+
18+
LINEAR_DEFAULT_SPEC = Spec(
19+
op="linear.default", # (Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
20+
inspec=[
21+
InPosArg(
22+
ArgType.Tensor,
23+
name="input",
24+
deps=[1, 2],
25+
constraints=[
26+
cp.Dtype.Eq(lambda deps: deps[0].dtype),
27+
cp.Rank.Ge(lambda deps: 2),
28+
cp.Size.In(
29+
lambda deps, r, d: fn.broadcast_to(
30+
(fn.safe_size(deps[0], 0), fn.safe_size(deps[1], 1)), r, d
31+
)
32+
),
33+
],
34+
),
35+
InPosArg(
36+
ArgType.Tensor,
37+
name="weight",
38+
constraints=[
39+
cp.Dtype.Ne(lambda deps: torch.bool),
40+
cp.Rank.Eq(lambda deps: 2),
41+
],
42+
),
43+
InPosArg(
44+
ArgType.Tensor,
45+
name="bias",
46+
deps=[1],
47+
constraints=[
48+
cp.Dtype.Eq(lambda deps: deps[0].dtype),
49+
cp.Rank.Eq(lambda deps: 2),
50+
cp.Size.Eq(
51+
lambda deps, r, d: fn.safe_size(deps[0], 1) if d == 0 else None
52+
),
53+
],
54+
),
55+
],
56+
outspec=[
57+
OutArg(ArgType.Tensor),
58+
],
59+
)
60+
61+
_extra_specs = [
62+
LINEAR_DEFAULT_SPEC,
63+
]
64+
65+
ExtraSpecDB: dict[str, Spec] = {
66+
s.op: s for s in _extra_specs
67+
}
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import copy
10+
import functools
11+
import traceback
12+
from typing import Any, Callable, List, OrderedDict, Sequence, Tuple
13+
import unittest
14+
15+
import torch
16+
from executorch.backends.test.harness.tester import Tester as TesterBase
17+
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower, Tester as XnnpackTester
18+
from facto.inputgen.argtuple.gen import ArgumentTupleGenerator
19+
from facto.inputgen.specs.model import Constraint, ConstraintProducer as cp, Spec
20+
from facto.inputgen.utils.random_manager import random_manager
21+
from facto.inputgen.variable.type import ScalarDtype
22+
from facto.specdb.db import SpecDictDB
23+
from torch._ops import OpOverload
24+
25+
from .facto_specs import ExtraSpecDB
26+
27+
CombinedSpecDB = SpecDictDB | ExtraSpecDB
28+
29+
COMMON_TENSOR_CONSTRAINTS = [
30+
cp.Rank.Ge(lambda deps: 1),
31+
cp.Rank.Le(lambda deps: 4),
32+
cp.Size.Ge(lambda deps, r, d: 1),
33+
cp.Size.Le(lambda deps, r, d: 2**9),
34+
]
35+
36+
COMMON_SCALAR_CONSTRAINS = [
37+
cp.Value.Ge(lambda deps, dtype: -1000),
38+
cp.Value.Le(lambda deps, dtype: 1000),
39+
]
40+
41+
# Operator args are treated as runtime graph inputs if the argument name is
42+
# in this list.
43+
RUNTIME_INPUT_NAMES = {
44+
"self",
45+
"tensor",
46+
"other",
47+
}
48+
49+
def _patch_spec(spec: Spec) -> Spec:
50+
spec = copy.deepcopy(spec)
51+
for inspec in spec.inspec:
52+
if inspec.type.is_tensor():
53+
inspec.constraints.extend(COMMON_TENSOR_CONSTRAINTS)
54+
elif inspec.type.is_scalar():
55+
inspec.constraints.extend(COMMON_SCALAR_CONSTRAINS)
56+
return spec
57+
58+
class OpModel(torch.nn.Module):
59+
"""
60+
Wraps a single torch operator in an nn.Module.
61+
"""
62+
def __init__(
63+
self,
64+
op: OpOverload,
65+
runtime_input_count: int,
66+
fixed_args: Sequence[Any],
67+
fixed_kwargs: dict[str, Any]
68+
):
69+
super().__init__()
70+
self.op = op
71+
self.runtime_input_count = runtime_input_count
72+
self.fixed_kwargs = fixed_kwargs
73+
74+
# Register parameters for fixed tensors. Some things will choke on
75+
# constant tensor weights, for example.
76+
new_args = []
77+
for i, arg in enumerate(fixed_args):
78+
if isinstance(arg, torch.Tensor):
79+
param = torch.nn.Parameter(arg, requires_grad=False)
80+
param_name = f"arg_{i}_param"
81+
setattr(self, param_name, param)
82+
self.register_parameter(param_name, param)
83+
new_args.append(param)
84+
else:
85+
new_args.append(arg)
86+
self.fixed_args = tuple(new_args)
87+
88+
def forward(self, *args, **kwargs):
89+
return self.op(*(args + self.fixed_args), **(kwargs | self.fixed_kwargs))
90+
91+
class ConvModel(OpModel):
92+
def forward(self, *args, **kwargs):
93+
weight, bias, stride, padding, dilation, transposed, output_padding, groups = self.fixed_args
94+
95+
if not transposed:
96+
if len(weight.shape) == 3:
97+
op = torch.nn.functional.conv1d
98+
elif len(weight.shape) == 4:
99+
op = torch.nn.functional.conv2d
100+
elif len(weight.shape) == 5:
101+
op = torch.nn.functional.conv3d
102+
103+
return op(args[0], weight, bias, stride, padding, dilation, groups)
104+
else:
105+
if len(weight.shape) == 3:
106+
op = torch.nn.functional.conv_transpose1d
107+
elif len(weight.shape) == 4:
108+
op = torch.nn.functional.conv_transpose2d
109+
elif len(weight.shape) == 5:
110+
op = torch.nn.functional.conv_transpose3d
111+
112+
return op(args[0], weight, bias, stride, padding, output_padding, groups, dilation)
113+
114+
def get_module_for_op(op: OpOverload):
115+
if op == torch.ops.aten.convolution.default:
116+
return ConvModel
117+
else:
118+
return OpModel
119+
120+
class FactoTestsBase(unittest.TestCase):
121+
def __init__(self, tester_factory: Callable[[], TesterBase], *args, **kwargs):
122+
super().__init__(*args, **kwargs)
123+
self._tester_factory = tester_factory
124+
125+
@staticmethod
126+
def _generate_test(op_name: str) -> None:
127+
# Find the torch op with the given name.
128+
sections = op_name.split(".")
129+
torch_op = functools.reduce(getattr, sections, torch.ops.aten)
130+
131+
test_name = "test_" + op_name.replace(".", "_")
132+
test_body = lambda self: self._test_op(torch_op)
133+
134+
setattr(FactoTestsBase, test_name, test_body)
135+
136+
@staticmethod
137+
def get_runtime_input_count(spec: Spec):
138+
# Determine which inputs are fixed at tracing time (weights, for example),
139+
# vs inputs to the runtime graph. We currently assume that the runtime graph
140+
# inputs start at the beginning of the arg list and are contiguous.
141+
#
142+
# Args are consider to be runtime inputs if they are positional and are named
143+
# one of RUNTIME_INPUT_NAMES. If none match, we assume only the first arg is a
144+
# runtime input.
145+
runtime_input_count = 0
146+
for inspec in spec.inspec:
147+
is_runtime_input = (
148+
inspec.type.is_tensor() and
149+
inspec.name.lower() in RUNTIME_INPUT_NAMES
150+
)
151+
if is_runtime_input:
152+
runtime_input_count += 1
153+
else:
154+
break
155+
156+
return max(1, runtime_input_count)
157+
158+
def setUp(self):
159+
torch.set_printoptions(threshold=3)
160+
161+
def _test_op(self, op: OpOverload) -> None:
162+
random_manager.seed(0)
163+
164+
# Strip namespace
165+
op_name = op.name().split("::")[-1]
166+
167+
# Default to .default overload
168+
if "." not in op_name:
169+
op_name += ".default"
170+
171+
# Find and patch op spec
172+
if not op_name in CombinedSpecDB:
173+
raise ValueError(f"Operator {op_name} not found in SpecDictDB.")
174+
spec = _patch_spec(CombinedSpecDB[op_name])
175+
176+
runtime_input_count = FactoTestsBase.get_runtime_input_count(spec)
177+
178+
print(f"Op: {op_name}, {runtime_input_count} runtime inputs")
179+
180+
# Run test cases
181+
success_count_delegated = 0
182+
success_count_undelegated = 0
183+
fail_count = 0
184+
185+
i = 0
186+
for posargs, inkwargs, _ in ArgumentTupleGenerator(spec).gen():
187+
i += 1
188+
189+
try:
190+
if isinstance(posargs[0], torch.Tensor):
191+
# Temporary for getting around XNN crashes (https://github.com/pytorch/executorch/issues/10960).
192+
# TODO Re-enable when resolved.
193+
if posargs[0].dtype in {torch.int8, torch.uint8}:
194+
print("Skipping (u)int8 case.")
195+
continue
196+
197+
module_cls = get_module_for_op(op)
198+
model = module_cls(
199+
op,
200+
runtime_input_count,
201+
posargs[runtime_input_count:],
202+
inkwargs
203+
)
204+
205+
# Sanity check to make sure it runs in eager. This can present nicer error
206+
# messages sometimes compared to tracing.
207+
try:
208+
model(*posargs[:runtime_input_count])
209+
except Exception as e:
210+
print(f"Eager execution failed: {e}")
211+
continue
212+
213+
tester = self._tester_factory(
214+
model,
215+
tuple(posargs[:runtime_input_count])
216+
)
217+
218+
# Dynamo will also fail to handle some patterns that are valid in eager.
219+
try:
220+
tester.export()
221+
except Exception as e:
222+
print(f"Export failed.")
223+
continue
224+
225+
tester.to_edge_transform_and_lower()
226+
227+
is_delegated = any(
228+
n.target == torch._higher_order_ops.executorch_call_delegate
229+
for n in tester.stages[tester.cur].graph_module.graph.nodes
230+
if n.op == "call_function"
231+
)
232+
233+
# Only run the runtime test if the op was delegated.
234+
if is_delegated:
235+
(
236+
tester
237+
.to_executorch()
238+
.serialize()
239+
.run_method_and_compare_outputs()
240+
)
241+
242+
if is_delegated:
243+
success_count_delegated += 1
244+
else:
245+
success_count_undelegated += 1
246+
#finally:
247+
except Exception as e:
248+
fail_count += 1
249+
print(f"Args:")
250+
for arg in posargs:
251+
if isinstance(arg, torch.Tensor):
252+
print(f" {arg.dtype} {arg.shape}")
253+
else:
254+
print(f" {arg}")
255+
256+
traceback.print_exc()
257+
258+
print(f"{success_count_delegated + success_count_undelegated} PASS, {fail_count} FAIL")
259+
print(f" {success_count_delegated} DELEGATED, {success_count_undelegated} UNDELEGATED")
260+
261+
# Programatically generate tests for each operator.
262+
for op_name in CombinedSpecDB.keys():
263+
FactoTestsBase._generate_test(op_name)
264+
265+
# TODO Figure out where to put these
266+
class FactoTestsXNNPACK(FactoTestsBase):
267+
def __init__(self, *args, **kwargs):
268+
super().__init__(XnnpackTester, *args, **kwargs)
269+
270+
try:
271+
from executorch.backends.apple.coreml.test.tester import CoreMLTester
272+
class FactoTestsCoreML(FactoTestsBase):
273+
def __init__(self, *args, **kwargs):
274+
super().__init__(CoreMLTester, *args, **kwargs)
275+
except:
276+
print("Skipping Core ML facto tests as Core ML AOT is not available.")
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
add_executable(executorch-test-runner
2+
test_runner.cpp
3+
# TODO
4+
../../../runtime/platform/runtime.cpp
5+
)
6+
7+
target_link_libraries(
8+
executorch-test-runner
9+
PRIVATE executorch
10+
gflags
11+
extension_flat_tensor
12+
extension_flat_tensor_serialize
13+
extension_module
14+
extension_tensor
15+
optimized_native_cpu_ops_lib
16+
xnnpack_backend)

0 commit comments

Comments
 (0)