Skip to content

Commit cd66e79

Browse files
authored
Op Input Generation via FACTO (#62)
1 parent 115a865 commit cd66e79

File tree

9 files changed

+367
-3
lines changed

9 files changed

+367
-3
lines changed

.github/workflows/smoke-test.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,17 @@ jobs:
2424
- name: Install package and dependencies
2525
run: uv sync --dev
2626

27+
- name: Clone FACTO source
28+
run: git clone https://github.com/pytorch-labs/FACTO.git
29+
30+
- name: Build and install FACTO
31+
run: cd FACTO && uv pip install .
32+
2733
- name: Run smoke test
2834
run: uv run python -m BackendBench.scripts.main --suite smoke --backend aten
2935

36+
- name: Run FACTO test
37+
run: uv run python -m BackendBench.scripts.main --suite facto --backend aten --ops "add.Tensor"
38+
3039
- name: Run pytest tests
3140
run: uv run pytest test/

BackendBench/backends/flag_gems.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import torch
22

3+
from BackendBench.opregistry import register_operator
4+
35
from .base import Backend
46

57
try:
@@ -284,6 +286,9 @@ def __init__(self) -> None:
284286
torch.ops.aten.eye.m: flag_gems.ops.eye_m,
285287
torch.ops.aten.to.dtype: flag_gems.ops.to_dtype,
286288
}
289+
# Register all operators in the global registry to ensure consistent object identity
290+
for op in self.ops.keys():
291+
register_operator(op)
287292

288293
def __getitem__(self, key):
289294
return self.ops[key]

