Skip to content

Commit b9acd4e

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
InputGen: introduce variable types (#2)
Summary: Pull Request resolved: #2 This is the first in a stack of diffs aimed at landing InputGen, a test generation engine for pytorch ops This diff introduces the most basic variable types that are generated. These are the building blocks for generating more complex structures. These variable types are: - Bool (True or False) - Int (any integral value) - Float (any floating point value) - String (any string) - ScalarDtype (currently this can take only 3 values: bool, int or float) - TensorDtype - Tuple (any tuple of values) Reviewed By: digantdesai Differential Revision: D52047188 fbshipit-source-id: c8b06f71037433f89148fa70d95fa0c769c186ae
1 parent 30f05d5 commit b9acd4e

File tree

3 files changed

+154
-0
lines changed

3 files changed

+154
-0
lines changed

inputgen/variable/__init__.py

Whitespace-only changes.

inputgen/variable/type.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import math
2+
from enum import Enum
3+
from typing import Any
4+
5+
import torch
6+
7+
8+
class ScalarDtype(Enum):
9+
bool = bool
10+
int = int
11+
float = float
12+
13+
def __lt__(self, other):
14+
if self.__class__ is other.__class__:
15+
order = [bool, int, float]
16+
return order.index(self.value) < order.index(other.value)
17+
return NotImplemented
18+
19+
def __str__(self):
20+
return self.value.__name__
21+
22+
23+
SUPPORTED_TENSOR_DTYPES = [
24+
torch.bool,
25+
torch.uint8,
26+
torch.int8,
27+
torch.int16,
28+
torch.int32,
29+
torch.int64,
30+
torch.float32,
31+
torch.float64,
32+
# The following types are not supported yet, but we should support them soon:
33+
# torch.float16,
34+
# torch.complex32,
35+
# torch.complex64,
36+
# torch.complex128,
37+
# torch.bfloat16,
38+
]
39+
40+
41+
class VariableType(Enum):
42+
"""
43+
These are the most basic variable types generated by Inputgen.
44+
More complex structures can be created using these as building blocks.
45+
"""
46+
47+
Bool = bool
48+
Int = int
49+
Float = float
50+
String = str
51+
ScalarDtype = ScalarDtype
52+
TensorDtype = torch.dtype
53+
Tuple = tuple
54+
55+
@staticmethod
56+
def contains(v: Any) -> bool:
57+
return v in [member.value for member in VariableType]
58+
59+
60+
def invalid_vtype(vtype: type, v: Any) -> bool:
61+
if v is None:
62+
return False
63+
if vtype in [bool, int, float]:
64+
if type(v) not in [bool, int, float]:
65+
return True
66+
else:
67+
if not isinstance(v, vtype):
68+
return True
69+
return False
70+
71+
72+
def is_integer(v: Any) -> bool:
73+
if type(v) not in [bool, int, float]:
74+
return False
75+
if math.isnan(v) or not math.isfinite(v):
76+
return False
77+
return bool(int(v) == v)
78+
79+
80+
def convert_to_vtype(vtype: type, v: Any) -> Any:
81+
if vtype == bool:
82+
return bool(v)
83+
if vtype == int:
84+
if not is_integer(v):
85+
return v
86+
return int(v)
87+
if vtype == float:
88+
return float(v)
89+
return v

test/inputgen/test_variable_types.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import unittest
2+
3+
import torch
4+
from inputgen.variable.type import (
5+
convert_to_vtype,
6+
invalid_vtype,
7+
is_integer,
8+
ScalarDtype,
9+
VariableType,
10+
)
11+
12+
13+
class TestVariableType(unittest.TestCase):
14+
def test_variable_type(self):
15+
self.assertEqual(VariableType.Bool.value, bool)
16+
self.assertEqual(VariableType.Int.value, int)
17+
self.assertEqual(VariableType.Float.value, float)
18+
self.assertEqual(VariableType.String.value, str)
19+
self.assertEqual(VariableType.Tuple.value, tuple)
20+
self.assertEqual(VariableType.ScalarDtype.value, ScalarDtype)
21+
self.assertEqual(VariableType.TensorDtype.value, torch.dtype)
22+
23+
def test_is_integer(self):
24+
self.assertTrue(is_integer(1))
25+
self.assertTrue(is_integer(5))
26+
self.assertTrue(is_integer(-5))
27+
self.assertTrue(is_integer(1.0))
28+
self.assertTrue(is_integer(0.0))
29+
self.assertTrue(is_integer(-3.0))
30+
self.assertTrue(is_integer(True))
31+
self.assertTrue(is_integer(False))
32+
self.assertFalse(is_integer(3.5))
33+
self.assertFalse(is_integer(float("inf")))
34+
self.assertFalse(is_integer(float("-inf")))
35+
self.assertFalse(is_integer(float("nan")))
36+
37+
def test_convert(self):
38+
self.assertEqual(convert_to_vtype(VariableType.Bool.value, 1.0), True)
39+
self.assertEqual(convert_to_vtype(VariableType.Int.value, False), 0)
40+
self.assertEqual(convert_to_vtype(VariableType.Float.value, 3), 3.0)
41+
self.assertEqual(
42+
convert_to_vtype(VariableType.Int.value, float("inf")), float("inf")
43+
)
44+
45+
def test_invalid_vtype(self):
46+
self.assertFalse(invalid_vtype(VariableType.Bool.value, 1.0))
47+
self.assertFalse(invalid_vtype(VariableType.Int.value, 1.0))
48+
self.assertFalse(invalid_vtype(VariableType.Float.value, 1.0))
49+
self.assertFalse(invalid_vtype(VariableType.String.value, "hello"))
50+
self.assertFalse(invalid_vtype(VariableType.Tuple.value, (1, 2)))
51+
self.assertFalse(
52+
invalid_vtype(VariableType.ScalarDtype.value, ScalarDtype.bool)
53+
)
54+
self.assertFalse(invalid_vtype(VariableType.ScalarDtype.value, ScalarDtype.int))
55+
self.assertFalse(
56+
invalid_vtype(VariableType.ScalarDtype.value, ScalarDtype.float)
57+
)
58+
self.assertFalse(invalid_vtype(VariableType.TensorDtype.value, torch.bool))
59+
self.assertTrue(invalid_vtype(VariableType.Float.value, "1.0"))
60+
self.assertTrue(invalid_vtype(VariableType.String.value, 1))
61+
self.assertTrue(invalid_vtype(VariableType.ScalarDtype.value, torch.int8))
62+
63+
64+
if __name__ == "__main__":
65+
unittest.main()

0 commit comments

Comments
 (0)