Skip to content

Commit bb9735f

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
InputGen: introduce argument types
Reviewed By: SS-JIA Differential Revision: D51868718 fbshipit-source-id: 96cf097505ece6052fca3ad4456bff5ee3596520
1 parent 06e1923 commit bb9735f

File tree

3 files changed

+162
-0
lines changed

3 files changed

+162
-0
lines changed

inputgen/argument/__init__.py

Whitespace-only changes.

inputgen/argument/type.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from enum import Enum
8+
9+
10+
class ArgType(str, Enum):
11+
Tensor = "Tensor"
12+
TensorOpt = "Tensor?"
13+
14+
TensorList = "Tensor[]"
15+
TensorOptList = "Tensor?[]"
16+
17+
Scalar = "Scalar"
18+
ScalarOpt = "Scalar?"
19+
20+
ScalarType = "ScalarType"
21+
ScalarTypeOpt = "ScalarType?"
22+
23+
Dim = "Dim"
24+
DimOpt = "Dim?"
25+
DimList = "Dim[]"
26+
DimListOpt = "Dim[]?"
27+
28+
Shape = "Shape"
29+
Index = "Index"
30+
IndexOpt = "Index?"
31+
Length = "Length"
32+
LengthOpt = "Length?"
33+
LengthList = "Length[]"
34+
35+
Bool = "Bool"
36+
Int = "Integer"
37+
IntOpt = "Integer?"
38+
Float = "Float"
39+
FloatOpt = "Float?"
40+
String = "String"
41+
StringOpt = "String?"
42+
MemoryFormat = "MemoryFormat"
43+
44+
def is_tensor(self) -> bool:
45+
return self in [ArgType.Tensor, ArgType.TensorOpt]
46+
47+
def is_tensor_list(self) -> bool:
48+
return self in [ArgType.TensorList, ArgType.TensorOptList]
49+
50+
def is_scalar(self) -> bool:
51+
return self in [ArgType.Scalar, ArgType.ScalarOpt]
52+
53+
def is_scalar_type(self) -> bool:
54+
return self in [ArgType.ScalarType, ArgType.ScalarTypeOpt]
55+
56+
def is_dim(self) -> bool:
57+
return self in [ArgType.Dim, ArgType.DimOpt]
58+
59+
def is_dim_list(self) -> bool:
60+
return self in [ArgType.DimList, ArgType.DimListOpt]
61+
62+
def is_shape(self) -> bool:
63+
return self in [ArgType.Shape]
64+
65+
def is_index(self) -> bool:
66+
return self in [ArgType.Index, ArgType.IndexOpt]
67+
68+
def is_length(self) -> bool:
69+
return self in [ArgType.Length, ArgType.LengthOpt]
70+
71+
def is_length_list(self) -> bool:
72+
return self in [ArgType.LengthList]
73+
74+
def is_bool(self) -> bool:
75+
return self in [ArgType.Bool]
76+
77+
def is_int(self) -> bool:
78+
return self in [ArgType.Int, ArgType.IntOpt]
79+
80+
def is_float(self) -> bool:
81+
return self in [ArgType.Float, ArgType.FloatOpt]
82+
83+
def is_string(self) -> bool:
84+
return self in [ArgType.String, ArgType.StringOpt]
85+
86+
def is_memory_format(self) -> bool:
87+
return self in [ArgType.MemoryFormat]
88+
89+
def is_optional(self) -> bool:
90+
return self in [
91+
ArgType.TensorOpt,
92+
ArgType.ScalarOpt,
93+
ArgType.ScalarTypeOpt,
94+
ArgType.DimOpt,
95+
ArgType.DimListOpt,
96+
ArgType.FloatOpt,
97+
ArgType.IndexOpt,
98+
ArgType.IntOpt,
99+
ArgType.LengthOpt,
100+
]
101+
102+
def is_list(self) -> bool:
103+
return self in [
104+
ArgType.TensorList,
105+
ArgType.TensorOptList,
106+
ArgType.DimList,
107+
ArgType.DimListOpt,
108+
ArgType.LengthList,
109+
ArgType.Shape,
110+
]
111+
112+
def has_integer_value(self) -> bool:
113+
return self in [
114+
ArgType.Dim,
115+
ArgType.DimOpt,
116+
ArgType.DimList,
117+
ArgType.DimListOpt,
118+
ArgType.Shape,
119+
ArgType.Index,
120+
ArgType.IndexOpt,
121+
ArgType.Length,
122+
ArgType.LengthOpt,
123+
ArgType.LengthList,
124+
ArgType.Int,
125+
ArgType.IntOpt,
126+
]
127+
128+
def has_dtype(self) -> bool:
129+
return (
130+
self.is_tensor()
131+
or self.is_tensor_list()
132+
or self.is_scalar()
133+
or self.is_scalar_type()
134+
)

test/inputgen/test_argument_types.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from inputgen.argument.type import ArgType
10+
11+
12+
class TestArgType(unittest.TestCase):
13+
def test_methods(self):
14+
argtype = ArgType.Tensor
15+
self.assertTrue(argtype.is_tensor())
16+
17+
argtype = ArgType.TensorList
18+
self.assertTrue(argtype.is_tensor_list())
19+
20+
argtype = ArgType.Scalar
21+
self.assertTrue(argtype.is_scalar())
22+
23+
argtype = ArgType.DimList
24+
self.assertTrue(argtype.is_dim_list())
25+
26+
27+
if __name__ == "__main__":
28+
unittest.main()

0 commit comments

Comments
 (0)