Skip to content
Merged
161 changes: 126 additions & 35 deletions BackendBench/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,26 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import json
import logging
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import torch


try:
import triton.testing
if torch.cuda.is_available():
import triton.testing

TRITON_AVAILABLE = True
TRITON_AVAILABLE = True
else:
TRITON_AVAILABLE = False
except ImportError:
TRITON_AVAILABLE = False

from BackendBench.utils import serialize_args, uses_cuda_stream
from BackendBench.utils import serialize_args, uses_cuda_stream, compute_errors

logger = logging.getLogger(__name__)

Expand All @@ -31,34 +39,71 @@ def format_exception(e, op, args, kwargs):
return EXC_MSG.format(op=op_name, args=serialize_args(args, kwargs), exc=e)


def allclose(a, b):
if isinstance(a, torch.Tensor):
torch.testing.assert_close(a, b, equal_nan=True, atol=1e-2, rtol=1e-2)
def _allclose(a, b, atol=1e-2, rtol=1e-2):
# using a stack to avoid recursion overflow issues
stack = [(a, b)]

while len(stack) > 0:
curr_a, curr_b = stack.pop()

if isinstance(curr_a, torch.Tensor):
torch.testing.assert_close(curr_a, curr_b, equal_nan=True, atol=atol, rtol=rtol)
elif isinstance(curr_a, (list, tuple)):
assert len(curr_a) == len(curr_b)
# Add pairs to stack in reverse order to maintain left-to-right checking
stack.extend(reversed(list(zip(curr_a, curr_b))))
else:
assert curr_a == curr_b


def allclose(a, b, atol=1e-2, rtol=1e-2):
try:
_allclose(a, b)
return True
if isinstance(a, (list, tuple)):
if len(a) != len(b):
raise ValueError(f"Length mismatch: {len(a)} vs {len(b)}")
return all(allclose(x, y) for x, y in zip(a, b))
return a == b
except Exception:
return False


def eval_correctness_test(op, impl, test):
"""Evaluate impl of op against test."""
def eval_correctness_test(
op, impl, test
) -> Tuple[bool, Optional[str], Optional[float], Optional[float]]:
"""Evaluate impl of op against test.

Returns:
Tuple of (is_correct, error_message, absolute_error, relative_error)
"""
args, kwargs = test.args, test.kwargs
ref = op(*args, **kwargs)
try:
res = impl(*args, **kwargs)
return allclose(ref, res)
is_correct = allclose(ref, res)

# Compute errors even if test passes (for verbose mode)
abs_error, rel_error = compute_errors(ref, res)

return is_correct, None, abs_error, rel_error
except Exception as e:
logger.warning(format_exception(e, op, args, kwargs))
return False
error_msg = format_exception(e, op, args, kwargs)
logger.warning(error_msg)
return False, str(e), None, None


def eval_correctness(op, impl, tests):
def eval_correctness(op, impl, tests, test_data: defaultdict = defaultdict(dict)):
"""Evaluate correctness of impl against tests."""
correct, total = 0, 0
for test in tests:
logging.debug(f"Testing {op.__name__} with args {serialize_args(test.args, test.kwargs)}")
if eval_correctness_test(op, impl, test):
args_str = serialize_args(test.args, test.kwargs)
logging.debug(f"Testing {op.__name__} with args {args_str}")
is_correct, error_msg, abs_error, rel_error = eval_correctness_test(op, impl, test)

test_data[args_str] = {
"correctness_score": 1 if is_correct else 0,
"correctness_errors": error_msg or "",
"absolute_error": str(abs_error) if abs_error is not None else "",
"relative_error": str(rel_error) if rel_error is not None else "",
}

if is_correct:
correct += 1
total += 1
return correct / total
Expand All @@ -77,34 +122,80 @@ def cpu_bench(fn, num_runs=100):
return (time.perf_counter() - start) / num_runs


