Skip to content

Commit 9f37f74

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
InputGen: Add support for out tensor arguments
Summary: Enables out tensor generation, for out variants. In the case of out tensors, only dtype information is relevant, since shape is determined by inputs. Reviewed By: SS-JIA Differential Revision: D52919719 fbshipit-source-id: 20ad8c59bbdf1e0d30243f672ed09fa4fe2e94f1
1 parent 4530318 commit 9f37f74

File tree

4 files changed

+39
-15
lines changed

4 files changed

+39
-15
lines changed

inputgen/argtuple/engine.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,11 @@ def inverse_permutation(permutation):
3939
class MetaArgTupleEngine:
4040
def __init__(self, spec: Spec, out: bool = False):
4141
if out:
42-
raise NotImplementedError("out=True is not supported yet")
43-
self.args = spec.inspec
42+
for arg in spec.outspec:
43+
arg.deps = list(range(len(spec.inspec)))
44+
self.args = spec.inspec + spec.outspec
45+
else:
46+
self.args = spec.inspec
4447
self.order = self._sort_dependencies()
4548
self.order_inverse_perm = inverse_permutation(self.order)
4649

@@ -77,7 +80,9 @@ def gen_meta_tuples(self, valid: bool, focus_ix: int):
7780
for focus in focuses:
7881
for meta_tuple in tuples:
7982
deps = self._get_deps(meta_tuple, arg.deps)
80-
engine = MetaArgEngine(arg.type, arg.constraints, deps, valid)
83+
engine = MetaArgEngine(
84+
arg.out, arg.type, arg.constraints, deps, valid
85+
)
8186
for meta_arg in engine.gen(focus):
8287
new_tuples.append(meta_tuple + (meta_arg,))
8388
tuples = new_tuples
@@ -98,7 +103,7 @@ def gen_invalid_from_valid(self, valid_tuple):
98103
# Generating invalid argument {ix} {arg.type}
99104
deps = tuple(valid_value_tuple[i] for i in arg.deps)
100105
for focus in Attribute.hierarchy(arg.type):
101-
engine = MetaArgEngine(arg.type, arg.constraints, deps, False)
106+
engine = MetaArgEngine(arg.out, arg.type, arg.constraints, deps, False)
102107
for meta_arg in engine.gen(focus):
103108
invalid_tuple = (
104109
valid_tuple[:ix] + (meta_arg,) + valid_tuple[ix + 1 :]

inputgen/argtuple/gen.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,22 @@ def __init__(self, spec: Spec):
1919
def gen_tuple(
2020
self, meta_tuple: Tuple[MetaArg], *, out: bool = False
2121
) -> Tuple[List[Any], OrderedDict[str, Any]]:
22-
args = []
23-
kwargs = OrderedDict()
22+
posargs = []
23+
inkwargs = OrderedDict()
24+
outargs = OrderedDict()
2425
for ix, arg in enumerate(self.spec.inspec):
2526
m = meta_tuple[ix]
2627
val = ArgumentGenerator(m).gen()
2728
if arg.kw:
28-
kwargs[arg.name] = val
29+
inkwargs[arg.name] = val
2930
else:
30-
args.append(val)
31+
posargs.append(val)
3132
if out:
3233
for ix, arg in enumerate(self.spec.outspec):
3334
m = meta_tuple[ix + len(self.spec.inspec)]
3435
val = ArgumentGenerator(m).gen()
35-
kwargs[arg.name] = val
36-
return args, kwargs
36+
outargs[arg.name] = val
37+
return posargs, inkwargs, outargs
3738

3839
def gen(
3940
self, *, valid: bool = True, out: bool = False

inputgen/argument/engine.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,13 @@ def rank(self, ix=None):
188188
class MetaArgEngine:
189189
def __init__(
190190
self,
191+
out: bool,
191192
argtype: ArgType,
192193
constraints: List[Constraint],
193194
deps: List[Any],
194195
valid: bool,
195196
):
197+
self.out = out
196198
self.argtype = argtype
197199
self.constraints = constraints
198200
self.deps = deps
@@ -244,6 +246,21 @@ def gen_value_spaces(self, focus, dtype, struct):
244246
def gen(self, focus):
245247
# TODO(mcandales): Enable Tensor List generation
246248

249+
if self.out:
250+
if self.argtype.is_tensor():
251+
if focus in [None, Attribute.DTYPE]:
252+
struct = (0,)
253+
for dtype in self.gen_dtypes(focus):
254+
for space in self.gen_value_spaces(focus, dtype, struct):
255+
yield MetaArg(
256+
self.argtype, dtype=dtype, structure=struct, value=space
257+
)
258+
return
259+
elif self.argtype.is_tensor_list():
260+
raise NotImplementedError("Tensor List output not implemented yet")
261+
else:
262+
raise ValueError("Output argtype must be tensor or tensor list")
263+
247264
if focus in [None, Attribute.OPTIONAL]:
248265
if self.argtype.is_optional() and self.gen_optional():
249266
yield MetaArg(self.argtype, optional=True)

test/inputgen/test_argtuple_generator.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@ def test_gen(self):
3939
outspec=[],
4040
)
4141

42-
for args, kwargs in ArgumentTupleGenerator(spec).gen():
43-
self.assertEqual(len(args), 2)
44-
self.assertEqual(kwargs, {})
45-
t = args[0]
46-
dim = args[1]
42+
for posargs, inkwargs, outargs in ArgumentTupleGenerator(spec).gen():
43+
self.assertEqual(len(posargs), 2)
44+
self.assertEqual(inkwargs, {})
45+
self.assertEqual(outargs, {})
46+
t = posargs[0]
47+
dim = posargs[1]
4748
self.assertTrue(isinstance(t, torch.Tensor))
4849
self.assertTrue(isinstance(dim, int))
4950
if t.dim() == 0:

0 commit comments

Comments
 (0)