@@ -112,11 +112,13 @@ def __init__(
112
112
dtype : Optional [torch .dtype ],
113
113
structure : Tuple ,
114
114
space : VariableSpace ,
115
+ device : str = "cpu" ,
115
116
transformation : Optional [TensorTransformation ] = None ,
116
117
):
117
118
self .dtype = dtype
118
119
self .structure = structure
119
120
self .space = space
121
+ self .device = device
120
122
self .transformation = transformation
121
123
122
124
def gen (self ):
@@ -141,6 +143,7 @@ def gen(self):
141
143
tensor = self .get_random_tensor (
142
144
size = underlying_shape , dtype = self .dtype , high = max_val , low = min_val
143
145
)
146
+ tensor = tensor .to (self .device )
144
147
145
148
# Apply transformations as instructed
146
149
tensor = self ._apply_transformation (tensor )
@@ -192,15 +195,17 @@ def _apply_noncontiguity(self, tensor):
192
195
193
196
return tensor [indices ]
194
197
195
- def get_random_tensor (self , size , dtype , high = None , low = None ):
198
+ def get_random_tensor (self , size , dtype , high = None , low = None ) -> torch . Tensor :
196
199
torch_rng = seeded_random_manager .get_torch ()
197
200
198
201
if low is None and high is None :
199
202
low = - 100
200
203
high = 100
201
204
elif low is None :
205
+ assert high is not None
202
206
low = high - 100
203
207
elif high is None :
208
+ assert low is not None
204
209
high = low + 100
205
210
size = tuple (size )
206
211
if dtype == torch .bool :
@@ -262,6 +267,7 @@ def get_random_tensor(self, size, dtype, high=None, low=None):
262
267
return t
263
268
if dtype in floating_types ():
264
269
return t / FLOAT_RESOLUTION
270
+ raise ValueError (f"Unsupported Dtype: { dtype } " )
265
271
266
272
267
273
class ArgumentGenerator :
@@ -270,6 +276,8 @@ def __init__(self, meta: MetaArg, config=None):
270
276
self .config = config
271
277
272
278
def gen (self ):
279
+ device = "cpu" if self .config is None else self .config .device
280
+
273
281
if self .meta .optional :
274
282
return None
275
283
elif self .meta .argtype .is_tensor ():
@@ -284,6 +292,7 @@ def gen(self):
284
292
dtype = self .meta .dtype ,
285
293
structure = self .meta .structure ,
286
294
space = self .meta .value ,
295
+ device = device ,
287
296
transformation = transformation ,
288
297
).gen ()
289
298
elif self .meta .argtype .is_tensor_list ():
@@ -302,6 +311,7 @@ def gen(self):
302
311
dtype = self .meta .dtype [i ],
303
312
structure = self .meta .structure [i ],
304
313
space = self .meta .value ,
314
+ device = device ,
305
315
transformation = transformation ,
306
316
).gen ()
307
317
tensors .append (tensor )
0 commit comments