BackendBench/facto_suite.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import logging
2+
from collections import defaultdict
3+
4+
import torch
5+
from torch.utils._python_dispatch import TorchDispatchMode
6+
7+
try:
8+
from facto.inputgen.argtuple.gen import ArgumentTupleGenerator
9+
from facto.inputgen.utils.config import TensorConfig
10+
from facto.specdb.db import SpecDictDB
11+
except ImportError:
12+
ArgumentTupleGenerator = None
13+
TensorConfig = None
14+
SpecDictDB = None
15+
16+
17+
from .eval import allclose
18+
from .opregistry import get_operator
19+
from .suite import OpTest, TestSuite
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class FactoTest:
25+
def __init__(self, *args, **kwargs):
26+
self.args = args
27+
self.kwargs = kwargs
28+
29+
30+
class FactoOpTest(OpTest):
31+
def __init__(self, op, correctness_tests):
32+
self.op = op
33+
self._correctness_tests = correctness_tests
34+
self.performance_tests = []
35+
36+
@property
37+
def correctness_tests(self):
38+
for test in self._correctness_tests:
39+
yield FactoTest(*test.args, **test.kwargs)
40+
41+
42+
class OpTracerMode(TorchDispatchMode):
43+
def __init__(self):
44+
self.ops = []
45+
self.args = []
46+
self.kwargs = []
47+
48+
def __torch_dispatch__(self, fn, types, args=(), kwargs={}):
49+
self.ops.append(fn)
50+
self.args.append(args)
51+
self.kwargs.append(kwargs)
52+
return fn(*args, **kwargs)
53+
54+
55+
def build_facto_op_tests(device, dtype, filter=None, num_runs=10, empty=False, probability=1.0):
56+
facto_op_tests = []
57+
failed = []
58+
for spec_name in SpecDictDB:
59+
try:
60+
if filter and spec_name not in filter:
61+
continue
62+
63+
# Get canonical operator from registry
64+
op = get_operator(spec_name)
65+
if op is None:
66+
logger.debug(f"Skipping {spec_name}: operator resolution failed")
67+
continue
68+
69+
config = TensorConfig(
70+
empty=empty,
71+
).set_probability(probability)
72+
73+
spec = SpecDictDB[spec_name]
74+
generator = ArgumentTupleGenerator(spec, config)
75+
76+
op_tests = defaultdict(list)
77+
78+
for idx, (posargs, inkwargs, outargs) in enumerate(generator.gen()):
79+
if idx >= num_runs:
80+
break
81+
82+
# Filter arguments to target device/dtype
83+
filtered_posargs = []
84+
for arg in posargs:
85+
if isinstance(arg, torch.Tensor):
86+
arg = arg.to(device=device, dtype=dtype)
87+
filtered_posargs.append(arg)
88+
89+
filtered_inkwargs = {}
90+
for k, v in inkwargs.items():
91+
if isinstance(v, torch.Tensor):
92+
v = v.to(device=device, dtype=dtype)
93+
filtered_inkwargs[k] = v
94+
95+
filtered_outargs = {}
96+
for k, v in outargs.items():
97+
if isinstance(v, torch.Tensor):
98+
v = v.to(device=device, dtype=dtype)
99+
filtered_outargs[k] = v
100+
101+
all_kwargs = {**filtered_inkwargs, **filtered_outargs}
102+
103+
try:
104+
# Trace execution to find underlying PyTorch ops
105+
with OpTracerMode() as tracer:
106+
ref = op(*filtered_posargs, **all_kwargs)
107+
except Exception:
108+
logger.debug(f"FACTO spec {spec_name} couldn't run underlying op {op}")
109+
continue
110+
111+
# Check if we captured exactly one op (clean mapping)
112+
if len(tracer.ops) == 1:
113+
try:
114+
# Verify the traced op produces the same result
115+
res = tracer.ops[0](*filtered_posargs, **all_kwargs)
116+
if allclose(ref, res):
117+
op_tests[tracer.ops[0]].append(
118+
FactoTest(*filtered_posargs, **all_kwargs)
119+
)
120+
except Exception:
121+
logger.debug(
122+
f"FACTO spec {spec_name} couldn't run underlying op {tracer.ops[0]}"
123+
)
124+
else:
125+
logger.debug(f"FACTO spec {spec_name} has {len(tracer.ops)} ops")
126+
127+
for traced_op, tests in op_tests.items():
128+
if len(tests) > 0:
129+
facto_op_tests.append(FactoOpTest(traced_op, tests))
130+
except Exception:
131+
logger.debug(f"FACTO spec {spec_name} failed")
132+
failed.append(spec_name)
133+
134+
logger.debug(f"Failed specs: {failed}")
135+
136+
return facto_op_tests
137+
138+
139+
class FactoTestSuite(TestSuite):
140+
def __init__(self, name, device, dtype, filter=None, num_runs=10, empty=False, probability=1.0):
141+
super().__init__(
142+
name, build_facto_op_tests(device, dtype, filter, num_runs, empty, probability)
143+
)

