Skip to content

Commit 68d25b1

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
InputGen: introduce meta argument engine
Reviewed By: SS-JIA Differential Revision: D52047182 fbshipit-source-id: c9e2320bdb2dad23f8b0b1defca9ffdf11c7ff99
1 parent 5e1bf0f commit 68d25b1

File tree

2 files changed

+247
-2
lines changed

2 files changed

+247
-2
lines changed

inputgen/argument/engine.py

Lines changed: 179 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import random
8-
from typing import Any, List, Optional
8+
from typing import Any, List, Optional, Tuple, Union
99

10+
import torch
1011
from inputgen.argument.type import ArgType
1112
from inputgen.attribute.engine import AttributeEngine
1213
from inputgen.attribute.model import Attribute
13-
from inputgen.specs.model import Constraint
14+
from inputgen.attribute.solve import AttributeSolver
15+
from inputgen.specs.model import Constraint, ConstraintSuffix
16+
from inputgen.variable.type import ScalarDtype
1417

1518

1619
class StructuralEngine:
@@ -27,6 +30,11 @@ def __init__(
2730
self.valid = valid
2831
self.hierarchy = StructuralEngine.hierarchy(argtype)
2932

33+
self.gen_list_mode = set()
34+
for constraint in constraints:
35+
if constraint.suffix == ConstraintSuffix.GEN:
36+
self.gen_list_mode.add(constraint.attribute)
37+
3038
@staticmethod
3139
def hierarchy(argtype) -> List[Attribute]:
3240
"""Return the structural hierarchy for a given argument type"""
@@ -47,6 +55,11 @@ def gen_structure_with_depth_and_length(
4755
return
4856

4957
attr = self.hierarchy[-(depth + 1)]
58+
59+
if attr in self.gen_list_mode:
60+
yield from self.gen_structure_with_depth(depth, focus, length)
61+
return
62+
5063
focus_ixs = range(length) if focus == attr else (random.choice(range(length)),)
5164
for focus_ix in focus_ixs:
5265
values = [()]
@@ -93,3 +106,167 @@ def gen_structure_with_depth(
93106
def gen(self, focus: Attribute):
94107
depth = len(self.hierarchy) - 1
95108
yield from self.gen_structure_with_depth(depth, focus)
109+
110+
111+
class MetaArg:
112+
def __init__(
113+
self,
114+
argtype: ArgType,
115+
*,
116+
optional: bool = False,
117+
dtype: Optional[
118+
Union[torch.dtype, List[Optional[torch.dtype]], ScalarDtype]
119+
] = None,
120+
structure: Optional[Tuple] = None,
121+
value: Optional[Any] = None,
122+
):
123+
self.argtype = argtype
124+
self.optional = optional
125+
self.dtype = dtype
126+
self.structure = structure
127+
self.value = value
128+
129+
if not self.argtype.is_optional() and self.optional:
130+
raise ValueError("Only optional argtypes can have optional instances")
131+
132+
if self.argtype.is_tensor_list():
133+
if len(self.structure) != len(self.dtype):
134+
raise ValueError(
135+
"Structure and dtype must be same length when tensor list"
136+
)
137+
if self.argtype == ArgType.TensorList and any(
138+
d is None for d in self.dtype
139+
):
140+
raise ValueError("Only TensorOptList can have None in list of dtypes")
141+
142+
if not self.optional and Attribute.DTYPE not in Attribute.hierarchy(
143+
self.argtype
144+
):
145+
if argtype.is_list():
146+
self.value = list(self.structure)
147+
else:
148+
self.value = self.structure
149+
150+
def __str__(self):
151+
if self.optional:
152+
strval = "None"
153+
elif self.argtype.is_tensor_list():
154+
strval = (
155+
"["
156+
+ ", ".join(
157+
[
158+
f"{self.dtype[i]} {self.structure[i]}"
159+
for i in range(len(self.dtype))
160+
]
161+
)
162+
+ "]"
163+
)
164+
elif self.argtype.is_tensor():
165+
strval = f"{self.dtype} {self.structure}"
166+
else:
167+
strval = str(self.value)
168+
return f"{self.argtype} {strval}"
169+
170+
def length(self):
171+
if self.argtype.is_list():
172+
return len(self.structure)
173+
else:
174+
return None
175+
176+
def rank(self, ix=None):
177+
if self.argtype.is_tensor():
178+
return len(self.structure)
179+
elif self.argtype.is_tensor_list():
180+
if ix is None:
181+
return (len(s) for s in self.structure)
182+
else:
183+
return len(self.structure[ix])
184+
else:
185+
return None
186+
187+
188+
class MetaArgEngine:
189+
def __init__(
190+
self,
191+
argtype: ArgType,
192+
constraints: List[Constraint],
193+
deps: List[Any],
194+
valid: bool,
195+
):
196+
self.argtype = argtype
197+
self.constraints = constraints
198+
self.deps = deps
199+
self.valid = valid
200+
201+
def gen_structures(self, focus):
202+
if self.argtype.is_scalar():
203+
yield None
204+
else:
205+
yield from StructuralEngine(
206+
self.argtype, self.constraints, self.deps, self.valid
207+
).gen(focus)
208+
209+
def gen_dtypes(self, focus):
210+
if Attribute.DTYPE not in Attribute.hierarchy(self.argtype):
211+
return {None}
212+
engine = AttributeEngine(
213+
Attribute.DTYPE, self.constraints, self.valid, self.argtype
214+
)
215+
if self.argtype.is_scalar() and focus == Attribute.VALUE:
216+
# if focused on a scalar value, must generate all dtypes too
217+
focus = Attribute.DTYPE
218+
return engine.gen(focus, self.deps)
219+
220+
def gen_optional(self):
221+
engine = AttributeEngine(
222+
Attribute.OPTIONAL, self.constraints, self.valid, self.argtype
223+
)
224+
return True in engine.gen(Attribute.OPTIONAL, self.deps)
225+
226+
def gen_scalars(self, scalar_dtype, focus):
227+
engine = AttributeEngine(
228+
Attribute.VALUE, self.constraints, self.valid, self.argtype, scalar_dtype
229+
)
230+
return engine.gen(focus, self.deps, scalar_dtype)
231+
232+
def gen_value_spaces(self, focus, dtype, struct):
233+
if not self.argtype.is_tensor() and not self.argtype.is_tensor_list():
234+
return [None]
235+
solver = AttributeSolver(Attribute.VALUE, self.argtype)
236+
variables = list(
237+
solver.solve(self.constraints, focus, self.valid, self.deps, dtype, struct)
238+
)
239+
if focus == Attribute.VALUE:
240+
return [v.space for v in variables]
241+
else:
242+
return [random.choice(variables).space]
243+
244+
def gen(self, focus):
245+
# TODO(mcandales): Enable Tensor List generation
246+
247+
if focus in [None, Attribute.OPTIONAL]:
248+
if self.argtype.is_optional() and self.gen_optional():
249+
yield MetaArg(self.argtype, optional=True)
250+
if focus == Attribute.OPTIONAL:
251+
return
252+
253+
if self.argtype.is_scalar():
254+
scalar_dtypes = self.gen_dtypes(focus)
255+
for scalar_dtype in scalar_dtypes:
256+
for value in self.gen_scalars(scalar_dtype, focus):
257+
yield MetaArg(self.argtype, dtype=scalar_dtype, value=value)
258+
else:
259+
if focus == Attribute.DTYPE:
260+
for dtype in self.gen_dtypes(focus):
261+
for struct in self.gen_structures(focus):
262+
for space in self.gen_value_spaces(focus, dtype, struct):
263+
yield MetaArg(
264+
self.argtype, dtype=dtype, structure=struct, value=space
265+
)
266+
else:
267+
for struct in self.gen_structures(focus):
268+
for dtype in self.gen_dtypes(focus):
269+
for space in self.gen_value_spaces(focus, dtype, struct):
270+
yield MetaArg(
271+
self.argtype, dtype=dtype, structure=struct, value=space
272+
)

test/inputgen/test_meta_arg_engine.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from inputgen.argument.engine import MetaArgEngine
10+
from inputgen.argument.type import ArgType
11+
from inputgen.attribute.model import Attribute
12+
from inputgen.specs.model import ConstraintProducer as cp
13+
from inputgen.variable.type import SUPPORTED_TENSOR_DTYPES
14+
15+
16+
class TestMetaArgEngine(unittest.TestCase):
17+
def test_tensor(self):
18+
constraints = [
19+
cp.Rank.Le(lambda deps: deps[0] + 2),
20+
cp.Size.NotIn(lambda deps, length, ix: [1, 3]),
21+
cp.Size.Le(lambda deps, length, ix: 5),
22+
cp.Value.Ne(lambda deps, dtype, struct: 0),
23+
]
24+
deps = [2]
25+
26+
engine = MetaArgEngine(ArgType.Tensor, constraints, deps, True)
27+
ms = list(engine.gen(Attribute.DTYPE))
28+
self.assertEqual(len(ms), len(SUPPORTED_TENSOR_DTYPES))
29+
self.assertEqual({m.dtype for m in ms}, set(SUPPORTED_TENSOR_DTYPES))
30+
self.assertTrue(all(0 <= m.rank() <= 4 for m in ms))
31+
for m in ms:
32+
self.assertTrue(
33+
all(0 <= size <= 5 and size not in [1, 3] for size in m.structure)
34+
)
35+
for m in ms:
36+
self.assertEqual(str(m.value), "[-inf, 0.0) (0.0, inf]")
37+
38+
ms = list(engine.gen(Attribute.RANK))
39+
self.assertEqual(len(ms), 4)
40+
ranks = {len(m.structure) for m in ms}
41+
self.assertTrue(0 in ranks)
42+
self.assertTrue(4 in ranks)
43+
self.assertTrue(all(0 <= r <= 4 for r in ranks))
44+
45+
def test_dim_list(self):
46+
constraints = [
47+
cp.Length.Le(lambda deps: deps[0] + deps[1]),
48+
cp.Value.Gen(
49+
lambda deps, length: ({(deps[0],) * length}, {(deps[1],) * length})
50+
),
51+
]
52+
deps = [2, 3]
53+
54+
engine = MetaArgEngine(ArgType.DimList, constraints, deps, True)
55+
ms = list(engine.gen(Attribute.VALUE))
56+
self.assertEqual(len(ms), 1)
57+
self.assertTrue(1 <= len(ms[0].value) <= 5)
58+
self.assertTrue(all(v == 2 for v in ms[0].value))
59+
60+
engine = MetaArgEngine(ArgType.DimList, constraints, deps, False)
61+
ms = list(engine.gen(Attribute.VALUE))
62+
self.assertEqual(len(ms), 1)
63+
self.assertTrue(1 <= len(ms[0].value) <= 5)
64+
self.assertTrue(all(v == 3 for v in ms[0].value))
65+
66+
67+
if __name__ == "__main__":
68+
unittest.main()

0 commit comments

Comments
 (0)