Skip to content

Commit 34ece3a

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
InputGen: introduce argument generator
Reviewed By: SS-JIA Differential Revision: D52047187 fbshipit-source-id: 48d0c65f78d6f3f2b1225e02fe4436211b2bf6ce
1 parent 68d25b1 commit 34ece3a

File tree

2 files changed

+210
-0
lines changed

2 files changed

+210
-0
lines changed

inputgen/argument/gen.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import math
8+
from typing import Optional, Tuple
9+
10+
import torch
11+
from inputgen.argument.engine import MetaArg
12+
from inputgen.variable.gen import VariableGenerator
13+
from inputgen.variable.space import VariableSpace
14+
from torch.testing._internal.common_dtype import floating_types, integral_types
15+
16+
17+
FLOAT_RESOLUTION = 8
18+
19+
20+
class TensorGenerator:
21+
def __init__(
22+
self, dtype: Optional[torch.dtype], structure: Tuple, space: VariableSpace
23+
):
24+
self.dtype = dtype
25+
self.structure = structure
26+
self.space = space
27+
28+
def gen(self):
29+
if self.dtype is None:
30+
return None
31+
vg = VariableGenerator(self.space)
32+
min_val = vg.gen_min()
33+
max_val = vg.gen_max()
34+
if min_val == float("-inf"):
35+
min_val = None
36+
if max_val == float("inf"):
37+
max_val = None
38+
# TODO(mcandales): Implement a generator that actually supports any given space
39+
return self.get_random_tensor(
40+
size=self.structure, dtype=self.dtype, high=max_val, low=min_val
41+
)
42+
43+
def get_random_tensor(self, size, dtype, high=None, low=None):
44+
if low is None and high is None:
45+
low = -100
46+
high = 100
47+
elif low is None:
48+
low = high - 100
49+
elif high is None:
50+
high = low + 100
51+
size = tuple(size)
52+
if dtype == torch.bool:
53+
if not self.space.contains(0):
54+
return torch.full(size, True, dtype=dtype)
55+
elif not self.space.contains(1):
56+
return torch.full(size, False, dtype=dtype)
57+
else:
58+
return torch.randint(low=0, high=2, size=size, dtype=dtype)
59+
60+
if dtype in integral_types():
61+
low = math.ceil(low)
62+
high = math.floor(high) + 1
63+
elif dtype in floating_types():
64+
low = math.ceil(FLOAT_RESOLUTION * low)
65+
high = math.floor(FLOAT_RESOLUTION * high) + 1
66+
else:
67+
raise ValueError(f"Unsupported Dtype: {dtype}")
68+
69+
if dtype == torch.uint8:
70+
if not self.space.contains(0):
71+
return torch.randint(low=max(1, low), high=high, size=size, dtype=dtype)
72+
else:
73+
return torch.randint(low=max(0, low), high=high, size=size, dtype=dtype)
74+
75+
t = torch.randint(low=low, high=high, size=size, dtype=dtype)
76+
if not self.space.contains(0):
77+
if high > 0:
78+
pos = torch.randint(low=max(1, low), high=high, size=size, dtype=dtype)
79+
else:
80+
pos = torch.randint(low=low, high=0, size=size, dtype=dtype)
81+
t = torch.where(t == 0, pos, t)
82+
83+
if dtype in integral_types():
84+
return t
85+
if dtype in floating_types():
86+
return t / FLOAT_RESOLUTION
87+
88+
89+
class ArgumentGenerator:
90+
def __init__(self, meta: MetaArg):
91+
self.meta = meta
92+
93+
def gen(self):
94+
if self.meta.optional:
95+
return None
96+
elif self.meta.argtype.is_tensor():
97+
return TensorGenerator(
98+
dtype=self.meta.dtype,
99+
structure=self.meta.structure,
100+
space=self.meta.value,
101+
).gen()
102+
elif self.meta.argtype.is_tensor_list():
103+
return [
104+
TensorGenerator(
105+
dtype=self.meta.dtype[i],
106+
structure=self.meta.structure[i],
107+
space=self.meta.value,
108+
).gen()
109+
for i in range(len(self.meta.dtype))
110+
]
111+
else:
112+
return self.meta.value
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from inputgen.argument.engine import MetaArg
11+
from inputgen.argument.gen import ArgumentGenerator, TensorGenerator
12+
from inputgen.argument.type import ArgType
13+
from inputgen.variable.solve import SolvableVariable
14+
15+
16+
class TestTensorGenerator(unittest.TestCase):
17+
def test_gen(self):
18+
v = SolvableVariable(float)
19+
v.Ge(13)
20+
v.Le(51)
21+
tg = TensorGenerator(dtype=torch.float64, structure=(2, 3), space=v.space)
22+
tensor = tg.gen()
23+
24+
self.assertEqual(tensor.shape, (2, 3))
25+
self.assertEqual(tensor.dtype, torch.float64)
26+
self.assertGreaterEqual(tensor.min(), 13)
27+
self.assertLessEqual(tensor.max(), 51)
28+
29+
def test_zero_tensor(self):
30+
v = SolvableVariable(float)
31+
v.Eq(0)
32+
tg = TensorGenerator(dtype=torch.float64, structure=(2, 3), space=v.space)
33+
tensor = tg.gen()
34+
35+
self.assertEqual(tensor.shape, (2, 3))
36+
self.assertEqual(tensor.dtype, torch.float64)
37+
self.assertGreaterEqual(tensor.min(), 0)
38+
self.assertLessEqual(tensor.max(), 0)
39+
40+
41+
class TestArgumentGenerator(unittest.TestCase):
42+
def test_gen_optional(self):
43+
m = MetaArg(argtype=ArgType.TensorOpt, optional=True)
44+
tensor = ArgumentGenerator(m).gen()
45+
self.assertEqual(tensor, None)
46+
47+
def test_gen_scalar(self):
48+
m = MetaArg(argtype=ArgType.Scalar, value=True)
49+
scalar = ArgumentGenerator(m).gen()
50+
self.assertIs(scalar, True)
51+
52+
def test_gen_dim_list(self):
53+
m = MetaArg(argtype=ArgType.DimList, structure=(2, 3))
54+
dim_list = ArgumentGenerator(m).gen()
55+
self.assertEqual(dim_list, [2, 3])
56+
57+
def test_gen_tensor(self):
58+
v = SolvableVariable(float)
59+
v.Ge(13)
60+
v.Le(51)
61+
m = MetaArg(
62+
argtype=ArgType.Tensor, dtype=torch.float64, structure=(2, 3), value=v.space
63+
)
64+
tensor = ArgumentGenerator(m).gen()
65+
66+
self.assertEqual(tensor.shape, (2, 3))
67+
self.assertEqual(tensor.dtype, torch.float64)
68+
self.assertGreaterEqual(tensor.min(), 13)
69+
self.assertLessEqual(tensor.max(), 51)
70+
71+
def test_gen_tensor_list(self):
72+
v = SolvableVariable(float)
73+
v.Ge(13)
74+
v.Le(51)
75+
m = MetaArg(
76+
argtype=ArgType.TensorOptList,
77+
dtype=[torch.float64, torch.int32, None],
78+
structure=((2, 3), (3,), None),
79+
value=v.space,
80+
)
81+
tensors = ArgumentGenerator(m).gen()
82+
83+
self.assertEqual(len(tensors), 3)
84+
self.assertEqual(tensors[0].shape, (2, 3))
85+
self.assertEqual(tensors[0].dtype, torch.float64)
86+
self.assertGreaterEqual(tensors[0].min(), 13)
87+
self.assertLessEqual(tensors[0].max(), 51)
88+
89+
self.assertEqual(tensors[1].shape, (3,))
90+
self.assertEqual(tensors[1].dtype, torch.int32)
91+
self.assertGreaterEqual(tensors[1].min(), 13)
92+
self.assertLessEqual(tensors[1].max(), 51)
93+
94+
self.assertEqual(tensors[2], None)
95+
96+
97+
if __name__ == "__main__":
98+
unittest.main()

0 commit comments

Comments
 (0)