Skip to content

Commit e5fec18

Browse files
committed
[Backend Tester] Add FACTO operator test skeleton
ghstack-source-id: a082c5e ghstack-comment-id: 3003288787 Pull-Request: #11953
1 parent 91c9ffa commit e5fec18

File tree

4 files changed

+356
-9
lines changed

4 files changed

+356
-9
lines changed

backends/test/operators/__init__.py

Whitespace-only changes.
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import facto.specdb.function as fn
2+
import torch
3+
4+
from facto.inputgen.argument.type import ArgType
5+
from facto.inputgen.specs.model import (
6+
ConstraintProducer as cp,
7+
InKwArg,
8+
InPosArg,
9+
OutArg,
10+
Spec,
11+
)
12+
13+
"""
14+
This file contains FACTO operator specs for ops not in the standard FACTO db. This mainly
15+
includes ops not in the Core ATen op set and preserved by a backend, such as linear.
16+
"""
17+
18+
LiNEAR_DEFAULT_SPEC = Spec(
19+
op="linear.default", # (Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
20+
inspec=[
21+
InPosArg(
22+
ArgType.Tensor,
23+
name="input",
24+
deps=[1, 2],
25+
constraints=[
26+
cp.Dtype.Eq(lambda deps: deps[0].dtype),
27+
cp.Rank.Ge(lambda deps: 2),
28+
cp.Size.In(
29+
lambda deps, r, d: fn.broadcast_to(
30+
(fn.safe_size(deps[0], 0), fn.safe_size(deps[1], 1)), r, d
31+
)
32+
),
33+
],
34+
),
35+
InPosArg(
36+
ArgType.Tensor,
37+
name="weight",
38+
constraints=[
39+
cp.Dtype.Ne(lambda deps: torch.bool),
40+
cp.Rank.Eq(lambda deps: 2),
41+
],
42+
),
43+
InPosArg(
44+
ArgType.Tensor,
45+
name="bias",
46+
deps=[1],
47+
constraints=[
48+
cp.Dtype.Eq(lambda deps: deps[0].dtype),
49+
cp.Rank.Eq(lambda deps: 2),
50+
cp.Size.Eq(
51+
lambda deps, r, d: fn.safe_size(deps[0], 1) if d == 0 else None
52+
),
53+
],
54+
),
55+
],
56+
outspec=[
57+
OutArg(ArgType.Tensor),
58+
],
59+
)
60+
61+
_extra_specs = [
62+
LiNEAR_DEFAULT_SPEC,
63+
]
64+
65+
ExtraSpecDB: dict[str, Spec] = {s.op: s for s in _extra_specs}
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import copy
10+
import functools
11+
import traceback
12+
import unittest
13+
from typing import Any, Callable, List, OrderedDict, Sequence, Tuple
14+
15+
import torch
16+
from executorch.backends.apple.coreml.test.tester import CoreMLTester
17+
from executorch.backends.test.harness.tester import Tester as TesterBase
18+
from executorch.backends.xnnpack.test.tester.tester import (
19+
Tester as XnnpackTester,
20+
ToEdgeTransformAndLower,
21+
)
22+
from facto.inputgen.argtuple.gen import ArgumentTupleGenerator
23+
from facto.inputgen.specs.model import Constraint, ConstraintProducer as cp, Spec
24+
from facto.inputgen.utils.random_manager import random_manager
25+
from facto.inputgen.variable.type import ScalarDtype
26+
from facto.specdb.db import SpecDictDB
27+
from torch._ops import OpOverload
28+
29+
from .facto_specs import ExtraSpecDB
30+
31+
CombinedSpecDB = SpecDictDB | ExtraSpecDB
32+
33+
COMMON_TENSOR_CONSTRAINTS = [
34+
cp.Rank.Ge(lambda deps: 1),
35+
cp.Rank.Le(lambda deps: 4),
36+
cp.Size.Ge(lambda deps, r, d: 1),
37+
cp.Size.Le(lambda deps, r, d: 2**9),
38+
]
39+
40+
COMMON_SCALAR_CONSTRAINS = [
41+
cp.Value.Ge(lambda deps, dtype: -1000),
42+
cp.Value.Le(lambda deps, dtype: 1000),
43+
]
44+
45+
# Operator args are treated as runtime graph inputs if the argument name is
46+
# in this list.
47+
RUNTIME_INPUT_NAMES = {
48+
"self",
49+
"tensor",
50+
"other",
51+
}
52+
53+
54+
def _patch_spec(spec: Spec) -> Spec:
55+
spec = copy.deepcopy(spec)
56+
for inspec in spec.inspec:
57+
if inspec.type.is_tensor():
58+
inspec.constraints.extend(COMMON_TENSOR_CONSTRAINTS)
59+
elif inspec.type.is_scalar():
60+
inspec.constraints.extend(COMMON_SCALAR_CONSTRAINS)
61+
return spec
62+
63+
64+
class OpModel(torch.nn.Module):
65+
"""
66+
Wraps a single torch operator in an nn.Module.
67+
"""
68+
69+
def __init__(
70+
self,
71+
op: OpOverload,
72+
runtime_input_count: int,
73+
fixed_args: Sequence[Any],
74+
fixed_kwargs: dict[str, Any],
75+
):
76+
super().__init__()
77+
self.op = op
78+
self.runtime_input_count = runtime_input_count
79+
self.fixed_kwargs = fixed_kwargs
80+
81+
# Register parameters for fixed tensors. Some things will choke on
82+
# constant tensor weights, for example.
83+
new_args = []
84+
for i, arg in enumerate(fixed_args):
85+
if isinstance(arg, torch.Tensor):
86+
param = torch.nn.Parameter(arg, requires_grad=False)
87+
param_name = f"arg_{i}_param"
88+
setattr(self, param_name, param)
89+
self.register_parameter(param_name, param)
90+
new_args.append(param)
91+
else:
92+
new_args.append(arg)
93+
self.fixed_args = tuple(new_args)
94+
95+
def forward(self, *args, **kwargs):
96+
return self.op(*(args + self.fixed_args), **(kwargs | self.fixed_kwargs))
97+
98+
99+
class ConvModel(OpModel):
100+
def forward(self, *args, **kwargs):
101+
weight, bias, stride, padding, dilation, transposed, output_padding, groups = (
102+
self.fixed_args
103+
)
104+
105+
if not transposed:
106+
if len(weight.shape) == 3:
107+
op = torch.nn.functional.conv1d
108+
elif len(weight.shape) == 4:
109+
op = torch.nn.functional.conv2d
110+
elif len(weight.shape) == 5:
111+
op = torch.nn.functional.conv3d
112+
113+
return op(args[0], weight, bias, stride, padding, dilation, groups)
114+
else:
115+
if len(weight.shape) == 3:
116+
op = torch.nn.functional.conv_transpose1d
117+
elif len(weight.shape) == 4:
118+
op = torch.nn.functional.conv_transpose2d
119+
elif len(weight.shape) == 5:
120+
op = torch.nn.functional.conv_transpose3d
121+
122+
return op(
123+
args[0], weight, bias, stride, padding, output_padding, groups, dilation
124+
)
125+
126+
127+
def get_module_for_op(op: OpOverload):
128+
if op == torch.ops.aten.convolution.default:
129+
return ConvModel
130+
else:
131+
return OpModel
132+
133+
134+
class FactoTestsBase(unittest.TestCase):
135+
def __init__(self, tester_factory: Callable[[], TesterBase], *args, **kwargs):
136+
super().__init__(*args, **kwargs)
137+
self._tester_factory = tester_factory
138+
139+
@staticmethod
140+
def _generate_test(op_name: str) -> None:
141+
# Find the torch op with the given name.
142+
sections = op_name.split(".")
143+
torch_op = functools.reduce(getattr, sections, torch.ops.aten)
144+
145+
test_name = "test_" + op_name.replace(".", "_")
146+
test_body = lambda self: self._test_op(torch_op)
147+
148+
setattr(FactoTestsBase, test_name, test_body)
149+
150+
@staticmethod
151+
def get_runtime_input_count(spec: Spec):
152+
# Determine which inputs are fixed at tracing time (weights, for example),
153+
# vs inputs to the runtime graph. We currently assume that the runtime graph
154+
# inputs start at the beginning of the arg list and are contiguous.
155+
#
156+
# Args are consider to be runtime inputs if they are positional and are named
157+
# one of RUNTIME_INPUT_NAMES. If none match, we assume only the first arg is a
158+
# runtime input.
159+
runtime_input_count = 0
160+
for inspec in spec.inspec:
161+
is_runtime_input = (
162+
inspec.type.is_tensor() and inspec.name.lower() in RUNTIME_INPUT_NAMES
163+
)
164+
if is_runtime_input:
165+
runtime_input_count += 1
166+
else:
167+
break
168+
169+
return max(1, runtime_input_count)
170+
171+
def setUp(self):
172+
torch.set_printoptions(threshold=3)
173+
174+
def _test_op(self, op: OpOverload) -> None:
175+
random_manager.seed(0)
176+
177+
# Strip namespace
178+
op_name = op.name().split("::")[-1]
179+
180+
# Default to .default overload
181+
if "." not in op_name:
182+
op_name += ".default"
183+
184+
# Find and patch op spec
185+
if not op_name in CombinedSpecDB:
186+
raise ValueError(f"Operator {op_name} not found in SpecDictDB.")
187+
spec = _patch_spec(CombinedSpecDB[op_name])
188+
189+
runtime_input_count = FactoTestsBase.get_runtime_input_count(spec)
190+
191+
print(f"Op: {op_name}, {runtime_input_count} runtime inputs")
192+
193+
# Run test cases
194+
success_count_delegated = 0
195+
success_count_undelegated = 0
196+
fail_count = 0
197+
198+
i = 0
199+
for posargs, inkwargs, _ in ArgumentTupleGenerator(spec).gen():
200+
i += 1
201+
202+
try:
203+
if isinstance(posargs[0], torch.Tensor):
204+
# Temporary for getting around XNN crashes
205+
if posargs[0].dtype not in {torch.float32, torch.float16}:
206+
print("SKIPPING NON FLOAT CASE")
207+
continue
208+
209+
module_cls = get_module_for_op(op)
210+
model = module_cls(
211+
op, runtime_input_count, posargs[runtime_input_count:], inkwargs
212+
)
213+
214+
# Sanity check to make sure it runs in eager. This can present nicer error
215+
# messages sometimes compared to tracing.
216+
try:
217+
model(*posargs[:runtime_input_count])
218+
except Exception as e:
219+
print(f"Eager execution failed: {e}")
220+
continue
221+
222+
tester = (
223+
self._tester_factory(model, tuple(posargs[:runtime_input_count]))
224+
.export()
225+
.dump_artifact()
226+
# .to_edge_transform_and_lower(ToEdgeTransformAndLower(partitioners=[]))
227+
.to_edge_transform_and_lower()
228+
# .dump_artifact()
229+
)
230+
231+
is_delegated = any(
232+
n.target == torch._higher_order_ops.executorch_call_delegate
233+
for n in tester.stages[tester.cur].graph_module.graph.nodes
234+
if n.op == "call_function"
235+
)
236+
237+
# Only run the runtime test if the op was delegated.
238+
if is_delegated:
239+
(
240+
tester.to_executorch()
241+
.serialize()
242+
.run_method_and_compare_outputs()
243+
)
244+
245+
if is_delegated:
246+
success_count_delegated += 1
247+
else:
248+
success_count_undelegated += 1
249+
except Exception as e:
250+
fail_count += 1
251+
print(f"Args:")
252+
for arg in posargs:
253+
if isinstance(arg, torch.Tensor):
254+
print(f" {arg.dtype} {arg.shape}")
255+
else:
256+
print(f" {arg}")
257+
258+
traceback.print_exc()
259+
260+
print(
261+
f"{success_count_delegated + success_count_undelegated} PASS, {fail_count} FAIL"
262+
)
263+
print(
264+
f" {success_count_delegated} DELEGATED, {success_count_undelegated} UNDELEGATED"
265+
)
266+
267+
268+
# Programatically generate tests for each operator.
269+
for op_name in CombinedSpecDB.keys():
270+
FactoTestsBase._generate_test(op_name)
271+
272+
273+
# TODO Figure out where to put these
274+
class FactoTestsXNNPACK(FactoTestsBase):
275+
def __init__(self, *args, **kwargs):
276+
super().__init__(XnnpackTester, *args, **kwargs)
277+
278+
279+
class FactoTestsCoreML(FactoTestsBase):
280+
def __init__(self, *args, **kwargs):
281+
super().__init__(CoreMLTester, *args, **kwargs)

backends/xnnpack/test/tester/__init__.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,19 @@
1313
Serialize,
1414
Tester,
1515
ToEdge,
16+
ToEdge,
1617
ToEdgeTransformAndLower,
1718
ToExecutorch,
1819
)
1920

2021
__all__ = [
21-
Export,
22-
ToEdge,
23-
Partition,
24-
Quantize,
25-
RunPasses,
26-
ToEdgeTransformAndLower,
27-
Tester,
28-
Serialize,
29-
ToExecutorch,
22+
"Export",
23+
"ToEdge",
24+
"Partition",
25+
"Quantize",
26+
"RunPasses",
27+
"ToEdgeTransformAndLower",
28+
"Tester",
29+
"Serialize",
30+
"ToExecutorch",
3031
]

0 commit comments

Comments
 (0)