Skip to content

Commit 5e1bf0f

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
InputGen: introduce structural engine
Reviewed By: SS-JIA Differential Revision: D52047181 fbshipit-source-id: 925bbcd95f969dd1373f9df8e3de192f245d55b3
1 parent 9341313 commit 5e1bf0f

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

inputgen/argument/engine.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import random
8+
from typing import Any, List, Optional
9+
10+
from inputgen.argument.type import ArgType
11+
from inputgen.attribute.engine import AttributeEngine
12+
from inputgen.attribute.model import Attribute
13+
from inputgen.specs.model import Constraint
14+
15+
16+
class StructuralEngine:
17+
def __init__(
18+
self,
19+
argtype: ArgType,
20+
constraints: List[Constraint],
21+
deps: List[Any],
22+
valid: bool,
23+
):
24+
self.argtype = argtype
25+
self.constraints = constraints
26+
self.deps = deps
27+
self.valid = valid
28+
self.hierarchy = StructuralEngine.hierarchy(argtype)
29+
30+
@staticmethod
31+
def hierarchy(argtype) -> List[Attribute]:
32+
"""Return the structural hierarchy for a given argument type"""
33+
if argtype.is_tensor_list():
34+
return [Attribute.LENGTH, Attribute.RANK, Attribute.SIZE]
35+
elif argtype.is_tensor():
36+
return [Attribute.RANK, Attribute.SIZE]
37+
elif argtype.is_list():
38+
return [Attribute.LENGTH, Attribute.VALUE]
39+
else:
40+
return [Attribute.VALUE]
41+
42+
def gen_structure_with_depth_and_length(
43+
self, depth: int, length: int, focus: Attribute
44+
):
45+
if length == 0:
46+
yield ()
47+
return
48+
49+
attr = self.hierarchy[-(depth + 1)]
50+
focus_ixs = range(length) if focus == attr else (random.choice(range(length)),)
51+
for focus_ix in focus_ixs:
52+
values = [()]
53+
for ix in range(length):
54+
if ix == focus_ix:
55+
elements = self.gen_structure_with_depth(depth, focus, length, ix)
56+
else:
57+
elements = self.gen_structure_with_depth(depth, None, length, ix)
58+
new_values = []
59+
for elem in elements:
60+
new_values += [t + (elem,) for t in values]
61+
values = new_values
62+
yield from values
63+
64+
def gen_structure_with_depth(
65+
self,
66+
depth: int,
67+
focus: Attribute,
68+
length: Optional[int] = None,
69+
ix: Optional[int] = None,
70+
):
71+
attr = self.hierarchy[-(depth + 1)]
72+
73+
if ix is not None:
74+
args = (self.deps, length, ix)
75+
elif length is not None:
76+
args = (
77+
self.deps,
78+
length,
79+
)
80+
else:
81+
args = (self.deps,)
82+
83+
values = AttributeEngine(attr, self.constraints, self.valid, self.argtype).gen(
84+
focus, *args
85+
)
86+
87+
for v in values:
88+
if depth == 0:
89+
yield v
90+
else:
91+
yield from self.gen_structure_with_depth_and_length(depth - 1, v, focus)
92+
93+
def gen(self, focus: Attribute):
94+
depth = len(self.hierarchy) - 1
95+
yield from self.gen_structure_with_depth(depth, focus)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from inputgen.argument.engine import StructuralEngine
10+
from inputgen.argument.type import ArgType
11+
from inputgen.attribute.model import Attribute
12+
from inputgen.specs.model import ConstraintProducer as cp
13+
14+
15+
class TestStructuralEngine(unittest.TestCase):
16+
def test_engine(self):
17+
constraints = [
18+
cp.Rank.Le(lambda deps: deps[0] + 2),
19+
cp.Size.NotIn(lambda deps, length, ix: [1, 3]),
20+
cp.Size.Le(lambda deps, length, ix: 5),
21+
cp.Value.Ne(lambda deps: 0),
22+
]
23+
deps = [2]
24+
25+
engine = StructuralEngine(ArgType.Tensor, constraints, deps, True)
26+
for s in engine.gen(Attribute.VALUE):
27+
self.assertTrue(1 <= len(s) <= 4)
28+
self.assertTrue(all(v in [2, 4, 5] for v in s))
29+
30+
31+
if __name__ == "__main__":
32+
unittest.main()

0 commit comments

Comments
 (0)