Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion examples/config_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,17 @@ def generate_inputs():
spec = SpecDictDB["add.Tensor"]

config = TensorConfig(
device="mps",
empty=False,
transposed=False,
permuted=True,
strided=True,
).set_probability(0.7)

print("Generating inputs with the following config:")
print("\tdevice: ", config.device)
print("\tconditions: ", config.conditions)

generator = ArgumentTupleGenerator(spec, config=config)
for ix, tup in enumerate(generator.gen()):
posargs, inkwargs, outargs = tup
Expand All @@ -72,7 +77,8 @@ def test_add_op():
for posargs, inkwargs, outargs in generate_inputs():
try:
op(*posargs, **inkwargs, **outargs)
except Exception:
except Exception as e:
print(f"Failed with error: {e}")
return False
return True

Expand Down
12 changes: 11 additions & 1 deletion facto/inputgen/argument/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,13 @@ def __init__(
dtype: Optional[torch.dtype],
structure: Tuple,
space: VariableSpace,
device: str = "cpu",
transformation: Optional[TensorTransformation] = None,
):
self.dtype = dtype
self.structure = structure
self.space = space
self.device = device
self.transformation = transformation

def gen(self):
Expand All @@ -141,6 +143,7 @@ def gen(self):
tensor = self.get_random_tensor(
size=underlying_shape, dtype=self.dtype, high=max_val, low=min_val
)
tensor = tensor.to(self.device)

# Apply transformations as instructed
tensor = self._apply_transformation(tensor)
Expand Down Expand Up @@ -192,15 +195,17 @@ def _apply_noncontiguity(self, tensor):

return tensor[indices]

def get_random_tensor(self, size, dtype, high=None, low=None):
def get_random_tensor(self, size, dtype, high=None, low=None) -> torch.Tensor:
torch_rng = seeded_random_manager.get_torch()

if low is None and high is None:
low = -100
high = 100
elif low is None:
assert high is not None
low = high - 100
elif high is None:
assert low is not None
high = low + 100
size = tuple(size)
if dtype == torch.bool:
Expand Down Expand Up @@ -262,6 +267,7 @@ def get_random_tensor(self, size, dtype, high=None, low=None):
return t
if dtype in floating_types():
return t / FLOAT_RESOLUTION
raise ValueError(f"Unsupported Dtype: {dtype}")


class ArgumentGenerator:
Expand All @@ -270,6 +276,8 @@ def __init__(self, meta: MetaArg, config=None):
self.config = config

def gen(self):
device = "cpu" if self.config is None else self.config.device

if self.meta.optional:
return None
elif self.meta.argtype.is_tensor():
Expand All @@ -284,6 +292,7 @@ def gen(self):
dtype=self.meta.dtype,
structure=self.meta.structure,
space=self.meta.value,
device=device,
transformation=transformation,
).gen()
elif self.meta.argtype.is_tensor_list():
Expand All @@ -302,6 +311,7 @@ def gen(self):
dtype=self.meta.dtype[i],
structure=self.meta.structure[i],
space=self.meta.value,
device=device,
transformation=transformation,
).gen()
tensors.append(tensor)
Expand Down
3 changes: 2 additions & 1 deletion facto/inputgen/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ class Condition(str, Enum):


class TensorConfig:
def __init__(self, **conditions):
def __init__(self, device="cpu", **conditions):
self.device = device
self.conditions = {condition: False for condition in Condition}
for condition, value in conditions.items():
if condition in self.conditions:
Expand Down