Skip to content

Commit b4093d4

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
InputGen: Introduce specs
Reviewed By: SS-JIA Differential Revision: D52308291 fbshipit-source-id: 4d4cb8d54ea6585e8e7a6403a700a1f243f3e17c
1 parent 34ece3a commit b4093d4

File tree

2 files changed

+133
-1
lines changed

2 files changed

+133
-1
lines changed

inputgen/specs/model.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66

77
from dataclasses import dataclass
88
from enum import Enum
9-
from typing import Callable
9+
from typing import Callable, List, Optional
1010

11+
from inputgen.argument.type import ArgType
1112
from inputgen.attribute.model import Attribute
1213

1314

@@ -63,3 +64,86 @@ class ConstraintProducer:
6364
Rank = ConstraintAttributeSuffixes(Attribute.RANK)
6465
Size = ConstraintAttributeSuffixes(Attribute.SIZE)
6566
Value = ConstraintAttributeSuffixes(Attribute.VALUE)
67+
68+
69+
class BaseArg:
70+
def __init__(
71+
self,
72+
argtype: ArgType,
73+
name: str,
74+
deps: Optional[List[int]] = None,
75+
constraints: Optional[List[Constraint]] = None,
76+
):
77+
self.name: str = name
78+
self.type: ArgType = argtype
79+
self.deps = () if deps is None else tuple(deps)
80+
self.constraints = [] if constraints is None else constraints
81+
self._kw: bool = False
82+
self._out: bool = False
83+
self._ret: bool = False
84+
85+
@property
86+
def kw(self):
87+
return self._kw
88+
89+
@kw.setter
90+
def kw(self, v):
91+
if not isinstance(v, bool):
92+
raise ValueError("kw property should be boolean")
93+
self._kw = v
94+
95+
@property
96+
def out(self):
97+
return self._out
98+
99+
@out.setter
100+
def out(self, v):
101+
if not isinstance(v, bool):
102+
raise ValueError("out property should be boolean")
103+
self._out = v
104+
105+
@property
106+
def ret(self):
107+
return self._ret
108+
109+
@ret.setter
110+
def ret(self, v):
111+
if not isinstance(v, bool):
112+
raise ValueError("ret property should be boolean")
113+
self._ret = v
114+
115+
116+
class InArg(BaseArg):
117+
def __init__(self, *args, **kwargs):
118+
BaseArg.__init__(self, *args, **kwargs)
119+
120+
121+
class InPosArg(InArg):
122+
def __init__(self, *args, **kwargs):
123+
BaseArg.__init__(self, *args, **kwargs)
124+
125+
126+
class InKwArg(InArg):
127+
def __init__(self, *args, **kwargs):
128+
BaseArg.__init__(self, *args, **kwargs)
129+
self._kw = True
130+
131+
132+
class OutArg(BaseArg):
133+
def __init__(self, argtype: ArgType, name: str = "out", *args, **kwargs):
134+
BaseArg.__init__(self, argtype, name, *args, **kwargs)
135+
self._kw = True
136+
self._out = True
137+
138+
139+
class Return(BaseArg):
140+
def __init__(self, argtype: ArgType, name: str = "__ret", *args, **kwargs):
141+
BaseArg.__init__(self, argtype, name, *args, **kwargs)
142+
self._ret = True
143+
144+
145+
class Spec:
146+
def __init__(self, op: str, inspec: List[InArg], outspec: List[OutArg]):
147+
self.op = op
148+
self.inspec = inspec
149+
self.outspec = outspec

test/inputgen/test_specs.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
from inputgen.argument.type import ArgType
10+
from inputgen.specs.model import InKwArg, InPosArg, OutArg, Return
11+
12+
13+
class TestArgSpecs(unittest.TestCase):
14+
def test_inpos(self):
15+
arg = InPosArg(ArgType.Tensor, name="self")
16+
self.assertEqual(arg.name, "self")
17+
self.assertEqual(arg.type, ArgType.Tensor)
18+
self.assertFalse(arg.kw)
19+
self.assertFalse(arg.out)
20+
self.assertFalse(arg.ret)
21+
22+
def test_inkw(self):
23+
arg = InKwArg(ArgType.Scalar, name="alpha")
24+
self.assertEqual(arg.name, "alpha")
25+
self.assertEqual(arg.type, ArgType.Scalar)
26+
self.assertTrue(arg.kw)
27+
self.assertFalse(arg.out)
28+
self.assertFalse(arg.ret)
29+
30+
def test_out(self):
31+
arg = OutArg(ArgType.TensorList)
32+
self.assertEqual(arg.name, "out")
33+
self.assertEqual(arg.type, ArgType.TensorList)
34+
self.assertTrue(arg.kw)
35+
self.assertTrue(arg.out)
36+
self.assertFalse(arg.ret)
37+
38+
def test_ret(self):
39+
arg = Return(ArgType.Tensor)
40+
self.assertEqual(arg.name, "__ret")
41+
self.assertEqual(arg.type, ArgType.Tensor)
42+
self.assertFalse(arg.kw)
43+
self.assertFalse(arg.out)
44+
self.assertTrue(arg.ret)
45+
46+
47+
if __name__ == "__main__":
48+
unittest.main()

0 commit comments

Comments
 (0)