diff --git a/examples/config_example.py b/examples/config_example.py index 06dc711..1553129 100644 --- a/examples/config_example.py +++ b/examples/config_example.py @@ -53,12 +53,17 @@ def generate_inputs(): spec = SpecDictDB["add.Tensor"] config = TensorConfig( + device="mps", empty=False, transposed=False, permuted=True, strided=True, ).set_probability(0.7) + print("Generating inputs with the following config:") + print("\tdevice: ", config.device) + print("\tconditions: ", config.conditions) + generator = ArgumentTupleGenerator(spec, config=config) for ix, tup in enumerate(generator.gen()): posargs, inkwargs, outargs = tup @@ -72,7 +77,8 @@ def test_add_op(): for posargs, inkwargs, outargs in generate_inputs(): try: op(*posargs, **inkwargs, **outargs) - except Exception: + except Exception as e: + print(f"Failed with error: {e}") return False return True diff --git a/facto/inputgen/argument/gen.py b/facto/inputgen/argument/gen.py index c66c27f..caffb31 100644 --- a/facto/inputgen/argument/gen.py +++ b/facto/inputgen/argument/gen.py @@ -112,11 +112,13 @@ def __init__( dtype: Optional[torch.dtype], structure: Tuple, space: VariableSpace, + device: str = "cpu", transformation: Optional[TensorTransformation] = None, ): self.dtype = dtype self.structure = structure self.space = space + self.device = device self.transformation = transformation def gen(self): @@ -141,6 +143,7 @@ def gen(self): tensor = self.get_random_tensor( size=underlying_shape, dtype=self.dtype, high=max_val, low=min_val ) + tensor = tensor.to(self.device) # Apply transformations as instructed tensor = self._apply_transformation(tensor) @@ -192,15 +195,17 @@ def _apply_noncontiguity(self, tensor): return tensor[indices] - def get_random_tensor(self, size, dtype, high=None, low=None): + def get_random_tensor(self, size, dtype, high=None, low=None) -> torch.Tensor: torch_rng = seeded_random_manager.get_torch() if low is None and high is None: low = -100 high = 100 elif low is None: + assert high is not None low = high - 100 elif high is None: + assert low is not None high = low + 100 size = tuple(size) if dtype == torch.bool: @@ -262,6 +267,7 @@ def get_random_tensor(self, size, dtype, high=None, low=None): return t if dtype in floating_types(): return t / FLOAT_RESOLUTION + raise ValueError(f"Unsupported Dtype: {dtype}") class ArgumentGenerator: @@ -270,6 +276,8 @@ def __init__(self, meta: MetaArg, config=None): self.config = config def gen(self): + device = "cpu" if self.config is None else self.config.device + if self.meta.optional: return None elif self.meta.argtype.is_tensor(): @@ -284,6 +292,7 @@ def gen(self): dtype=self.meta.dtype, structure=self.meta.structure, space=self.meta.value, + device=device, transformation=transformation, ).gen() elif self.meta.argtype.is_tensor_list(): @@ -302,6 +311,7 @@ def gen(self): dtype=self.meta.dtype[i], structure=self.meta.structure[i], space=self.meta.value, + device=device, transformation=transformation, ).gen() tensors.append(tensor) diff --git a/facto/inputgen/utils/config.py b/facto/inputgen/utils/config.py index 17046d2..bb16f84 100644 --- a/facto/inputgen/utils/config.py +++ b/facto/inputgen/utils/config.py @@ -15,7 +15,8 @@ class Condition(str, Enum): class TensorConfig: - def __init__(self, **conditions): + def __init__(self, device="cpu", **conditions): + self.device = device self.conditions = {condition: False for condition in Condition} for condition, value in conditions.items(): if condition in self.conditions: