5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import random
8
- from typing import Any , List , Optional
8
+ from typing import Any , List , Optional , Tuple , Union
9
9
10
+ import torch
10
11
from inputgen .argument .type import ArgType
11
12
from inputgen .attribute .engine import AttributeEngine
12
13
from inputgen .attribute .model import Attribute
13
- from inputgen .specs .model import Constraint
14
+ from inputgen .attribute .solve import AttributeSolver
15
+ from inputgen .specs .model import Constraint , ConstraintSuffix
16
+ from inputgen .variable .type import ScalarDtype
14
17
15
18
16
19
class StructuralEngine :
@@ -27,6 +30,11 @@ def __init__(
27
30
self .valid = valid
28
31
self .hierarchy = StructuralEngine .hierarchy (argtype )
29
32
33
+ self .gen_list_mode = set ()
34
+ for constraint in constraints :
35
+ if constraint .suffix == ConstraintSuffix .GEN :
36
+ self .gen_list_mode .add (constraint .attribute )
37
+
30
38
@staticmethod
31
39
def hierarchy (argtype ) -> List [Attribute ]:
32
40
"""Return the structural hierarchy for a given argument type"""
@@ -47,6 +55,11 @@ def gen_structure_with_depth_and_length(
47
55
return
48
56
49
57
attr = self .hierarchy [- (depth + 1 )]
58
+
59
+ if attr in self .gen_list_mode :
60
+ yield from self .gen_structure_with_depth (depth , focus , length )
61
+ return
62
+
50
63
focus_ixs = range (length ) if focus == attr else (random .choice (range (length )),)
51
64
for focus_ix in focus_ixs :
52
65
values = [()]
@@ -93,3 +106,167 @@ def gen_structure_with_depth(
93
106
def gen (self , focus : Attribute ):
94
107
depth = len (self .hierarchy ) - 1
95
108
yield from self .gen_structure_with_depth (depth , focus )
109
+
110
+
111
+ class MetaArg :
112
+ def __init__ (
113
+ self ,
114
+ argtype : ArgType ,
115
+ * ,
116
+ optional : bool = False ,
117
+ dtype : Optional [
118
+ Union [torch .dtype , List [Optional [torch .dtype ]], ScalarDtype ]
119
+ ] = None ,
120
+ structure : Optional [Tuple ] = None ,
121
+ value : Optional [Any ] = None ,
122
+ ):
123
+ self .argtype = argtype
124
+ self .optional = optional
125
+ self .dtype = dtype
126
+ self .structure = structure
127
+ self .value = value
128
+
129
+ if not self .argtype .is_optional () and self .optional :
130
+ raise ValueError ("Only optional argtypes can have optional instances" )
131
+
132
+ if self .argtype .is_tensor_list ():
133
+ if len (self .structure ) != len (self .dtype ):
134
+ raise ValueError (
135
+ "Structure and dtype must be same length when tensor list"
136
+ )
137
+ if self .argtype == ArgType .TensorList and any (
138
+ d is None for d in self .dtype
139
+ ):
140
+ raise ValueError ("Only TensorOptList can have None in list of dtypes" )
141
+
142
+ if not self .optional and Attribute .DTYPE not in Attribute .hierarchy (
143
+ self .argtype
144
+ ):
145
+ if argtype .is_list ():
146
+ self .value = list (self .structure )
147
+ else :
148
+ self .value = self .structure
149
+
150
+ def __str__ (self ):
151
+ if self .optional :
152
+ strval = "None"
153
+ elif self .argtype .is_tensor_list ():
154
+ strval = (
155
+ "["
156
+ + ", " .join (
157
+ [
158
+ f"{ self .dtype [i ]} { self .structure [i ]} "
159
+ for i in range (len (self .dtype ))
160
+ ]
161
+ )
162
+ + "]"
163
+ )
164
+ elif self .argtype .is_tensor ():
165
+ strval = f"{ self .dtype } { self .structure } "
166
+ else :
167
+ strval = str (self .value )
168
+ return f"{ self .argtype } { strval } "
169
+
170
+ def length (self ):
171
+ if self .argtype .is_list ():
172
+ return len (self .structure )
173
+ else :
174
+ return None
175
+
176
+ def rank (self , ix = None ):
177
+ if self .argtype .is_tensor ():
178
+ return len (self .structure )
179
+ elif self .argtype .is_tensor_list ():
180
+ if ix is None :
181
+ return (len (s ) for s in self .structure )
182
+ else :
183
+ return len (self .structure [ix ])
184
+ else :
185
+ return None
186
+
187
+
188
+ class MetaArgEngine :
189
+ def __init__ (
190
+ self ,
191
+ argtype : ArgType ,
192
+ constraints : List [Constraint ],
193
+ deps : List [Any ],
194
+ valid : bool ,
195
+ ):
196
+ self .argtype = argtype
197
+ self .constraints = constraints
198
+ self .deps = deps
199
+ self .valid = valid
200
+
201
+ def gen_structures (self , focus ):
202
+ if self .argtype .is_scalar ():
203
+ yield None
204
+ else :
205
+ yield from StructuralEngine (
206
+ self .argtype , self .constraints , self .deps , self .valid
207
+ ).gen (focus )
208
+
209
+ def gen_dtypes (self , focus ):
210
+ if Attribute .DTYPE not in Attribute .hierarchy (self .argtype ):
211
+ return {None }
212
+ engine = AttributeEngine (
213
+ Attribute .DTYPE , self .constraints , self .valid , self .argtype
214
+ )
215
+ if self .argtype .is_scalar () and focus == Attribute .VALUE :
216
+ # if focused on a scalar value, must generate all dtypes too
217
+ focus = Attribute .DTYPE
218
+ return engine .gen (focus , self .deps )
219
+
220
+ def gen_optional (self ):
221
+ engine = AttributeEngine (
222
+ Attribute .OPTIONAL , self .constraints , self .valid , self .argtype
223
+ )
224
+ return True in engine .gen (Attribute .OPTIONAL , self .deps )
225
+
226
+ def gen_scalars (self , scalar_dtype , focus ):
227
+ engine = AttributeEngine (
228
+ Attribute .VALUE , self .constraints , self .valid , self .argtype , scalar_dtype
229
+ )
230
+ return engine .gen (focus , self .deps , scalar_dtype )
231
+
232
+ def gen_value_spaces (self , focus , dtype , struct ):
233
+ if not self .argtype .is_tensor () and not self .argtype .is_tensor_list ():
234
+ return [None ]
235
+ solver = AttributeSolver (Attribute .VALUE , self .argtype )
236
+ variables = list (
237
+ solver .solve (self .constraints , focus , self .valid , self .deps , dtype , struct )
238
+ )
239
+ if focus == Attribute .VALUE :
240
+ return [v .space for v in variables ]
241
+ else :
242
+ return [random .choice (variables ).space ]
243
+
244
+ def gen (self , focus ):
245
+ # TODO(mcandales): Enable Tensor List generation
246
+
247
+ if focus in [None , Attribute .OPTIONAL ]:
248
+ if self .argtype .is_optional () and self .gen_optional ():
249
+ yield MetaArg (self .argtype , optional = True )
250
+ if focus == Attribute .OPTIONAL :
251
+ return
252
+
253
+ if self .argtype .is_scalar ():
254
+ scalar_dtypes = self .gen_dtypes (focus )
255
+ for scalar_dtype in scalar_dtypes :
256
+ for value in self .gen_scalars (scalar_dtype , focus ):
257
+ yield MetaArg (self .argtype , dtype = scalar_dtype , value = value )
258
+ else :
259
+ if focus == Attribute .DTYPE :
260
+ for dtype in self .gen_dtypes (focus ):
261
+ for struct in self .gen_structures (focus ):
262
+ for space in self .gen_value_spaces (focus , dtype , struct ):
263
+ yield MetaArg (
264
+ self .argtype , dtype = dtype , structure = struct , value = space
265
+ )
266
+ else :
267
+ for struct in self .gen_structures (focus ):
268
+ for dtype in self .gen_dtypes (focus ):
269
+ for space in self .gen_value_spaces (focus , dtype , struct ):
270
+ yield MetaArg (
271
+ self .argtype , dtype = dtype , structure = struct , value = space
272
+ )
0 commit comments