Skip to content

Commit 1220e8d

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Add device to TensorConfig (#34)
Summary: Pull Request resolved: #34 imported-using-ghimport Test Plan: Imported from OSS Rollback Plan: Reviewed By: digantdesai Differential Revision: D80468304 Pulled By: manuelcandales fbshipit-source-id: bbd928b6797ec45fee04a15161640c284611e1e6
1 parent ac8ce5f commit 1220e8d

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

examples/config_example.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,17 @@ def generate_inputs():
5353
spec = SpecDictDB["add.Tensor"]
5454

5555
config = TensorConfig(
56+
device="mps",
5657
empty=False,
5758
transposed=False,
5859
permuted=True,
5960
strided=True,
6061
).set_probability(0.7)
6162

63+
print("Generating inputs with the following config:")
64+
print("\tdevice: ", config.device)
65+
print("\tconditions: ", config.conditions)
66+
6267
generator = ArgumentTupleGenerator(spec, config=config)
6368
for ix, tup in enumerate(generator.gen()):
6469
posargs, inkwargs, outargs = tup
@@ -72,7 +77,8 @@ def test_add_op():
7277
for posargs, inkwargs, outargs in generate_inputs():
7378
try:
7479
op(*posargs, **inkwargs, **outargs)
75-
except Exception:
80+
except Exception as e:
81+
print(f"Failed with error: {e}")
7682
return False
7783
return True
7884

facto/inputgen/argument/gen.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,13 @@ def __init__(
112112
dtype: Optional[torch.dtype],
113113
structure: Tuple,
114114
space: VariableSpace,
115+
device: str = "cpu",
115116
transformation: Optional[TensorTransformation] = None,
116117
):
117118
self.dtype = dtype
118119
self.structure = structure
119120
self.space = space
121+
self.device = device
120122
self.transformation = transformation
121123

122124
def gen(self):
@@ -141,6 +143,7 @@ def gen(self):
141143
tensor = self.get_random_tensor(
142144
size=underlying_shape, dtype=self.dtype, high=max_val, low=min_val
143145
)
146+
tensor = tensor.to(self.device)
144147

145148
# Apply transformations as instructed
146149
tensor = self._apply_transformation(tensor)
@@ -192,15 +195,17 @@ def _apply_noncontiguity(self, tensor):
192195

193196
return tensor[indices]
194197

195-
def get_random_tensor(self, size, dtype, high=None, low=None):
198+
def get_random_tensor(self, size, dtype, high=None, low=None) -> torch.Tensor:
196199
torch_rng = seeded_random_manager.get_torch()
197200

198201
if low is None and high is None:
199202
low = -100
200203
high = 100
201204
elif low is None:
205+
assert high is not None
202206
low = high - 100
203207
elif high is None:
208+
assert low is not None
204209
high = low + 100
205210
size = tuple(size)
206211
if dtype == torch.bool:
@@ -262,6 +267,7 @@ def get_random_tensor(self, size, dtype, high=None, low=None):
262267
return t
263268
if dtype in floating_types():
264269
return t / FLOAT_RESOLUTION
270+
raise ValueError(f"Unsupported Dtype: {dtype}")
265271

266272

267273
class ArgumentGenerator:
@@ -270,6 +276,8 @@ def __init__(self, meta: MetaArg, config=None):
270276
self.config = config
271277

272278
def gen(self):
279+
device = "cpu" if self.config is None else self.config.device
280+
273281
if self.meta.optional:
274282
return None
275283
elif self.meta.argtype.is_tensor():
@@ -284,6 +292,7 @@ def gen(self):
284292
dtype=self.meta.dtype,
285293
structure=self.meta.structure,
286294
space=self.meta.value,
295+
device=device,
287296
transformation=transformation,
288297
).gen()
289298
elif self.meta.argtype.is_tensor_list():
@@ -302,6 +311,7 @@ def gen(self):
302311
dtype=self.meta.dtype[i],
303312
structure=self.meta.structure[i],
304313
space=self.meta.value,
314+
device=device,
305315
transformation=transformation,
306316
).gen()
307317
tensors.append(tensor)

facto/inputgen/utils/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ class Condition(str, Enum):
1515

1616

1717
class TensorConfig:
18-
def __init__(self, **conditions):
18+
def __init__(self, device="cpu", **conditions):
19+
self.device = device
1920
self.conditions = {condition: False for condition in Condition}
2021
for condition, value in conditions.items():
2122
if condition in self.conditions:

0 commit comments

Comments
 (0)