def eval_performance(op, impl, tests):
def eval_performance(op, impl, tests, test_data: defaultdict = defaultdict(dict)):
"""Evaluate performance of impl against tests."""
bench_fn = (
triton.testing.do_bench if TRITON_AVAILABLE and torch.cuda.is_available() else cpu_bench
)
base_times = []
test_times = []
args_strs = []

for test in tests:
logging.debug(
f"Benchmarking {op.__name__} with args {serialize_args(test.args, test.kwargs)}"
)
base_times.append(bench_fn(lambda: op(*test.args, **test.kwargs)))
args_str = serialize_args(test.args, test.kwargs)
args_strs.append(args_str)
logging.debug(f"Benchmarking {op.__name__} with args {args_str}")
base_time = bench_fn(lambda: op(*test.args, **test.kwargs))
base_times.append(base_time)
test_time = base_time
try:
allclose(op(*test.args, **test.kwargs), impl(*test.args, **test.kwargs))
ref = op(*test.args, **test.kwargs)
res = impl(*test.args, **test.kwargs)
if not allclose(
ref,
res,
):
raise ValueError(f"Reference and result tensors are not close: {ref} vs {res}")
test_time = bench_fn(lambda: impl(*test.args, **test.kwargs))
except Exception:
test_times.append(base_times[-1])
continue
test_times.append(bench_fn(lambda: impl(*test.args, **test.kwargs)))
pass
finally:
test_times.append(test_time)
test_data[args_str]["benchmark_time"] = str(test_time)

speedups = torch.tensor(base_times) / torch.tensor(test_times)

# Update test_data with speedups from the tensor
for i, args_str in enumerate(args_strs):
test_data[args_str]["speedup"] = str(speedups[i].item())

return speedups.log().mean().exp()


def eval_one_op(op, impl, correctness_tests, performance_tests):
"""Evaluate impl of op against correctness_tests and performance_tests."""
# TODO: We should have proper error reporting instead of just saying this is 0,
# but that should be a separate PR.
"""Evaluate impl of op against correctness_tests and performance_tests.

Returns:
Tuple of (correctness_score, performance_score, test_data)
"""
test_data = defaultdict(dict)

if uses_cuda_stream(impl):
logger.warning(f"Skipping {op.__name__} because it uses CUDA stream")
return 0.0, 1.0
return eval_correctness(op, impl, correctness_tests), eval_performance(
op, impl, performance_tests
)
for test in correctness_tests + performance_tests:
args_str = serialize_args(test.args, test.kwargs)
test_data[args_str] = {
"correctness_score": 0,
"benchmark_time": "",
"speedup": "",
"correctness_errors": "Skipped: uses CUDA stream",
"absolute_error": "",
"relative_error": "",
}
return 0, 1.0, test_data

correctness_score = eval_correctness(op, impl, correctness_tests, test_data)
performance_score = eval_performance(op, impl, performance_tests, test_data)
test_data = dict(test_data)
return correctness_score, performance_score, test_data


def save_verbose_results(
results: List[Dict[str, Any]],
output_path: str = "backendbench_verbose_results.json",
):
"""Save verbose results to a JSON file."""
with open(Path(output_path), "w") as f:
json.dump(results, f, indent=2)

logger.info(f"Verbose results saved to {output_path}")
12 changes: 11 additions & 1 deletion BackendBench/multiprocessing_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class EvalResult:
task_id: int
correctness_score: float
performance_score: float
test_data: Optional[dict] = None
error: Optional[str] = None


Expand Down Expand Up @@ -99,13 +100,14 @@ def _worker_process(worker_id, task_queue, result_queue):
if isinstance(impl, str):
impl = get_operator(impl)

