Skip to content

Commit 028fb33

Browse files
authored
Add verbose mode (#92)
1 parent 75d29af commit 028fb33

File tree

9 files changed

+346
-61
lines changed

9 files changed

+346
-61
lines changed

BackendBench/eval.py

Lines changed: 126 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,26 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import json
78
import logging
9+
from collections import defaultdict
10+
from pathlib import Path
11+
from typing import Any, Dict, List, Optional, Tuple
812

913
import torch
1014

15+
1116
try:
12-
import triton.testing
17+
if torch.cuda.is_available():
18+
import triton.testing
1319

14-
TRITON_AVAILABLE = True
20+
TRITON_AVAILABLE = True
21+
else:
22+
TRITON_AVAILABLE = False
1523
except ImportError:
1624
TRITON_AVAILABLE = False
1725

18-
from BackendBench.utils import serialize_args, uses_cuda_stream
26+
from BackendBench.utils import serialize_args, uses_cuda_stream, compute_errors
1927

2028
logger = logging.getLogger(__name__)
2129

@@ -31,34 +39,71 @@ def format_exception(e, op, args, kwargs):
3139
return EXC_MSG.format(op=op_name, args=serialize_args(args, kwargs), exc=e)
3240

3341

34-
def allclose(a, b):
35-
if isinstance(a, torch.Tensor):
36-
torch.testing.assert_close(a, b, equal_nan=True, atol=1e-2, rtol=1e-2)
42+
def _allclose(a, b, atol=1e-2, rtol=1e-2):
43+
# using a stack to avoid recursion overflow issues
44+
stack = [(a, b)]
45+
46+
while len(stack) > 0:
47+
curr_a, curr_b = stack.pop()
48+
49+
if isinstance(curr_a, torch.Tensor):
50+
torch.testing.assert_close(curr_a, curr_b, equal_nan=True, atol=atol, rtol=rtol)
51+
elif isinstance(curr_a, (list, tuple)):
52+
assert len(curr_a) == len(curr_b)
53+
# Add pairs to stack in reverse order to maintain left-to-right checking
54+
stack.extend(reversed(list(zip(curr_a, curr_b))))
55+
else:
56+
assert curr_a == curr_b
57+
58+
59+
def allclose(a, b, atol=1e-2, rtol=1e-2):
60+
try:
61+
_allclose(a, b)
3762
return True
38-
if isinstance(a, (list, tuple)):
39-
if len(a) != len(b):
40-
raise ValueError(f"Length mismatch: {len(a)} vs {len(b)}")
41-
return all(allclose(x, y) for x, y in zip(a, b))
42-
return a == b
63+
except Exception:
64+
return False
4365

4466

45-
def eval_correctness_test(op, impl, test):
46-
"""Evaluate impl of op against test."""
67+
def eval_correctness_test(
68+
op, impl, test
69+
) -> Tuple[bool, Optional[str], Optional[float], Optional[float]]:
70+
"""Evaluate impl of op against test.
71+
72+
Returns:
73+
Tuple of (is_correct, error_message, absolute_error, relative_error)
74+
"""
4775
args, kwargs = test.args, test.kwargs
4876
ref = op(*args, **kwargs)
4977
try:
5078
res = impl(*args, **kwargs)
51-
return allclose(ref, res)
79+
is_correct = allclose(ref, res)
80+
81+
# Compute errors even if test passes (for verbose mode)
82+
abs_error, rel_error = compute_errors(ref, res)
83+
84+
return is_correct, None, abs_error, rel_error
5285
except Exception as e:
53-
logger.warning(format_exception(e, op, args, kwargs))
54-
return False
86+
error_msg = format_exception(e, op, args, kwargs)
87+
logger.warning(error_msg)
88+
return False, str(e), None, None
5589

5690

57-
def eval_correctness(op, impl, tests):
91+
def eval_correctness(op, impl, tests, test_data: defaultdict = defaultdict(dict)):
92+
"""Evaluate correctness of impl against tests."""
5893
correct, total = 0, 0
5994
for test in tests:
60-
logging.debug(f"Testing {op.__name__} with args {serialize_args(test.args, test.kwargs)}")
61-
if eval_correctness_test(op, impl, test):
95+
args_str = serialize_args(test.args, test.kwargs)
96+
logging.debug(f"Testing {op.__name__} with args {args_str}")
97+
is_correct, error_msg, abs_error, rel_error = eval_correctness_test(op, impl, test)
98+
99+
test_data[args_str] = {
100+
"correctness_score": 1 if is_correct else 0,
101+
"correctness_errors": error_msg or "",
102+
"absolute_error": str(abs_error) if abs_error is not None else "",
103+
"relative_error": str(rel_error) if rel_error is not None else "",
104+
}
105+
106+
if is_correct:
62107
correct += 1
63108
total += 1
64109

@@ -83,34 +128,80 @@ def cpu_bench(fn, num_runs=100):
83128
return (time.perf_counter() - start) / num_runs
84129

85130

86-
def eval_performance(op, impl, tests):
131+
def eval_performance(op, impl, tests, test_data: defaultdict = defaultdict(dict)):
132+
"""Evaluate performance of impl against tests."""
87133
bench_fn = (
88134
triton.testing.do_bench if TRITON_AVAILABLE and torch.cuda.is_available() else cpu_bench
89135
)
90136
base_times = []
91137
test_times = []
138+
args_strs = []
139+
92140
for test in tests:
93-
logging.debug(
94-
f"Benchmarking {op.__name__} with args {serialize_args(test.args, test.kwargs)}"
95-
)
96-
base_times.append(bench_fn(lambda: op(*test.args, **test.kwargs)))
141+
args_str = serialize_args(test.args, test.kwargs)
142+
args_strs.append(args_str)
143+
logging.debug(f"Benchmarking {op.__name__} with args {args_str}")
144+
base_time = bench_fn(lambda: op(*test.args, **test.kwargs))
145+
base_times.append(base_time)
146+
test_time = base_time
97147
try:
98-
allclose(op(*test.args, **test.kwargs), impl(*test.args, **test.kwargs))
148+
ref = op(*test.args, **test.kwargs)
149+
res = impl(*test.args, **test.kwargs)
150+
if not allclose(
151+
ref,
152+
res,
153+
):
154+
raise ValueError(f"Reference and result tensors are not close: {ref} vs {res}")
155+
test_time = bench_fn(lambda: impl(*test.args, **test.kwargs))
99156
except Exception:
100-
test_times.append(base_times[-1])
101-
continue
102-
test_times.append(bench_fn(lambda: impl(*test.args, **test.kwargs)))
157+
pass
158+
finally:
159+
test_times.append(test_time)
160+
test_data[args_str]["benchmark_time"] = str(test_time)
161+
103162
speedups = torch.tensor(base_times) / torch.tensor(test_times)
163+
164+
# Update test_data with speedups from the tensor
165+
for i, args_str in enumerate(args_strs):
166+
test_data[args_str]["speedup"] = str(speedups[i].item())
167+
104168
return speedups.log().mean().exp()
105169

106170

107171
def eval_one_op(op, impl, correctness_tests, performance_tests):
108-
"""Evaluate impl of op against correctness_tests and performance_tests."""
109-
# TODO: We should have proper error reporting instead of just saying this is 0,
110-
# but that should be a separate PR.
172+
"""Evaluate impl of op against correctness_tests and performance_tests.
173+
174+
Returns:
175+
Tuple of (correctness_score, performance_score, test_data)
176+
"""
177+
test_data = defaultdict(dict)
178+
111179
if uses_cuda_stream(impl):
112180
logger.warning(f"Skipping {op.__name__} because it uses CUDA stream")
113-
return 0.0, 1.0
114-
return eval_correctness(op, impl, correctness_tests), eval_performance(
115-
op, impl, performance_tests
116-
)
181+
for test in correctness_tests + performance_tests:
182+
args_str = serialize_args(test.args, test.kwargs)
183+
test_data[args_str] = {
184+
"correctness_score": 0,
185+
"benchmark_time": "",
186+
"speedup": "",
187+
"correctness_errors": "Skipped: uses CUDA stream",
188+
"absolute_error": "",
189+
"relative_error": "",
190+
}
191+
return 0, 1.0, test_data
192+
193+
correctness_score = eval_correctness(op, impl, correctness_tests, test_data)
194+
performance_score = eval_performance(op, impl, performance_tests, test_data)
195+
test_data = dict(test_data)
196+
return correctness_score, performance_score, test_data
197+
198+
199+
def save_verbose_results(
200+
results: List[Dict[str, Any]],
201+
output_path: str = "backendbench_verbose_results.json",
202+
):
203+
"""Save verbose results to a JSON file."""
204+
with open(Path(output_path), "w") as f:
205+
json.dump(results, f, indent=2)
206+
207+
logger.info(f"Verbose results saved to {output_path}")

BackendBench/multiprocessing_eval.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class EvalResult:
5151
task_id: int
5252
correctness_score: float
5353
performance_score: float
54+
test_data: Optional[dict] = None
5455
error: Optional[str] = None
5556

5657

@@ -99,13 +100,14 @@ def _worker_process(worker_id, task_queue, result_queue):
99100
if isinstance(impl, str):
100101
impl = get_operator(impl)
101102

102-
correctness_score, performance_score = eval_one_op(
103+
correctness_score, performance_score, test_data = eval_one_op(
103104
op, impl, task.correctness_tests, task.performance_tests
104105
)
105106
result = EvalResult(
106107
task_id=task.task_id,
107108
correctness_score=correctness_score,
108109
performance_score=performance_score,
110+
test_data=test_data,
109111
)
110112
except Exception as e:
111113
error_msg = f"Error in eval_one_op: {str(e)}\n{traceback.format_exc()}"
@@ -121,6 +123,14 @@ def _worker_process(worker_id, task_queue, result_queue):
121123
task_id=task.task_id,
122124
correctness_score=0.0,
123125
performance_score=1.0,
126+
test_data={
127+
"correctness_score": 0.0,
128+
"benchmark_time": "",
129+
"speedup": "",
130+
"correctness_errors": f"{error_msg}",
131+
"absolute_error": "",
132+
"relative_error": "",
133+
},
124134
error=error_msg,
125135
)
126136

BackendBench/scripts/main.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,17 @@ def setup_logging(log_level):
108108
type=str,
109109
help="Path to directory containing generated kernels",
110110
)
111+
@click.option(
112+
"--output-path",
113+
default=None,
114+
type=str,
115+
help="Path for JSON output file with detailed results (if not specified, no JSON output)",
116+
)
111117
@click.option(
112118
"--num-workers",
113119
default=None,
114120
type=int,
115-
help="Number of workers to use for multiprocessing, default to None to disable multiprocessing)",
121+
help="Number of workers to use for multiprocessing, default to None to disable multiprocessing",
116122
)
117123
def cli(
118124
log_level,
@@ -126,6 +132,7 @@ def cli(
126132
kernel_agent_max_rounds,
127133
torchbench_data_path,
128134
ops_directory,
135+
output_path,
129136
num_workers,
130137
):
131138
setup_logging(log_level)
@@ -187,6 +194,7 @@ def cli(
187194

188195
overall_correctness = []
189196
overall_performance = []
197+
verbose_results = []
190198

191199
if num_workers is None:
192200
for test in suite:
@@ -195,7 +203,7 @@ def cli(
195203

196204
logger.debug(test.op)
197205

198-
correctness, perf = eval.eval_one_op(
206+
correctness, perf, op_test_data = eval.eval_one_op(
199207
test.op,
200208
backend[test.op],
201209
test.correctness_tests,
@@ -204,19 +212,29 @@ def cli(
204212
overall_correctness.append(correctness)
205213
overall_performance.append(perf)
206214

215+
# Convert dict to list entries with op_name
216+
op_name = getattr(test.op, "__name__", str(test.op))
217+
for args_str, data in op_test_data.items():
218+
entry = {"op_name": op_name, "args": args_str}
219+
entry.update(data)
220+
verbose_results.append(entry)
221+
207222
logger.debug(f"max memory allocated: {torch.cuda.max_memory_allocated():,}")
208223
else:
209224
with multiprocessing_eval.MultiprocessingEvaluator(num_workers) as evaluator:
210-
# Submit all tasks
225+
# Submit all tasks and track op names
226+
task_to_op_name = {}
211227
for test in suite:
212228
if test.op not in backend:
213229
continue
214230

215231
logger.debug(test.op)
216232

217-
evaluator.submit_task(
233+
task_id = evaluator.submit_task(
218234
test.op, backend[test.op], test.correctness_tests, test.performance_tests
219235
)
236+
op_name = getattr(test.op, "__name__", str(test.op))
237+
task_to_op_name[task_id] = op_name
220238

221239
# Start evaluation
222240
evaluator.start_evaluation()
@@ -230,11 +248,24 @@ def cli(
230248
overall_correctness.append(correctness_score)
231249
overall_performance.append(performance_score)
232250

251+
# Handle verbose data if present
252+
if result.test_data and result.task_id in task_to_op_name:
253+
op_name = task_to_op_name[result.task_id]
254+
for args_str, data in result.test_data.items():
255+
entry = {"op_name": op_name, "args": args_str}
256+
entry.update(data)
257+
verbose_results.append(entry)
258+
233259
mean_correctness = torch.tensor(overall_correctness).mean().item()
234260
geomean_perf = torch.tensor(overall_performance).log().mean().exp().item()
235261
print(f"correctness score (mean pass rate over all operators): {mean_correctness:.2f}")
236262
print(f"performance score (geomean speedup over all operators): {geomean_perf:.2f}")
237263

264+
# Save verbose results if output path is specified
265+
if output_path and verbose_results:
266+
eval.save_verbose_results(verbose_results, output_path)
267+
print(f"Detailed results saved to: {output_path}")
268+
238269

239270
def setup_llm_backend(llm_backend, llm_client, suite, max_attempts=5):
240271
"""Setup LLM backend by generating kernels for all operations in the suite."""

0 commit comments

Comments
 (0)