BackendBench/opregistry.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import logging
2+
3+
import torch
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
def _extract_spec_name_from_op(op_obj):
9+
try:
10+
# PyTorch operator objects have _name attribute that contains the full name
11+
if hasattr(op_obj, "_name"):
12+
full_name = op_obj._name
13+
# full_name is typically like "aten::add.Tensor"
14+
if "::" in full_name:
15+
# Remove the "aten::" prefix
16+
spec_name = full_name.split("::", 1)[1]
17+
return spec_name
18+
return None
19+
20+
except Exception as e:
21+
logger.debug(f"Failed to extract spec name from operator {op_obj}: {e}")
22+
return None
23+
24+
25+
class OpRegistry:
26+
def __init__(self):
27+
self._registry = {}
28+
29+
def get_operator(self, input_obj):
30+
if isinstance(input_obj, str):
31+
return self._get_operator_from_spec_name(input_obj)
32+
else:
33+
return self._get_operator_from_object(input_obj)
34+
35+
def _get_operator_from_spec_name(self, spec_name):
36+
# Return cached operator if available
37+
if spec_name in self._registry:
38+
return self._registry[spec_name]
39+
40+
# Parse spec name
41+
op_parts = spec_name.split(".")
42+
op_name = op_parts[0]
43+
overload = op_parts[1] if len(op_parts) > 1 else "default"
44+
45+
try:
46+
# Resolve operator using PyTorch's API
47+
op = getattr(torch.ops.aten, op_name).__getattr__(overload)
48+
49+
# Cache the resolved operator
50+
self._registry[spec_name] = op
51+
# logger.debug(f"Registered operator: {spec_name} -> {op}")
52+
return op
53+
54+
except AttributeError as e:
55+
logger.warning(f"Failed to resolve operator {spec_name}: {e}")
56+
return None
57+
58+
def _get_operator_from_object(self, op_obj):
59+
# Extract spec name from the operator object
60+
spec_name = _extract_spec_name_from_op(op_obj)
61+
62+
# Check if we already have this operator registered
63+
if spec_name in self._registry:
64+
return self._registry[spec_name]
65+
66+
# Register the provided operator object
67+
self._registry[spec_name] = op_obj
68+
# logger.debug(f"Registered operator from object: {spec_name} -> {op_obj}")
69+
return op_obj
70+
71+
def register_operator(self, op_obj):
72+
return self._get_operator_from_object(op_obj)
73+
74+
def get_all_registered_ops(self):
75+
return self._registry.copy()
76+
77+
def clear(self):
78+
self._registry.clear()
79+
80+
def __len__(self):
81+
return len(self._registry)
82+
83+
def __contains__(self, spec_name):
84+
"""Check if operator is registered."""
85+
return spec_name in self._registry
86+
87+
def __repr__(self):
88+
return f"OpRegistry({len(self._registry)} ops)"
89+
90+
91+
# Global operator registry instance
92+
_op_registry = OpRegistry()
93+
94+
95+
def get_operator(input_obj):
96+
return _op_registry.get_operator(input_obj)
97+
98+
99+
def register_operator(op_obj):
100+
return _op_registry.register_operator(op_obj)
101+
102+
103+
def get_registry():
104+
return _op_registry

BackendBench/scripts/main.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import BackendBench.eval as eval
88
import click
99
import torch
10+
from BackendBench.facto_suite import FactoTestSuite
1011
from BackendBench.llm_client import ClaudeKernelGenerator
1112
from BackendBench.opinfo_suite import OpInfoTestSuite
1213
from BackendBench.suite import SmokeTestSuite
@@ -38,7 +39,7 @@ def setup_logging(log_level):
3839
@click.option(
3940
"--suite",
4041
default="smoke",
41-
type=click.Choice(["smoke", "opinfo", "torchbench"]),
42+
type=click.Choice(["smoke", "opinfo", "torchbench", "facto"]),
4243
help="Which suite to run",
4344
)
4445
@click.option(
@@ -128,6 +129,12 @@ def cli(
128129
filter=ops,
129130
topn=topn_inputs,
130131
),
132+
"facto": lambda: FactoTestSuite(
133+
"facto_cuda_bfloat16",
134+
"cuda",
135+
torch.bfloat16,
136+
filter=ops,
137+
),
131138
}[suite]()
132139

133140
# For LLM backend, we need to generate kernels first

BackendBench/suite.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import torch
22

3+
from BackendBench.opregistry import get_operator
4+
35

46
def randn(*args, **kwargs):
57
return lambda: torch.randn(*args, **kwargs)
@@ -40,7 +42,7 @@ def __iter__(self):
4042
"smoke",
4143
[
4244
OpTest(
43-
torch.ops.aten.relu.default,
45+
get_operator(torch.ops.aten.relu.default),
4446
[
4547
Test(randn(2, device="cpu")),
4648
],

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ dependencies = [
3030
flaggems = [
3131
# flag_gems must be installed from source: https://github.com/FlagOpen/FlagGems
3232
]
33+
facto = [
34+
# facto must be installed from source: https://github.com/pytorch-labs/FACTO
35+
]
3336

3437
[project.scripts]
3538
backendbench = "BackendBench.scripts.main:cli"

0 commit comments

Comments
 (0)