Skip to content

Commit fd18020

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
InputGen: introduce meta argument tuple engine
Reviewed By: SS-JIA Differential Revision: D52308289 fbshipit-source-id: 46acf37c93509f8d3b34ca9c068557154692f493
1 parent b4093d4 commit fd18020

File tree

3 files changed

+177
-0
lines changed

3 files changed

+177
-0
lines changed

inputgen/argtuple/__init__.py

Whitespace-only changes.

inputgen/argtuple/engine.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from inputgen.argument.engine import MetaArgEngine
8+
from inputgen.argument.gen import ArgumentGenerator
9+
from inputgen.attribute.model import Attribute
10+
from inputgen.specs.model import Spec
11+
12+
13+
def reverse_topological_sort(graph):
14+
def dfs(node, visited, strack):
15+
visited[node] = True
16+
for neig in graph[node]:
17+
if not visited[neig]:
18+
dfs(neig, visited, strack)
19+
stack.append(node)
20+
21+
visited = {node: False for node in graph}
22+
stack = []
23+
24+
for node in graph:
25+
if not visited[node]:
26+
dfs(node, visited, stack)
27+
28+
return stack
29+
30+
31+
def inverse_permutation(permutation):
32+
n = len(permutation)
33+
inverse = [0] * n
34+
for i in range(n):
35+
inverse[permutation[i]] = i
36+
return inverse
37+
38+
39+
class MetaArgTupleEngine:
40+
def __init__(self, spec: Spec, out: bool = False):
41+
if out:
42+
raise NotImplementedError("out=True is not supported yet")
43+
self.args = spec.inspec
44+
self.order = self._sort_dependencies()
45+
self.order_inverse_perm = inverse_permutation(self.order)
46+
47+
def _generate_dependency_dag(self):
48+
graph = {}
49+
for i, arg in enumerate(self.args):
50+
if arg.deps is None:
51+
graph[i] = []
52+
else:
53+
graph[i] = arg.deps
54+
return graph
55+
56+
def _sort_dependencies(self):
57+
graph = self._generate_dependency_dag()
58+
return reverse_topological_sort(graph)
59+
60+
def _sort_meta_tuple(self, meta_tuple):
61+
return tuple(
62+
meta_tuple[self.order_inverse_perm[i]] for i in range(len(self.args))
63+
)
64+
65+
def _get_deps(self, meta_tuple, arg_deps):
66+
value_tuple = tuple(ArgumentGenerator(m).gen() for m in meta_tuple)
67+
return tuple(value_tuple[self.order_inverse_perm[ix]] for ix in arg_deps)
68+
69+
def gen_meta_tuples(self, valid: bool, focus_ix: int):
70+
tuples = [()]
71+
for ix in self.order:
72+
arg = self.args[ix]
73+
new_tuples = []
74+
focuses = [None]
75+
if ix == focus_ix:
76+
focuses = Attribute.hierarchy(arg.type)
77+
for focus in focuses:
78+
for meta_tuple in tuples:
79+
deps = self._get_deps(meta_tuple, arg.deps)
80+
engine = MetaArgEngine(arg.type, arg.constraints, deps, valid)
81+
for meta_arg in engine.gen(focus):
82+
new_tuples.append(meta_tuple + (meta_arg,))
83+
tuples = new_tuples
84+
return map(self._sort_meta_tuple, tuples)
85+
86+
def gen_valid_meta_tuples(self):
87+
valid_tuples = []
88+
for ix in range(len(self.args)):
89+
valid_tuples += self.gen_meta_tuples(True, ix)
90+
return valid_tuples
91+
92+
def gen_invalid_from_valid(self, valid_tuple):
93+
# Valid [str(x) for x in valid_tuple]
94+
valid_value_tuple = tuple(ArgumentGenerator(m).gen() for m in valid_tuple)
95+
invalid_tuples = []
96+
for ix in range(len(self.args)):
97+
arg = self.args[ix]
98+
# Generating invalid argument {ix} {arg.type}
99+
deps = tuple(valid_value_tuple[i] for i in arg.deps)
100+
for focus in Attribute.hierarchy(arg.type):
101+
engine = MetaArgEngine(arg.type, arg.constraints, deps, False)
102+
for meta_arg in engine.gen(focus):
103+
invalid_tuple = (
104+
valid_tuple[:ix] + (meta_arg,) + valid_tuple[ix + 1 :]
105+
)
106+
# Invalid {ix} {focus} [str(x) for x in invalid_tuple]
107+
invalid_tuples.append(invalid_tuple)
108+
invalid_tuples = list(set(invalid_tuples))
109+
return invalid_tuples
110+
111+
def gen_invalid_meta_tuples(self):
112+
valid_tuples = self.gen_valid_meta_tuples()
113+
invalid_tuples = []
114+
for valid_tuple in valid_tuples:
115+
invalids = self.gen_invalid_from_valid(valid_tuple)
116+
invalid_tuples += invalids
117+
invalid_tuples = list(set(invalid_tuples))
118+
return invalid_tuples
119+
120+
def gen(self, valid: bool = True):
121+
if valid:
122+
return self.gen_valid_meta_tuples()
123+
else:
124+
return self.gen_invalid_meta_tuples()

test/inputgen/test_argtuple_engine.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from inputgen.argtuple.engine import MetaArgTupleEngine
10+
from inputgen.argument.type import ArgType
11+
from inputgen.specs.model import ConstraintProducer as cp, InPosArg, Spec
12+
13+
14+
class TestMetaArgTupleEngine(unittest.TestCase):
15+
def test_size(self):
16+
spec = Spec(
17+
op="test_size", # (Tensor self, int dim) -> int
18+
inspec=[
19+
InPosArg(ArgType.Tensor, name="self"),
20+
InPosArg(
21+
ArgType.Dim,
22+
name="dim",
23+
deps=[0],
24+
constraints=[
25+
cp.Value.Ge(
26+
lambda deps: -deps[0].dim() if deps[0].dim() > 0 else None
27+
),
28+
cp.Value.Ge(lambda deps: -1 if deps[0].dim() == 0 else None),
29+
cp.Value.Le(
30+
lambda deps: deps[0].dim() - 1
31+
if deps[0].dim() > 0
32+
else None
33+
),
34+
cp.Value.Le(lambda deps: 0 if deps[0].dim() == 0 else None),
35+
],
36+
),
37+
],
38+
outspec=[],
39+
)
40+
41+
for meta_tuple in MetaArgTupleEngine(spec).gen(True):
42+
t, dim = meta_tuple
43+
self.assertEqual(t.argtype, ArgType.Tensor)
44+
self.assertEqual(dim.argtype, ArgType.Dim)
45+
if t.rank() == 0:
46+
self.assertTrue(dim.value in [-1, 0])
47+
else:
48+
self.assertTrue(dim.value >= -t.rank())
49+
self.assertTrue(dim.value < t.rank())
50+
51+
52+
if __name__ == "__main__":
53+
unittest.main()

0 commit comments

Comments
 (0)