Skip to content

Commit 1eca619

Browse files
authored
Operator Count Analysis (#79)
1 parent 72b8720 commit 1eca619

File tree

5 files changed

+225
-13
lines changed

5 files changed

+225
-13
lines changed

.gitignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,5 @@ CLAUDE.md
88
venv/
99
ops/
1010
uv.lock
11-
12-
# Pre-commit
11+
pytorch_operator_coverage.csv
1312
.pre-commit-cache/

BackendBench/eval.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22

33
import torch
44

5-
import triton.testing
5+
try:
6+
import triton.testing
7+
8+
TRITON_AVAILABLE = True
9+
except ImportError:
10+
TRITON_AVAILABLE = False
611

712

813
from BackendBench.utils import uses_cuda_stream
@@ -69,7 +74,9 @@ def cpu_bench(fn, num_runs=100):
6974

7075

7176
def eval_performance(op, impl, tests):
72-
bench_fn = triton.testing.do_bench if torch.cuda.is_available() else cpu_bench
77+
bench_fn = (
78+
triton.testing.do_bench if TRITON_AVAILABLE and torch.cuda.is_available() else cpu_bench
79+
)
7380
base_times = []
7481
test_times = []
7582
for test in tests:

BackendBench/opinfo_suite.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,24 @@ def build_op_tests(device, dtype, filter=None):
5757
continue
5858

5959
op_indices = defaultdict(list)
60-
for idx, test in enumerate(op.sample_inputs(device, dtype)):
60+
try:
61+
sample_inputs = list(op.sample_inputs(device, dtype))
62+
except Exception:
63+
continue
64+
65+
for idx, test in enumerate(sample_inputs):
6166
# print(f"{idx=} {test.input=} {test.args=} {test.kwargs=}")
62-
with OpTracerMode() as tracer:
63-
ref = op.op(test.input, *test.args, **test.kwargs)
64-
if len(tracer.ops) == 1:
65-
try:
67+
try:
68+
with OpTracerMode() as tracer:
69+
ref = op.op(test.input, *test.args, **test.kwargs)
70+
if len(tracer.ops) == 1:
6671
res = tracer.ops[0](test.input, *test.args, **test.kwargs)
6772
if allclose(ref, res):
6873
op_indices[tracer.ops[0]].append(idx)
69-
except Exception:
70-
logger.debug(f"opinfo {op.name} couldn't run underlying op {tracer.ops[0]}")
71-
else:
72-
logger.debug(f"opinfo {op.name} has {len(tracer.ops)} ops")
74+
else:
75+
logger.debug(f"opinfo {op.name} has {len(tracer.ops)} ops")
76+
except Exception:
77+
continue
7378

7479
for overload, indices in op_indices.items():
7580
if len(indices) > 0:
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#!/usr/bin/env python3
2+
"""Generate comprehensive operator coverage CSV for BackendBench"""
3+
4+
import csv
5+
import torch
6+
7+
from torch.testing._internal.common_methods_invocations import op_db
8+
from BackendBench.scripts.pytorch_operators import (
9+
get_pytorch_operators,
10+
extract_aten_ops,
11+
extract_operator_name,
12+
)
13+
from BackendBench.opinfo_suite import OpInfoTestSuite
14+
from BackendBench.torchbench_suite import TorchBenchTestSuite
15+
16+
17+
def get_torchbench_ops():
18+
"""Get operations from TorchBench suite"""
19+
suite = TorchBenchTestSuite("torchbench", None)
20+
ops = set()
21+
for optest in suite:
22+
op_str = str(optest.op)
23+
op_name = extract_operator_name(op_str)
24+
ops.add(op_name)
25+
return ops
26+
27+
28+
def generate_coverage_csv():
29+
"""Generate comprehensive operator coverage CSV"""
30+
print("Gathering operator data...")
31+
32+
# Get all operators and core operators in one call
33+
all_native_ops, core_ops = get_pytorch_operators()
34+
35+
# Get OpInfo operators
36+
print("Building OpInfo tests for device=cpu, dtype=torch.float32")
37+
suite = OpInfoTestSuite("opinfo", "cpu", torch.float32)
38+
opinfo_successful_ops = [str(optest.op) for optest in suite]
39+
print("\nOpInfo loading results:")
40+
print(f" Total ops in op_db: {len(op_db)}")
41+
print(f" Successful operations found: {len(opinfo_successful_ops)}")
42+
print(f" Unique successful ops: {len(set(opinfo_successful_ops))}")
43+
44+
opinfo_ops = set(extract_aten_ops(opinfo_successful_ops))
45+
torchbench_ops = get_torchbench_ops()
46+
47+
print("\nOperator counts:")
48+
print(f"- Total native functions: {len(all_native_ops)}")
49+
print(f"- Core operators: {len(core_ops)}")
50+
print(f"- OpInfo: {len(opinfo_ops)}")
51+
print(f"- TorchBench: {len(torchbench_ops)}")
52+
53+
# Create comprehensive operator list
54+
all_operators = set(all_native_ops) | set(core_ops) | opinfo_ops | torchbench_ops
55+
core_ops_set = set(core_ops)
56+
57+
# Generate CSV
58+
csv_data = [["op_name", "is_core", "is_in_opinfo", "is_in_torchbench"]]
59+
60+
for op in sorted(all_operators):
61+
row = [
62+
op,
63+
True if op in core_ops_set else False,
64+
True if op in opinfo_ops else False,
65+
True if op in torchbench_ops else False,
66+
]
67+
csv_data.append(row)
68+
69+
csv_filename = "pytorch_operator_coverage.csv"
70+
with open(csv_filename, "w", newline="") as csvfile:
71+
writer = csv.writer(csvfile)
72+
writer.writerows(csv_data)
73+
74+
print(f"\nCSV generated: {csv_filename}")
75+
76+
# Analysis
77+
core_in_opinfo = core_ops_set & opinfo_ops
78+
core_in_torchbench = core_ops_set & torchbench_ops
79+
core_in_either = core_ops_set & (opinfo_ops | torchbench_ops)
80+
core_missing_both = core_ops_set - (opinfo_ops | torchbench_ops)
81+
82+
print(
83+
f"\nCore in OpInfo: {len(core_in_opinfo)}/{len(core_ops)} ({len(core_in_opinfo) / len(core_ops) * 100:.1f}%)"
84+
)
85+
print(
86+
f"Core in TorchBench: {len(core_in_torchbench)}/{len(core_ops)} ({len(core_in_torchbench) / len(core_ops) * 100:.1f}%)"
87+
)
88+
print(
89+
f"Combined coverage: {len(core_in_either)}/{len(core_ops)} ({len(core_in_either) / len(core_ops) * 100:.1f}%)"
90+
)
91+
print(f"Missing from both: {sorted(core_missing_both)}")
92+
93+
return csv_filename
94+
95+
96+
if __name__ == "__main__":
97+
csv_file = generate_coverage_csv()
98+
print(f"\nAnalysis complete! CSV saved as: {csv_file}")
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#!/usr/bin/env python3
2+
"""PyTorch operator utilities for BackendBench analysis"""
3+
4+
import urllib.request
5+
import yaml
6+
from typing import List
7+
8+
9+
def extract_operator_name(op_str: str) -> str:
10+
"""Extract clean operator name from various operator string formats.
11+
12+
Note: We don't care about overloads - we treat all overloads of an operator
13+
(e.g., add.Tensor, add.Scalar, add.out) as the same base operator.
14+
15+
Examples:
16+
"aten.relu.default" -> "relu"
17+
"torch.ops.aten.add.Tensor" -> "add"
18+
"add.Tensor" -> "add"
19+
"relu" -> "relu"
20+
"""
21+
if "aten." in op_str:
22+
return op_str.split("aten.")[-1].split(".")[0]
23+
elif "." in op_str:
24+
return op_str.split(".")[0]
25+
else:
26+
return op_str
27+
28+
29+
def get_deprecated_operators():
30+
"""Get deprecated operators from PyTorch's deprecated.yaml"""
31+
url = "https://raw.githubusercontent.com/pytorch/pytorch/refs/heads/main/tools/autograd/deprecated.yaml"
32+
33+
deprecated_ops = set()
34+
try:
35+
print("Downloading deprecated.yaml...")
36+
with urllib.request.urlopen(url) as response:
37+
yaml_content = response.read().decode("utf-8")
38+
39+
deprecated_functions = yaml.safe_load(yaml_content)
40+
41+
if deprecated_functions:
42+
for func_def in deprecated_functions:
43+
if isinstance(func_def, dict) and "name" in func_def:
44+
func_name = func_def["name"]
45+
base_name = extract_operator_name(func_name)
46+
deprecated_ops.add(base_name)
47+
48+
print(f"Found {len(deprecated_ops)} deprecated operators")
49+
except Exception as e:
50+
print(f"Warning: Could not fetch deprecated operators: {e}")
51+
52+
return deprecated_ops
53+
54+
55+
def get_pytorch_operators():
56+
"""Get all operators and core operators from PyTorch's native_functions.yaml, excluding deprecated ones"""
57+
url = "https://raw.githubusercontent.com/pytorch/pytorch/refs/heads/main/aten/src/ATen/native/native_functions.yaml"
58+
59+
print("Downloading native_functions.yaml...")
60+
with urllib.request.urlopen(url) as response:
61+
yaml_content = response.read().decode("utf-8")
62+
63+
functions = yaml.safe_load(yaml_content)
64+
print(f"Found {len(functions)} function definitions")
65+
66+
# Get deprecated operators to exclude
67+
deprecated_ops = get_deprecated_operators()
68+
69+
all_ops = set()
70+
core_ops = set()
71+
72+
for func_def in functions:
73+
if isinstance(func_def, dict) and "func" in func_def:
74+
func_signature = func_def["func"]
75+
func_name = func_signature.split("(")[0].strip()
76+
77+
base_name = extract_operator_name(func_name)
78+
79+
# Skip deprecated operators
80+
if base_name in deprecated_ops:
81+
continue
82+
83+
all_ops.add(base_name)
84+
85+
if "core" in func_def.get("tags", []):
86+
core_ops.add(base_name)
87+
88+
all_ops_list = sorted([op for op in all_ops if op and not op.isspace()])
89+
core_ops_list = sorted([op for op in core_ops if op and not op.isspace()])
90+
91+
print(f"Extracted {len(all_ops_list)} unique operators (excluding deprecated)")
92+
print(f"Found {len(core_ops_list)} core operators (excluding deprecated)")
93+
94+
return all_ops_list, core_ops_list
95+
96+
97+
def extract_aten_ops(ops_list: List[str]) -> List[str]:
98+
"""Extract aten operation names from ops list"""
99+
aten_ops = set()
100+
for op_str in ops_list:
101+
if "aten." in op_str:
102+
aten_ops.add(extract_operator_name(op_str))
103+
return sorted(aten_ops)

0 commit comments

Comments
 (0)