Skip to content

Commit 4032f16

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Introduce tensor conditions (empty, transposed, permuted, strided) (#15)
Summary: This PR introduces support for `TensorConfig` in FACTO's tensor generation pipeline, enabling fine-grained control over tensor properties like empty tensors, transposed layouts, permuted dimensions, and strided memory patterns. The implementation features a clean separation between transformation generation and tensor creation through a new `TensorTransformationGenerator` class. Pull Request resolved: #15 Reviewed By: digantdesai Differential Revision: D79295653 Pulled By: manuelcandales fbshipit-source-id: 64e0f252f6ed2beb572f9369cb22f59b2e7894dc
1 parent 131d717 commit 4032f16

File tree

5 files changed

+604
-15
lines changed

5 files changed

+604
-15
lines changed

examples/config_example.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from facto.inputgen.argtuple.gen import ArgumentTupleGenerator
9+
from facto.inputgen.utils.config import TensorConfig
10+
from facto.specdb.db import SpecDictDB
11+
12+
13+
def qualify_tensor(tensor):
14+
order = tensor.dim_order()
15+
dims = sum(i != order[i] for i in range(len(order)))
16+
empty = tensor.numel() == 0
17+
transposed = dims == 2
18+
permuted = dims > 2
19+
strided = len(tensor.storage()) - tensor.storage_offset() != tensor.numel()
20+
return empty, transposed, permuted, strided
21+
22+
23+
def qualify_tensor_string(tensor):
24+
empty, transposed, permuted, strided = qualify_tensor(tensor)
25+
s = "E" if empty else ""
26+
s += "P" if permuted else "T" if transposed else ""
27+
s += "S" if strided else ""
28+
return s
29+
30+
31+
def pretty_print_add_args(posargs, inkwargs, outargs):
32+
return "".join(
33+
[
34+
"Tensor{",
35+
qualify_tensor_string(posargs[0]),
36+
"}(",
37+
str(list(posargs[0].shape)),
38+
", ",
39+
str(posargs[0].dtype)[6:],
40+
") + Tensor{",
41+
qualify_tensor_string(posargs[1]),
42+
"}(",
43+
str(list(posargs[1].shape)),
44+
", dtype=",
45+
str(posargs[0].dtype)[6:],
46+
") alpha = ",
47+
str(inkwargs["alpha"]),
48+
]
49+
)
50+
51+
52+
def generate_inputs():
53+
spec = SpecDictDB["add.Tensor"]
54+
55+
config = TensorConfig(
56+
empty=False,
57+
transposed=False,
58+
permuted=True,
59+
strided=True,
60+
).set_probability(0.7)
61+
62+
generator = ArgumentTupleGenerator(spec, config=config)
63+
for ix, tup in enumerate(generator.gen()):
64+
posargs, inkwargs, outargs = tup
65+
# Pretty printing the inputs and outputs
66+
print(f"Tuple #{ix}: {pretty_print_add_args(posargs, inkwargs, outargs)}")
67+
yield posargs, inkwargs, outargs
68+
69+
70+
def test_add_op():
71+
op = torch.ops.aten.add.Tensor
72+
for posargs, inkwargs, outargs in generate_inputs():
73+
try:
74+
op(*posargs, **inkwargs, **outargs)
75+
except Exception:
76+
return False
77+
return True
78+
79+
80+
def main():
81+
print("Testing add.Tensor with the following input tuples:")
82+
success = test_add_op()
83+
if success:
84+
print("Success!")
85+
else:
86+
print("Failure!")
87+
88+
89+
if __name__ == "__main__":
90+
main()

facto/inputgen/argtuple/gen.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,55 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from collections import OrderedDict
8-
from typing import Any, Generator, List, Tuple
8+
from copy import deepcopy
9+
from typing import Any, Generator, List, Optional, Tuple
910

1011
from facto.inputgen.argtuple.engine import MetaArgTupleEngine
1112
from facto.inputgen.argument.engine import MetaArg
1213
from facto.inputgen.argument.gen import ArgumentGenerator
13-
from facto.inputgen.specs.model import Spec
14+
from facto.inputgen.specs.model import ConstraintProducer as cp, Spec
15+
from facto.inputgen.utils.config import Condition, TensorConfig
1416

1517

1618
class ArgumentTupleGenerator:
17-
def __init__(self, spec: Spec):
19+
def __init__(self, spec: Spec, config: Optional[TensorConfig] = None):
1820
self.spec = spec
21+
self.config = config
22+
self._modified_spec = self._apply_config_constraints(spec, config)
23+
24+
def _apply_config_constraints(
25+
self, spec: Spec, config: Optional[TensorConfig]
26+
) -> Spec:
27+
"""Apply TensorConfig constraints to the spec by modifying argument constraints."""
28+
29+
if config is None:
30+
return spec
31+
32+
# Create a copy of the spec with modified constraints
33+
modified_inspec = []
34+
for arg in spec.inspec:
35+
modified_arg = self._apply_constraints_to_arg(arg, config)
36+
modified_inspec.append(modified_arg)
37+
38+
modified_outspec = []
39+
for arg in spec.outspec:
40+
modified_arg = self._apply_constraints_to_arg(arg, config)
41+
modified_outspec.append(modified_arg)
42+
43+
return Spec(spec.op, modified_inspec, modified_outspec)
44+
45+
def _apply_constraints_to_arg(self, arg, config: TensorConfig):
46+
"""Apply config constraints to a single argument."""
47+
# Create a copy of the argument with potentially modified constraints
48+
modified_arg = deepcopy(arg)
49+
50+
# Add size constraints for tensor arguments when empty tensors are not allowed
51+
if not config.is_allowed(Condition.ALLOW_EMPTY):
52+
if arg.type.is_tensor() or arg.type.is_tensor_list():
53+
size_constraint = cp.Size.Ge(lambda deps, r, d: 1)
54+
modified_arg.constraints = modified_arg.constraints + [size_constraint]
55+
56+
return modified_arg
1957

2058
def gen_tuple(
2159
self, meta_tuple: Tuple[MetaArg], *, out: bool = False
@@ -25,15 +63,15 @@ def gen_tuple(
2563
outargs = OrderedDict()
2664
for ix, arg in enumerate(self.spec.inspec):
2765
m = meta_tuple[ix]
28-
val = ArgumentGenerator(m).gen()
66+
val = ArgumentGenerator(m, config=self.config).gen()
2967
if arg.kw:
3068
inkwargs[arg.name] = val
3169
else:
3270
posargs.append(val)
3371
if out:
3472
for ix, arg in enumerate(self.spec.outspec):
3573
m = meta_tuple[ix + len(self.spec.inspec)]
36-
val = ArgumentGenerator(m).gen()
74+
val = ArgumentGenerator(m, config=self.config).gen()
3775
outargs[arg.name] = val
3876
return posargs, inkwargs, outargs
3977

@@ -42,6 +80,6 @@ def gen(
4280
) -> Generator[
4381
Tuple[List[Any], OrderedDict[str, Any], OrderedDict[str, Any]], Any, Any
4482
]:
45-
engine = MetaArgTupleEngine(self.spec, out=out)
83+
engine = MetaArgTupleEngine(self._modified_spec, out=out)
4684
for meta_tuple in engine.gen(valid=valid):
4785
yield self.gen_tuple(meta_tuple, out=out)

0 commit comments

Comments
 (0)