correctness_score, performance_score = eval_one_op(
correctness_score, performance_score, test_data = eval_one_op(
op, impl, task.correctness_tests, task.performance_tests
)
result = EvalResult(
task_id=task.task_id,
correctness_score=correctness_score,
performance_score=performance_score,
test_data=test_data,
)
except Exception as e:
error_msg = f"Error in eval_one_op: {str(e)}\n{traceback.format_exc()}"
Expand All @@ -121,6 +123,14 @@ def _worker_process(worker_id, task_queue, result_queue):
task_id=task.task_id,
correctness_score=0.0,
performance_score=1.0,
test_data={
"correctness_score": 0.0,
"benchmark_time": "",
"speedup": "",
"correctness_errors": f"{error_msg}",
"absolute_error": "",
"relative_error": "",
},
error=error_msg,
)

Expand Down
39 changes: 35 additions & 4 deletions BackendBench/scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,17 @@ def setup_logging(log_level):
type=str,
help="Path to directory containing generated kernels",
)
@click.option(
"--output-path",
default=None,
type=str,
help="Path for JSON output file with detailed results (if not specified, no JSON output)",
)
@click.option(
"--num-workers",
default=None,
type=int,
help="Number of workers to use for multiprocessing, default to None to disable multiprocessing)",
help="Number of workers to use for multiprocessing, default to None to disable multiprocessing",
)
def cli(
log_level,
Expand All @@ -123,6 +129,7 @@ def cli(
kernel_agent_max_rounds,
torchbench_data_path,
ops_directory,
output_path,
num_workers,
):
setup_logging(log_level)
Expand Down Expand Up @@ -184,6 +191,7 @@ def cli(

overall_correctness = []
overall_performance = []
verbose_results = []

if num_workers is None:
for test in suite:
Expand All @@ -192,7 +200,7 @@ def cli(

logger.debug(test.op)

correctness, perf = eval.eval_one_op(
correctness, perf, op_test_data = eval.eval_one_op(
test.op,
backend[test.op],
test.correctness_tests,
Expand All @@ -201,19 +209,29 @@ def cli(
overall_correctness.append(correctness)
overall_performance.append(perf)

# Convert dict to list entries with op_name
op_name = getattr(test.op, "__name__", str(test.op))
for args_str, data in op_test_data.items():
entry = {"op_name": op_name, "args": args_str}
entry.update(data)
verbose_results.append(entry)

logger.debug(f"max memory allocated: {torch.cuda.max_memory_allocated():,}")
else:
with multiprocessing_eval.MultiprocessingEvaluator(num_workers) as evaluator:
# Submit all tasks
# Submit all tasks and track op names
task_to_op_name = {}
for test in suite:
if test.op not in backend:
continue

logger.debug(test.op)

evaluator.submit_task(
task_id = evaluator.submit_task(
test.op, backend[test.op], test.correctness_tests, test.performance_tests
)
op_name = getattr(test.op, "__name__", str(test.op))
task_to_op_name[task_id] = op_name

# Start evaluation
evaluator.start_evaluation()
Expand All @@ -227,11 +245,24 @@ def cli(
overall_correctness.append(correctness_score)
overall_performance.append(performance_score)

# Handle verbose data if present
if result.test_data and result.task_id in task_to_op_name:
op_name = task_to_op_name[result.task_id]
for args_str, data in result.test_data.items():
entry = {"op_name": op_name, "args": args_str}
entry.update(data)
verbose_results.append(entry)

mean_correctness = torch.tensor(overall_correctness).mean().item()
geomean_perf = torch.tensor(overall_performance).log().mean().exp().item()
print(f"correctness score (mean pass rate over all operators): {mean_correctness:.2f}")
print(f"performance score (geomean speedup over all operators): {geomean_perf:.2f}")

# Save verbose results if output path is specified
if output_path and verbose_results:
eval.save_verbose_results(verbose_results, output_path)
print(f"Detailed results saved to: {output_path}")


def setup_llm_backend(llm_backend, llm_client, suite, max_attempts=5):
"""Setup LLM backend by generating kernels for all operations in the suite."""
Expand Down
Loading