Skip to content

Commit b75f70f

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
InputGen: introduce attribute solver
Reviewed By: SS-JIA Differential Revision: D52047185 fbshipit-source-id: 4e0854845dbf4f3b1d504c33f374b1548d592c77
1 parent 9f01abc commit b75f70f

File tree

2 files changed

+208
-0
lines changed

2 files changed

+208
-0
lines changed

inputgen/attribute/solve.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, List, Optional
8+
9+
from inputgen.argument.type import ArgType
10+
from inputgen.attribute.model import Attribute
11+
from inputgen.specs.model import Constraint, ConstraintSuffix
12+
from inputgen.variable.solve import SolvableVariable
13+
from inputgen.variable.type import ScalarDtype
14+
15+
16+
class AttributeSolver:
17+
def __init__(
18+
self,
19+
attribute: Attribute,
20+
argtype: ArgType,
21+
scalar_dtype: Optional[ScalarDtype] = None,
22+
):
23+
self.attribute = attribute
24+
if attribute == Attribute.VALUE and argtype.is_scalar():
25+
if scalar_dtype is None:
26+
raise ValueError(
27+
"Attribute value for argtype scalar requires a scalar_dtype"
28+
)
29+
self.argtype = argtype
30+
self.vtype = attribute.get_vtype(argtype, scalar_dtype)
31+
32+
def solve_hard_constraints(self, variable: SolvableVariable) -> None:
33+
if self.attribute in [Attribute.LENGTH, Attribute.RANK, Attribute.SIZE]:
34+
variable.Ge(0)
35+
36+
def solve_user_constraint(
37+
self,
38+
variable: SolvableVariable,
39+
suffix: ConstraintSuffix,
40+
res: Any,
41+
valid: bool = True,
42+
) -> bool:
43+
if res is None:
44+
return False
45+
if suffix == ConstraintSuffix.EQ:
46+
variable.Eq(res) if valid else variable.Ne(res)
47+
if suffix == ConstraintSuffix.NE:
48+
variable.Ne(res) if valid else variable.Eq(res)
49+
if suffix == ConstraintSuffix.IN:
50+
variable.In(res) if valid else variable.NotIn(res)
51+
if suffix == ConstraintSuffix.NOTIN:
52+
variable.NotIn(res) if valid else variable.In(res)
53+
if suffix == ConstraintSuffix.LE:
54+
variable.Le(res) if valid else variable.Gt(res)
55+
if suffix == ConstraintSuffix.LT:
56+
variable.Lt(res) if valid else variable.Ge(res)
57+
if suffix == ConstraintSuffix.GE:
58+
variable.Ge(res) if valid else variable.Lt(res)
59+
if suffix == ConstraintSuffix.GT:
60+
variable.Gt(res) if valid else variable.Le(res)
61+
# TODO(mcandales): Enable Such That
62+
# if suffix == ConstraintSuffix.ST:
63+
# variable.St(res) if valid else variable.St(lambda x: not res(x))
64+
if suffix == ConstraintSuffix.BE:
65+
if valid:
66+
variable.In(res)
67+
else:
68+
return False
69+
return True
70+
71+
def solve_focus_constraints(
72+
self, variable: SolvableVariable, focus: Attribute
73+
) -> None:
74+
if self.attribute in [Attribute.LENGTH, Attribute.RANK, Attribute.SIZE]:
75+
if focus in [
76+
Attribute.LENGTH,
77+
Attribute.RANK,
78+
Attribute.SIZE,
79+
Attribute.VALUE,
80+
]:
81+
attr_pos = Attribute.hierarchy(self.argtype).index(self.attribute)
82+
focus_pos = Attribute.hierarchy(self.argtype).index(focus)
83+
if attr_pos < focus_pos:
84+
variable.Ge(1)
85+
86+
def solve(
87+
self, constraints: List[Constraint], focus: Attribute, valid: bool, *args
88+
):
89+
applicable_constraints = []
90+
for constraint in constraints:
91+
if constraint.attribute != self.attribute:
92+
continue
93+
res = constraint.fn(*args)
94+
if res is None:
95+
continue
96+
applicable_constraints.append((constraint.suffix, res))
97+
98+
# TODO(mcandales) This is a hack:
99+
if constraint.suffix == ConstraintSuffix.GEN:
100+
valid_values, invalid_values = res
101+
variable = SolvableVariable(tuple)
102+
variable.In(valid_values if valid else invalid_values)
103+
yield variable
104+
return
105+
106+
if not valid and self.attribute == focus:
107+
for invalid_ix in range(len(applicable_constraints)):
108+
variable = SolvableVariable(self.vtype)
109+
self.solve_hard_constraints(variable)
110+
self.solve_focus_constraints(variable, focus)
111+
for ix, (suffix, res) in enumerate(applicable_constraints):
112+
if ix == invalid_ix:
113+
if not self.solve_user_constraint(variable, suffix, res, False):
114+
break
115+
else:
116+
self.solve_user_constraint(variable, suffix, res, True)
117+
else:
118+
yield variable
119+
else:
120+
variable = SolvableVariable(self.vtype)
121+
self.solve_hard_constraints(variable)
122+
self.solve_focus_constraints(variable, focus)
123+
for suffix, res in applicable_constraints:
124+
self.solve_user_constraint(variable, suffix, res, True)
125+
yield variable
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from inputgen.argument.type import ArgType
10+
from inputgen.attribute.model import Attribute
11+
from inputgen.attribute.solve import AttributeSolver
12+
from inputgen.specs.model import ConstraintProducer as cp
13+
from inputgen.variable.type import ScalarDtype
14+
15+
16+
class TestAttributeSolver(unittest.TestCase):
17+
def test_solver(self):
18+
solver = AttributeSolver(Attribute.VALUE, ArgType.Scalar, ScalarDtype.float)
19+
constraints = [
20+
cp.Value.Ge(lambda x, y: 1),
21+
cp.Value.Le(lambda x, y: x + 3),
22+
cp.Value.Ne(lambda x, y: y + 1),
23+
]
24+
x = 2
25+
y = 1
26+
variables = list(solver.solve(constraints, Attribute.VALUE, True, x, y))
27+
self.assertEqual(len(variables), 1)
28+
self.assertEqual(str(variables[0].space), "[1.0, 2.0) (2.0, 5.0]")
29+
30+
variables = list(solver.solve(constraints, Attribute.VALUE, False, x, y))
31+
self.assertEqual(len(variables), 3)
32+
self.assertEqual(str(variables[0].space), "[-inf, 1.0)")
33+
self.assertEqual(str(variables[1].space), "(5.0, inf]")
34+
self.assertEqual(str(variables[2].space), "{2.0}")
35+
36+
def test_hidden_constraints(self):
37+
solver = AttributeSolver(Attribute.RANK, ArgType.Tensor)
38+
constraints = [
39+
cp.Rank.Le(lambda x: x + 3),
40+
]
41+
x = 2
42+
43+
# valid case: test focus constraint
44+
variables = list(solver.solve(constraints, Attribute.SIZE, True, x))
45+
self.assertEqual(len(variables), 1)
46+
self.assertEqual(str(variables[0].space), "[1, 5]")
47+
48+
# valid case: test hard constraint
49+
variables = list(solver.solve(constraints, Attribute.RANK, True, x))
50+
self.assertEqual(len(variables), 1)
51+
self.assertEqual(str(variables[0].space), "[0, 5]")
52+
53+
# invalid case: test focus constraint
54+
variables = list(solver.solve(constraints, Attribute.SIZE, False, x))
55+
self.assertEqual(len(variables), 1)
56+
self.assertEqual(str(variables[0].space), "[1, 5]")
57+
58+
constraints = [
59+
cp.Rank.Ge(lambda x: x),
60+
cp.Rank.Le(lambda x: x + 4),
61+
]
62+
x = 3
63+
64+
# invalid case: test hard constraint
65+
variables = list(solver.solve(constraints, Attribute.RANK, False, x))
66+
self.assertEqual(len(variables), 2)
67+
self.assertEqual(str(variables[0].space), "[0, 3)")
68+
self.assertEqual(str(variables[1].space), "(7, inf)")
69+
70+
def test_scalar_type(self):
71+
solver = AttributeSolver(Attribute.VALUE, ArgType.ScalarType)
72+
constraints = []
73+
74+
variables = list(solver.solve(constraints, Attribute.VALUE, True))
75+
self.assertEqual(len(variables), 1)
76+
self.assertFalse(variables[0].space.empty())
77+
78+
variables = list(solver.solve(constraints, Attribute.VALUE, False))
79+
self.assertEqual(len(variables), 0)
80+
81+
82+
if __name__ == "__main__":
83+
unittest.main()

0 commit comments

Comments
 (0)