Skip to content

Commit af31cb0

Browse files
Test all specs in SpecDB for valid inputs
ghstack-source-id: d5a7106 Pull-Request: #33
1 parent 20a10f4 commit af31cb0

File tree

4 files changed

+160
-0
lines changed

4 files changed

+160
-0
lines changed

facto/inputgen/argtuple/gen.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,51 @@ def gen(
8383
engine = MetaArgTupleEngine(self._modified_spec, out=out)
8484
for meta_tuple in engine.gen(valid=valid):
8585
yield self.gen_tuple(meta_tuple, out=out)
86+
87+
def gen_errors(
88+
self, op, *, valid: bool = True, out: bool = False
89+
) -> Generator[
90+
Tuple[List[Any], OrderedDict[str, Any], OrderedDict[str, Any]], Any, Any
91+
]:
92+
"""
93+
Generate input tuples and yield only those that don't behave as expected.
94+
95+
This function takes the same signature as gen() but with an additional
96+
op parameter. It filters inputs based on whether they behave as expected:
97+
- When valid=True: yields inputs that should be valid but DO error
98+
- When valid=False: yields inputs that should be invalid but DON'T error
99+
100+
Args:
101+
op: The operation/function to test the inputs against
102+
valid: Whether to generate valid or invalid inputs (same as gen())
103+
out: Whether to include output arguments (same as gen())
104+
105+
Yields:
106+
Tuples of (posargs, inkwargs, outargs) that don't behave as expected
107+
"""
108+
for posargs, inkwargs, outargs in self.gen(valid=valid, out=out):
109+
try:
110+
# Try to execute the operation with the generated inputs
111+
if out:
112+
# If there are output arguments, include them in the call
113+
op(*posargs, **inkwargs, **outargs)
114+
else:
115+
# Otherwise, just call with positional and keyword arguments
116+
op(*posargs, **inkwargs)
117+
118+
# If execution succeeds:
119+
if valid:
120+
# When valid=True, we expect success, so this is NOT a bug
121+
continue
122+
else:
123+
# When valid=False, we expect failure, so success IS a bug
124+
yield posargs, inkwargs, outargs
125+
126+
except Exception:
127+
# If execution fails:
128+
if valid:
129+
# When valid=True, we expect success, so failure IS a bug
130+
yield posargs, inkwargs, outargs
131+
else:
132+
# When valid=False, we expect failure, so this is NOT a bug
133+
continue

facto/utils/__init__.py

Whitespace-only changes.

facto/utils/ops.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
9+
10+
def get_op_overload(op_name: str):
11+
"""
12+
Get the torch operation overload from an operation name.
13+
14+
Args:
15+
op_name: Operation name in the format "op_base.overload" (e.g., "add.Tensor")
16+
17+
Returns:
18+
The torch operation overload (e.g., torch.ops.aten.add.Tensor)
19+
20+
Raises:
21+
AttributeError: If the operation is not found
22+
ValueError: If the operation name format is invalid
23+
"""
24+
if "." not in op_name:
25+
raise ValueError(
26+
f"Operation name '{op_name}' must contain a '.' to separate base and overload"
27+
)
28+
29+
parts = op_name.split(".")
30+
if len(parts) != 2:
31+
raise ValueError(
32+
f"Operation name '{op_name}' must be in format 'op_base.overload'"
33+
)
34+
35+
op_base, overload = parts
36+
37+
# Get the operation from torch.ops.aten
38+
if not hasattr(torch.ops.aten, op_base):
39+
raise AttributeError(f"Operation base '{op_base}' not found in torch.ops.aten")
40+
41+
op_obj = getattr(torch.ops.aten, op_base)
42+
43+
if not hasattr(op_obj, overload):
44+
raise AttributeError(
45+
f"Overload '{overload}' not found for operation '{op_base}'"
46+
)
47+
48+
return getattr(op_obj, overload)

test/specdb/test_specdb.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from facto.inputgen.argtuple.gen import ArgumentTupleGenerator
11+
from facto.specdb.db import SpecDictDB
12+
from facto.utils.ops import get_op_overload
13+
14+
15+
class TestSpecDBOperations(unittest.TestCase):
16+
"""Test class for validating all specs in SpecDB using gen_errors."""
17+
18+
def test_all_ops(self):
19+
"""
20+
Test all ops in SpecDB.
21+
22+
This test iterates through all operations in SpecDB and calls
23+
ArgumentTupleGenerator.gen_errors with valid=True, out=False
24+
for each operation. Each operation is tested as a subtest.
25+
"""
26+
# Get all operation names from SpecDB
27+
op_names = list(SpecDictDB.keys())
28+
29+
skip_ops = [
30+
"_native_batch_norm_legit_no_training.default",
31+
"addmm.default",
32+
"arange.default",
33+
"arange.start_step",
34+
"constant_pad_nd.default",
35+
"split_with_sizes_copy.default",
36+
]
37+
38+
for op_name in op_names:
39+
if op_name in skip_ops:
40+
continue
41+
with self.subTest(op=op_name):
42+
try:
43+
# Get the spec and operation
44+
spec = SpecDictDB[op_name]
45+
op = get_op_overload(op_name)
46+
generator = ArgumentTupleGenerator(spec)
47+
except Exception as e:
48+
# If we can't resolve the operation or there's another issue,
49+
# fail this subtest with a descriptive message
50+
self.fail(f"Failed to test operation {op_name}: {e}")
51+
52+
try:
53+
errors = list(generator.gen_errors(op, valid=True, out=False))
54+
except Exception as e:
55+
self.fail(f"Failed while testing operation {op_name}: {e}")
56+
57+
if len(errors) > 0:
58+
self.fail(
59+
f"Found {len(errors)} errors for {op_name} with valid=True, out=False"
60+
)
61+
62+
63+
if __name__ == "__main__":
64+
unittest.main()

0 commit comments

Comments
 (0)