Skip to content

Commit 9341313

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
InputGen: introduce attribute engine
Reviewed By: SS-JIA Differential Revision: D52047186 fbshipit-source-id: 19bb6e665c5f0a395a82575d8cff37f88cf839ee
1 parent b75f70f commit 9341313

File tree

2 files changed

+116
-0
lines changed

2 files changed

+116
-0
lines changed

inputgen/attribute/engine.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
from typing import List, Optional
9+
10+
from inputgen.argument.type import ArgType
11+
from inputgen.attribute.model import Attribute
12+
from inputgen.attribute.solve import AttributeSolver
13+
from inputgen.specs.model import Constraint
14+
from inputgen.variable.gen import VariableGenerator
15+
from inputgen.variable.type import ScalarDtype
16+
17+
18+
class AttributeEngine(AttributeSolver):
19+
def __init__(
20+
self,
21+
attribute: Attribute,
22+
constraints: List[Constraint],
23+
valid: bool,
24+
argtype: Optional[ArgType] = None,
25+
scalar_dtype: Optional[ScalarDtype] = None,
26+
):
27+
super().__init__(attribute, argtype, scalar_dtype)
28+
self.constraints = constraints
29+
self.valid = valid
30+
31+
def gen(self, focus: Attribute, *args):
32+
if self.attribute == Attribute.OPTIONAL:
33+
num = 2
34+
elif self.attribute == focus:
35+
if self.attribute == Attribute.DTYPE:
36+
num = 8
37+
else:
38+
num = 6
39+
else:
40+
num = 1
41+
gen_vals = set()
42+
for variable in self.solve(self.constraints, focus, self.valid, *args):
43+
vals = []
44+
if variable.vtype in [bool, int, float]:
45+
limits = self.attribute.get_custom_limits(self.argtype)
46+
if limits is not None:
47+
v_copy = copy.deepcopy(variable)
48+
v_copy.Ge(limits[0])
49+
v_copy.Le(limits[1])
50+
vals = VariableGenerator(v_copy.space).gen(num)
51+
if len(vals) == 0:
52+
vals = VariableGenerator(variable.space).gen(num)
53+
gen_vals.update(vals)
54+
return gen_vals
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from inputgen.argument.type import ArgType
11+
from inputgen.attribute.engine import AttributeEngine
12+
from inputgen.attribute.model import Attribute
13+
from inputgen.specs.model import ConstraintProducer as cp
14+
from inputgen.variable.type import ScalarDtype
15+
16+
17+
class TestAttributeEngine(unittest.TestCase):
18+
def test_engine(self):
19+
constraints = [
20+
cp.Value.Ge(lambda x, y: 1),
21+
cp.Value.Le(lambda x, y: x + 3),
22+
cp.Value.Ne(lambda x, y: y + 1),
23+
]
24+
x = 2
25+
y = 1
26+
27+
engine = AttributeEngine(
28+
Attribute.VALUE, constraints, True, ArgType.Scalar, ScalarDtype.float
29+
)
30+
values = engine.gen(Attribute.VALUE, x, y)
31+
self.assertEqual(len(values), 6)
32+
self.assertTrue(all(v >= 1 for v in values))
33+
self.assertTrue(all(v <= 5 for v in values))
34+
self.assertTrue(all(v != 2 for v in values))
35+
36+
values = engine.gen(Attribute.DTYPE, x, y)
37+
self.assertEqual(len(values), 1)
38+
39+
engine = AttributeEngine(
40+
Attribute.VALUE, constraints, False, ArgType.Scalar, ScalarDtype.float
41+
)
42+
values = engine.gen(Attribute.VALUE, x, y)
43+
self.assertEqual(len(values), 9)
44+
self.assertTrue(float("-inf") in values)
45+
self.assertTrue(0.9999999999999999 in values)
46+
self.assertTrue(2.0 in values)
47+
self.assertTrue(5.000000000000001 in values)
48+
self.assertTrue(float("inf") in values)
49+
50+
def test_scalar_type(self):
51+
engine = AttributeEngine(Attribute.VALUE, [], True, ArgType.ScalarType)
52+
values = engine.gen(Attribute.VALUE)
53+
self.assertTrue(len(values) > 0)
54+
self.assertTrue(all(isinstance(v, torch.dtype) for v in values))
55+
56+
engine = AttributeEngine(Attribute.VALUE, [], False, ArgType.ScalarType)
57+
values = engine.gen(Attribute.VALUE)
58+
self.assertTrue(len(values) == 0)
59+
60+
61+
if __name__ == "__main__":
62+
unittest.main()

0 commit comments

Comments
 (0)