Skip to content

Commit 493aab1

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
InputGen: introduce argument tuple generator
Reviewed By: SS-JIA Differential Revision: D52308290 fbshipit-source-id: c4e50f7f13e574e0c32f797d9d4f5e4e298751be
1 parent fd18020 commit 493aab1

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed

inputgen/argtuple/gen.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, List, OrderedDict, Tuple
8+
9+
from inputgen.argtuple.engine import MetaArgTupleEngine
10+
from inputgen.argument.engine import MetaArg
11+
from inputgen.argument.gen import ArgumentGenerator
12+
from inputgen.specs.model import Spec
13+
14+
15+
class ArgumentTupleGenerator:
16+
def __init__(self, spec: Spec):
17+
self.spec = spec
18+
19+
def gen_tuple(
20+
self, meta_tuple: Tuple[MetaArg], *, out: bool = False
21+
) -> Tuple[List[Any], OrderedDict[str, Any]]:
22+
args = []
23+
kwargs = OrderedDict()
24+
for ix, arg in enumerate(self.spec.inspec):
25+
m = meta_tuple[ix]
26+
val = ArgumentGenerator(m).gen()
27+
if arg.kw:
28+
kwargs[arg.name] = val
29+
else:
30+
args.append(val)
31+
if out:
32+
for ix, arg in enumerate(self.spec.outspec):
33+
m = meta_tuple[ix + len(self.spec.inspec)]
34+
val = ArgumentGenerator(m).gen()
35+
kwargs[arg.name] = val
36+
return args, kwargs
37+
38+
def gen(
39+
self, *, valid: bool = True, out: bool = False
40+
) -> Tuple[List[Any], OrderedDict[str, Any]]:
41+
engine = MetaArgTupleEngine(self.spec, out=out)
42+
for meta_tuple in engine.gen(valid=valid):
43+
yield self.gen_tuple(meta_tuple, out=out)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from inputgen.argtuple.gen import ArgumentTupleGenerator
11+
from inputgen.argument.type import ArgType
12+
from inputgen.specs.model import ConstraintProducer as cp, InPosArg, Spec
13+
14+
15+
class TestArgumentTupleGenerator(unittest.TestCase):
16+
def test_gen(self):
17+
spec = Spec(
18+
op="test_size", # (Tensor self, int dim) -> int
19+
inspec=[
20+
InPosArg(ArgType.Tensor, name="self"),
21+
InPosArg(
22+
ArgType.Dim,
23+
name="dim",
24+
deps=[0],
25+
constraints=[
26+
cp.Value.Ge(
27+
lambda deps: -deps[0].dim() if deps[0].dim() > 0 else None
28+
),
29+
cp.Value.Ge(lambda deps: -1 if deps[0].dim() == 0 else None),
30+
cp.Value.Le(
31+
lambda deps: deps[0].dim() - 1
32+
if deps[0].dim() > 0
33+
else None
34+
),
35+
cp.Value.Le(lambda deps: 0 if deps[0].dim() == 0 else None),
36+
],
37+
),
38+
],
39+
outspec=[],
40+
)
41+
42+
for args, kwargs in ArgumentTupleGenerator(spec).gen():
43+
self.assertEqual(len(args), 2)
44+
self.assertEqual(kwargs, {})
45+
t = args[0]
46+
dim = args[1]
47+
self.assertTrue(isinstance(t, torch.Tensor))
48+
self.assertTrue(isinstance(dim, int))
49+
if t.dim() == 0:
50+
self.assertTrue(dim in [-1, 0])
51+
else:
52+
self.assertTrue(dim >= -t.dim())
53+
self.assertTrue(dim < t.dim())
54+
55+
56+
if __name__ == "__main__":
57+
unittest.main()

0 commit comments

Comments
 (0)