Skip to content

Commit 3ba6b08

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
InputGen: introduce solvable variable
Summary: This diff introduces the SolvableVariable class, responsible for solving the space of values of a variable subjected to constraints. The supported constraints are: - Eq (equal to) - Ne (not equal to) - In (in list of values) - NotIn (not in list of values) - Le (less than or equal to) - Lt (less than) - Ge (greater than or equal to) - Gt (greater than) Note: A tentative future constraint not included here yet will be St (such that) A SolvableVariable needs to be initialized with the variable type. It maintains a space of values for that variable type (VariableSpace object). The result of applying multiple constraints to a SolvableVariable, is the conjunction of those constraints. Consequently, every constraint reduces or maintains its space. We currently don't support constraint disjunctions. Sample usage: ``` v = SolvableVariable(float) v.Gt(0.5) v.Le(4.5) v.NotIn([1.5, 2.5]) ``` Reviewed By: SS-JIA Differential Revision: D51868719 fbshipit-source-id: 51f740358561930b40e0b11c210228c0b24fc6f3
1 parent a4468f9 commit 3ba6b08

File tree

2 files changed

+335
-0
lines changed

2 files changed

+335
-0
lines changed

inputgen/variable/solve.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import math
8+
from typing import Any, List, Union
9+
10+
from inputgen.variable.space import Discrete, VariableSpace
11+
from inputgen.variable.type import convert_to_vtype, invalid_vtype, is_integer
12+
13+
14+
class SolvableVariable:
15+
"""
16+
A solvable variable is a variable over which we can impose constraints.
17+
It needs to be initialized with the variable type. It maintains an internal state
18+
of the possible values of the variable, represented by a VariableSpace object.
19+
It supports the following constraints:
20+
- Eq: Equal to a specific value
21+
- Ne: Not equal to a specific value
22+
- In: Contained in a list of values
23+
- NotIn: Not contained in a list of values
24+
- Le: Less than or equal to a specific value
25+
- Lt: Less than a specific value
26+
- Ge: Greater than or equal to a specific value
27+
- Gt: Greater than a specific value
28+
The result of applying multiple constraints to a solvable variable, is the
29+
conjunction of those constraints.
30+
"""
31+
32+
def __init__(self, vtype: type):
33+
self.vtype = vtype
34+
self.space = VariableSpace(vtype)
35+
36+
def Eq(self, v: Any) -> None:
37+
if invalid_vtype(self.vtype, v):
38+
raise TypeError("Variable type mismatch")
39+
if self.space.empty():
40+
return
41+
if self.space.contains(v):
42+
self.space.discrete = Discrete([convert_to_vtype(self.vtype, v)])
43+
else:
44+
self.space.discrete = Discrete([])
45+
46+
def Ne(self, v: Any) -> None:
47+
if invalid_vtype(self.vtype, v):
48+
raise TypeError("Variable type mismatch")
49+
if self.space.empty():
50+
return
51+
self.space.remove(v)
52+
53+
def In(self, values: List[Any]) -> None:
54+
for v in values:
55+
if invalid_vtype(self.vtype, v):
56+
raise TypeError("Variable type mismatch")
57+
if self.space.empty():
58+
return
59+
self.space.discrete = Discrete(
60+
[convert_to_vtype(self.vtype, v) for v in values if self.space.contains(v)]
61+
)
62+
63+
def NotIn(self, values: List[Any]) -> None:
64+
for v in values:
65+
if invalid_vtype(self.vtype, v):
66+
raise TypeError("Variable type mismatch")
67+
if self.space.empty():
68+
return
69+
for v in values:
70+
self.space.remove(v)
71+
72+
def Le(self, upper: Union[bool, int, float]) -> None:
73+
if self.vtype not in [bool, int, float]:
74+
raise Exception(f"Le is not valid constraint on {self.vtype}")
75+
if invalid_vtype(self.vtype, upper):
76+
raise TypeError("Variable type mismatch")
77+
if self.space.empty():
78+
return
79+
elif self.space.discrete.initialized:
80+
self.space.discrete.filter(lambda v: v <= upper)
81+
elif self.vtype == int:
82+
if math.isfinite(upper):
83+
self.space.intervals.set_upper(math.ceil(upper), upper_open=False)
84+
else:
85+
self.space.intervals.set_upper(upper, upper_open=False)
86+
else:
87+
self.space.intervals.set_upper(float(upper), upper_open=False)
88+
89+
def Lt(self, upper: Union[bool, int, float]) -> None:
90+
if self.vtype not in [bool, int, float]:
91+
raise Exception(f"Lt is not valid constraint on {self.vtype}")
92+
if invalid_vtype(self.vtype, upper):
93+
raise TypeError("Variable type mismatch")
94+
if self.space.empty():
95+
return
96+
elif self.space.discrete.initialized:
97+
self.space.discrete.filter(lambda v: v < upper)
98+
elif self.vtype == int:
99+
if math.isfinite(upper):
100+
self.space.intervals.set_upper(
101+
math.floor(upper), upper_open=is_integer(upper)
102+
)
103+
else:
104+
self.space.intervals.set_upper(upper, upper_open=True)
105+
else:
106+
self.space.intervals.set_upper(float(upper), upper_open=True)
107+
108+
def Ge(self, lower: Union[bool, int, float]) -> None:
109+
if self.vtype not in [bool, int, float]:
110+
raise Exception(f"Ge is not valid constraint on {self.vtype}")
111+
if invalid_vtype(self.vtype, lower):
112+
raise TypeError("Variable type mismatch")
113+
if self.space.empty():
114+
return
115+
elif self.space.discrete.initialized:
116+
self.space.discrete.filter(lambda v: v >= lower)
117+
elif self.vtype == int:
118+
if math.isfinite(lower):
119+
self.space.intervals.set_lower(math.ceil(lower), lower_open=False)
120+
else:
121+
self.space.intervals.set_lower(lower, lower_open=False)
122+
else:
123+
self.space.intervals.set_lower(float(lower), lower_open=False)
124+
125+
def Gt(self, lower: Union[bool, int, float]) -> None:
126+
if self.vtype not in [bool, int, float]:
127+
raise Exception(f"Gt is not valid constraint on {self.vtype}")
128+
if invalid_vtype(self.vtype, lower):
129+
raise TypeError("Variable type mismatch")
130+
if self.space.empty():
131+
return
132+
elif self.space.discrete.initialized:
133+
self.space.discrete.filter(lambda v: v > lower)
134+
elif self.vtype == int:
135+
if math.isfinite(lower):
136+
self.space.intervals.set_lower(
137+
math.ceil(lower), lower_open=is_integer(lower)
138+
)
139+
else:
140+
self.space.intervals.set_lower(lower, lower_open=True)
141+
else:
142+
self.space.intervals.set_lower(float(lower), lower_open=True)
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from inputgen.variable.constants import INT64_MAX, INT64_MIN
11+
from inputgen.variable.solve import SolvableVariable
12+
from inputgen.variable.type import ScalarDtype, SUPPORTED_TENSOR_DTYPES
13+
14+
15+
class TestSolvableVariable(unittest.TestCase):
16+
def test_bool_solver(self):
17+
s = SolvableVariable(bool)
18+
self.assertTrue(s.space.discrete.initialized)
19+
self.assertEqual(s.space.discrete.values, {False, True})
20+
21+
s.Ne(0.0)
22+
self.assertEqual(str(s.space), "{True}")
23+
24+
s = SolvableVariable(bool)
25+
s.Ne(1)
26+
self.assertEqual(str(s.space), "{False}")
27+
28+
s = SolvableVariable(bool)
29+
s.Eq(1)
30+
self.assertEqual(str(s.space), "{True}")
31+
32+
s = SolvableVariable(bool)
33+
s.NotIn([0.0, 1.0])
34+
self.assertEqual(str(s.space), "{}")
35+
36+
s = SolvableVariable(bool)
37+
s.Gt(2)
38+
self.assertEqual(str(s.space), "{}")
39+
40+
def test_int_solver(self):
41+
s = SolvableVariable(int)
42+
self.assertFalse(s.space.empty())
43+
44+
s.Ne(0.0)
45+
self.assertEqual(str(s.space), "(-inf, 0) (0, inf)")
46+
self.assertFalse(s.space.contains(0))
47+
48+
s = SolvableVariable(int)
49+
s.Eq(1.5)
50+
self.assertEqual(str(s.space), "{}")
51+
52+
s = SolvableVariable(int)
53+
s.Ge(3.5)
54+
self.assertEqual(str(s.space), "[4, inf)")
55+
56+
s.Gt(5.0)
57+
self.assertEqual(str(s.space), "(5, inf)")
58+
59+
s = SolvableVariable(int)
60+
s.Gt(INT64_MAX)
61+
self.assertTrue(s.space.empty())
62+
63+
s = SolvableVariable(int)
64+
s.Lt(INT64_MIN)
65+
self.assertTrue(s.space.empty())
66+
67+
s = SolvableVariable(int)
68+
s.Gt(3)
69+
s.Lt(4)
70+
self.assertTrue(s.space.empty())
71+
72+
s = SolvableVariable(int)
73+
s.Lt(float("inf"))
74+
self.assertEqual(str(s.space), "(-inf, inf)")
75+
76+
s = SolvableVariable(int)
77+
s.Lt(float("-inf"))
78+
self.assertEqual(str(s.space), "{}")
79+
80+
def test_float_solver(self):
81+
s = SolvableVariable(float)
82+
s.Ge(1.7976931348623157e308)
83+
self.assertEqual(str(s.space), "[1.7976931348623157e+308, inf]")
84+
85+
s = SolvableVariable(float)
86+
s.Gt(float("inf"))
87+
self.assertEqual(str(s.space), "{}")
88+
self.assertTrue(s.space.empty())
89+
90+
s = SolvableVariable(float)
91+
s.Eq(2)
92+
self.assertEqual(str(s.space), "{2.0}")
93+
s.Eq(1)
94+
self.assertEqual(str(s.space), "{}")
95+
self.assertTrue(s.space.empty())
96+
97+
s = SolvableVariable(float)
98+
s.Eq(2)
99+
s.Lt(2)
100+
self.assertTrue(s.space.empty())
101+
102+
s = SolvableVariable(float)
103+
s.Le(2)
104+
self.assertEqual(str(s.space), "[-inf, 2.0]")
105+
s.Eq(2)
106+
self.assertEqual(str(s.space), "{2.0}")
107+
108+
s = SolvableVariable(float)
109+
s.Ne(3)
110+
self.assertEqual(str(s.space), "[-inf, 3.0) (3.0, inf]")
111+
s.Le(3.5)
112+
self.assertEqual(str(s.space), "[-inf, 3.0) (3.0, 3.5]")
113+
s.Ge(3)
114+
self.assertEqual(str(s.space), "(3.0, 3.5]")
115+
s.In([3, 3.5])
116+
self.assertEqual(str(s.space), "{3.5}")
117+
118+
s = SolvableVariable(float)
119+
s.Lt(float("inf"))
120+
self.assertEqual(str(s.space), "[-inf, inf)")
121+
122+
s = SolvableVariable(float)
123+
s.Lt(float("-inf"))
124+
self.assertEqual(str(s.space), "{}")
125+
126+
s = SolvableVariable(float)
127+
s.Le(float("-inf"))
128+
self.assertEqual(str(s.space), "[-inf, -inf]")
129+
130+
s = SolvableVariable(float)
131+
s.Ge(float("inf"))
132+
self.assertEqual(str(s.space), "[inf, inf]")
133+
134+
def test_str_solver(self):
135+
s = SolvableVariable(str)
136+
137+
s.In(["a", "b", "c", "d"])
138+
self.assertEqual(s.space.discrete.values, {"a", "b", "c", "d"})
139+
140+
s.Ne("a")
141+
self.assertEqual(s.space.discrete.values, {"b", "c", "d"})
142+
143+
s.Eq("b")
144+
self.assertEqual(s.space.discrete.values, {"b"})
145+
146+
s.NotIn(["b", "c"])
147+
self.assertTrue(s.space.empty())
148+
149+
with self.assertRaises(Exception):
150+
s.Le(3)
151+
152+
def test_tensor_dtype_solver(self):
153+
s = SolvableVariable(torch.dtype)
154+
self.assertTrue(s.space.discrete.initialized)
155+
self.assertEqual(s.space.discrete.values, set(SUPPORTED_TENSOR_DTYPES))
156+
157+
s.In([torch.bool, torch.uint8, torch.int8, torch.int32, torch.float32])
158+
self.assertEqual(
159+
s.space.discrete.values,
160+
{torch.bool, torch.uint8, torch.int8, torch.int32, torch.float32},
161+
)
162+
163+
s.Ne(torch.float32)
164+
self.assertEqual(
165+
s.space.discrete.values, {torch.bool, torch.uint8, torch.int8, torch.int32}
166+
)
167+
168+
s.NotIn([torch.bool, torch.uint8])
169+
self.assertEqual(s.space.discrete.values, {torch.int8, torch.int32})
170+
171+
s.Eq(torch.int32)
172+
self.assertEqual(s.space.discrete.values, {torch.int32})
173+
174+
with self.assertRaises(Exception):
175+
s.Ge(3)
176+
177+
def test_scalar_dtype_solver(self):
178+
s = SolvableVariable(ScalarDtype)
179+
self.assertTrue(s.space.discrete.initialized)
180+
self.assertEqual(s.space.discrete.values, set(ScalarDtype))
181+
182+
s.In([ScalarDtype.int, ScalarDtype.float])
183+
self.assertEqual(s.space.discrete.values, {ScalarDtype.int, ScalarDtype.float})
184+
185+
s.Ne(ScalarDtype.float)
186+
self.assertEqual(s.space.discrete.values, {ScalarDtype.int})
187+
188+
with self.assertRaises(Exception):
189+
s.Gt(3)
190+
191+
192+
if __name__ == "__main__":
193+
unittest.main()

0 commit comments

Comments
 (0)