Skip to content

Commit 63cbfc8

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
InputGen: introduce argument attributes
Reviewed By: SS-JIA Differential Revision: D52047184 fbshipit-source-id: 4a7e81a7965f791ffe6c793235f59c671730f8df
1 parent bb9735f commit 63cbfc8

File tree

3 files changed

+211
-0
lines changed

3 files changed

+211
-0
lines changed

inputgen/attribute/__init__.py

Whitespace-only changes.

inputgen/attribute/model.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from enum import Enum
8+
from typing import List, Optional, Tuple
9+
10+
import torch
11+
from inputgen.argument.type import ArgType
12+
from inputgen.variable.type import ScalarDtype
13+
14+
15+
class Attribute(str, Enum):
16+
OPTIONAL = "optional"
17+
LENGTH = "len"
18+
DTYPE = "dtype"
19+
RANK = "rank"
20+
SIZE = "size"
21+
VALUE = "value"
22+
23+
@staticmethod
24+
def hierarchy(argtype: ArgType) -> List["Attribute"]:
25+
if argtype.is_tensor_list():
26+
if argtype == ArgType.TensorOptList:
27+
return [
28+
Attribute.LENGTH,
29+
Attribute.OPTIONAL,
30+
Attribute.DTYPE,
31+
Attribute.RANK,
32+
Attribute.SIZE,
33+
Attribute.VALUE,
34+
]
35+
else:
36+
return [
37+
Attribute.LENGTH,
38+
Attribute.DTYPE,
39+
Attribute.RANK,
40+
Attribute.SIZE,
41+
Attribute.VALUE,
42+
]
43+
opt = [Attribute.OPTIONAL] if argtype.is_optional() else []
44+
if argtype.is_tensor():
45+
return opt + [
46+
Attribute.DTYPE,
47+
Attribute.RANK,
48+
Attribute.SIZE,
49+
Attribute.VALUE,
50+
]
51+
elif argtype.is_scalar():
52+
return opt + [Attribute.DTYPE, Attribute.VALUE]
53+
elif argtype.is_list():
54+
return opt + [Attribute.LENGTH, Attribute.VALUE]
55+
else:
56+
return opt + [Attribute.VALUE]
57+
58+
def get_vtype(
59+
self,
60+
argtype: Optional[ArgType] = None,
61+
scalar_dtype: Optional[ScalarDtype] = None,
62+
) -> type:
63+
if self == Attribute.OPTIONAL:
64+
return bool
65+
if self == Attribute.DTYPE:
66+
if argtype is None:
67+
raise ValueError(f"Attribute {self} requires an argtype")
68+
if argtype.is_scalar():
69+
assert isinstance(ScalarDtype, type)
70+
return ScalarDtype
71+
else:
72+
assert isinstance(torch.dtype, type)
73+
return torch.dtype
74+
if self in [Attribute.LENGTH, Attribute.RANK, Attribute.SIZE]:
75+
return int
76+
if self == Attribute.VALUE:
77+
if argtype is None:
78+
raise ValueError(f"Attribute {self} requires an argtype")
79+
if argtype.has_integer_value():
80+
return int
81+
if argtype.is_bool():
82+
return bool
83+
if argtype.is_float():
84+
return float
85+
if argtype.is_string():
86+
return str
87+
if argtype.is_memory_format():
88+
return str
89+
if argtype.is_scalar():
90+
if scalar_dtype is None:
91+
raise ValueError(
92+
"Attribute value for argtype scalar requires a scalar_dtype"
93+
)
94+
assert isinstance(scalar_dtype, ScalarDtype)
95+
assert isinstance(scalar_dtype.value, type)
96+
return scalar_dtype.value
97+
if argtype.is_scalar_type():
98+
assert isinstance(torch.dtype, type)
99+
return torch.dtype
100+
return float
101+
102+
def get_custom_limits(
103+
self, argtype: Optional[ArgType] = None
104+
) -> Optional[Tuple[int, int]]:
105+
RANK_MAX = 6
106+
SIZE_MAX = 8
107+
TL_LEN_MAX = 6
108+
LIST_LEN_MAX = 8
109+
VALUE_LENGTH_MIN = -9
110+
VALUE_LENGTH_MAX = 9
111+
VALUE_MIN = -20
112+
VALUE_MAX = 20
113+
114+
if self == Attribute.LENGTH:
115+
if argtype is None:
116+
raise ValueError(f"Attribute {self} requires an argtype")
117+
if argtype.is_tensor_list():
118+
return (0, TL_LEN_MAX)
119+
if argtype.is_shape():
120+
return (0, RANK_MAX)
121+
return (0, LIST_LEN_MAX)
122+
elif self == Attribute.RANK:
123+
return (0, RANK_MAX)
124+
elif self == Attribute.SIZE:
125+
return (0, SIZE_MAX)
126+
elif self == Attribute.VALUE:
127+
if argtype is None:
128+
raise ValueError(f"Attribute {self} requires an argtype")
129+
if argtype.is_shape():
130+
return (-SIZE_MAX, SIZE_MAX)
131+
if argtype.is_length() or argtype.is_length_list():
132+
return (VALUE_LENGTH_MIN, VALUE_LENGTH_MAX)
133+
return None
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from inputgen.argument.type import ArgType
11+
from inputgen.attribute.model import Attribute
12+
from inputgen.variable.type import ScalarDtype
13+
14+
15+
class TestAttribute(unittest.TestCase):
16+
def test_hierarchy(self):
17+
self.assertEqual(
18+
Attribute.hierarchy(ArgType.ScalarOpt),
19+
[Attribute.OPTIONAL, Attribute.DTYPE, Attribute.VALUE],
20+
)
21+
self.assertEqual(
22+
Attribute.hierarchy(ArgType.TensorOpt),
23+
[
24+
Attribute.OPTIONAL,
25+
Attribute.DTYPE,
26+
Attribute.RANK,
27+
Attribute.SIZE,
28+
Attribute.VALUE,
29+
],
30+
)
31+
32+
def test_vtype(self):
33+
attr = Attribute.OPTIONAL
34+
self.assertEqual(attr.get_vtype(), bool)
35+
36+
attr = Attribute.LENGTH
37+
self.assertEqual(attr.get_vtype(), int)
38+
39+
attr = Attribute.RANK
40+
self.assertEqual(attr.get_vtype(), int)
41+
42+
attr = Attribute.SIZE
43+
self.assertEqual(attr.get_vtype(), int)
44+
45+
attr = Attribute.DTYPE
46+
self.assertEqual(attr.get_vtype(ArgType.Tensor), torch.dtype)
47+
48+
attr = Attribute.DTYPE
49+
self.assertEqual(attr.get_vtype(ArgType.Scalar), ScalarDtype)
50+
51+
attr = Attribute.VALUE
52+
self.assertEqual(attr.get_vtype(ArgType.Dim), int)
53+
54+
attr = Attribute.VALUE
55+
self.assertEqual(attr.get_vtype(ArgType.Float), float)
56+
57+
attr = Attribute.VALUE
58+
self.assertEqual(attr.get_vtype(ArgType.String), str)
59+
60+
attr = Attribute.VALUE
61+
self.assertEqual(attr.get_vtype(ArgType.ScalarType), torch.dtype)
62+
63+
attr = Attribute.VALUE
64+
self.assertEqual(attr.get_vtype(ArgType.Scalar, ScalarDtype.bool), bool)
65+
66+
attr = Attribute.VALUE
67+
self.assertEqual(attr.get_vtype(ArgType.Scalar, ScalarDtype.int), int)
68+
69+
attr = Attribute.VALUE
70+
self.assertEqual(attr.get_vtype(ArgType.Scalar, ScalarDtype.float), float)
71+
72+
def test_custom_limits(self):
73+
attr = Attribute.OPTIONAL
74+
self.assertEqual(attr.get_custom_limits(), None)
75+
76+
77+
if __name__ == "__main__":
78+
unittest.main()

0 commit comments

Comments
 (0)