Skip to content
Merged
198 changes: 166 additions & 32 deletions BackendBench/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import json
import logging
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import torch


try:
import triton.testing
if torch.cuda.is_available():
import triton.testing

TRITON_AVAILABLE = True
TRITON_AVAILABLE = True
else:
TRITON_AVAILABLE = False
except ImportError:
TRITON_AVAILABLE = False

Expand All @@ -31,34 +39,119 @@ def format_exception(e, op, args, kwargs):
return EXC_MSG.format(op=op_name, args=serialize_args(args, kwargs), exc=e)


def allclose(a, b):
def _allclose(a, b):
# due to sparse tensors we check by error checking
if isinstance(a, torch.Tensor):
torch.testing.assert_close(a, b, equal_nan=True, atol=1e-2, rtol=1e-2)
elif isinstance(a, (list, tuple)):
assert len(a) == len(b)
for ele_a, ele_b in zip(a, b):
_allclose(ele_a, ele_b)
else:
assert a == b


def allclose(a, b):
try:
_allclose(a, b)
return True
if isinstance(a, (list, tuple)):
if len(a) != len(b):
raise ValueError(f"Length mismatch: {len(a)} vs {len(b)}")
return all(allclose(x, y) for x, y in zip(a, b))
return a == b
except Exception:
return False


def compute_errors(ref, res, eps=1e-10) -> Tuple[Optional[float], Optional[float]]:
"""Compute absolute and relative errors between reference and result tensors.

Returns:
Tuple of (absolute_error, relative_error) or (None, None) if not tensors/list of tensors
"""
if isinstance(ref, torch.Tensor) and isinstance(res, torch.Tensor):
if ref.shape != res.shape:
return None, None

if ref.is_sparse and res.is_sparse:
# todo: create note that we don't calculate errors for sparse tensors / results
return None, None

# Convert to float for error calculation
ref_float = ref.float()
res_float = res.float()

# Absolute error
abs_error = (ref_float - res_float).abs().mean().item()

# Relative error (avoid division by zero)
ref_abs = ref_float.abs()
rel_error = ((ref_float - res_float).abs() / (ref_abs + eps)).mean().item()

return abs_error, rel_error
elif isinstance(ref, (list, tuple)) and isinstance(res, (list, tuple)):
if len(ref) != len(res):
return None, None

def eval_correctness_test(op, impl, test):
"""Evaluate impl of op against test."""
# if we have no tensors just return None
if not any(not isinstance(x, torch.Tensor) for x in ref) or not any(
not isinstance(x, torch.Tensor) for x in res
):
return None, None

# For lists/tuples, compute mean error across all elements.
# We will return the mean of these means
mean_abs_error = 0.0
mean_rel_error = 0.0

for r, s in zip(ref, res):
abs_err, rel_err = compute_errors(r, s)
if abs_err is None or rel_err is None:
continue
mean_abs_error += abs_err
mean_rel_error += rel_err

return mean_abs_error / len(ref), mean_rel_error / len(ref)
else:
return None, None


def eval_correctness_test(
op, impl, test
) -> Tuple[bool, Optional[str], Optional[float], Optional[float]]:
"""Evaluate impl of op against test.

Returns:
Tuple of (is_correct, error_message, absolute_error, relative_error)
"""
args, kwargs = test.args, test.kwargs
ref = op(*args, **kwargs)
try:
res = impl(*args, **kwargs)
return allclose(ref, res)
is_correct = allclose(ref, res)

# Compute errors even if test passes (for verbose mode)
abs_error, rel_error = compute_errors(ref, res)

return is_correct, None, abs_error, rel_error
except Exception as e:
logger.warning(format_exception(e, op, args, kwargs))
return False
error_msg = format_exception(e, op, args, kwargs)
logger.warning(error_msg)
return False, str(e), None, None


def eval_correctness(op, impl, tests):
def eval_correctness(op, impl, tests, verbose_data: defaultdict = defaultdict(dict)):
"""Evaluate correctness of impl against tests."""
correct, total = 0, 0
for test in tests:
logging.debug(f"Testing {op.__name__} with args {serialize_args(test.args, test.kwargs)}")
if eval_correctness_test(op, impl, test):
args_str = serialize_args(test.args, test.kwargs)
logging.debug(f"Testing {op.__name__} with args {args_str}")
is_correct, error_msg, abs_error, rel_error = eval_correctness_test(op, impl, test)

verbose_data[args_str] = {
"correctness_score": 1 if is_correct else 0,
"correctness_errors": error_msg or "",
"absolute_error": str(abs_error) if abs_error is not None else "",
"relative_error": str(rel_error) if rel_error is not None else "",
}

if is_correct:
correct += 1
total += 1
return correct / total
Expand All @@ -77,34 +170,75 @@ def cpu_bench(fn, num_runs=100):
return (time.perf_counter() - start) / num_runs


def eval_performance(op, impl, tests):
def eval_performance(op, impl, tests, verbose_data: defaultdict = defaultdict(dict)):
"""Evaluate performance of impl against tests."""
bench_fn = (
triton.testing.do_bench if TRITON_AVAILABLE and torch.cuda.is_available() else cpu_bench
)
base_times = []
test_times = []

for test in tests:
logging.debug(
f"Benchmarking {op.__name__} with args {serialize_args(test.args, test.kwargs)}"
)
base_times.append(bench_fn(lambda: op(*test.args, **test.kwargs)))
args_str = serialize_args(test.args, test.kwargs)
logging.debug(f"Benchmarking {op.__name__} with args {args_str}")
base_time = bench_fn(lambda: op(*test.args, **test.kwargs))
base_times.append(base_time)
test_time = base_time
try:
allclose(op(*test.args, **test.kwargs), impl(*test.args, **test.kwargs))
ref = op(*test.args, **test.kwargs)
res = impl(*test.args, **test.kwargs)
if not allclose(
ref,
res,
):
raise ValueError(f"Reference and result tensors are not close: {ref} vs {res}")
test_time = bench_fn(lambda: impl(*test.args, **test.kwargs))
except Exception:
test_times.append(base_times[-1])
continue
test_times.append(bench_fn(lambda: impl(*test.args, **test.kwargs)))
pass
finally:
test_times.append(test_time)
verbose_data[args_str]["benchmark_time"] = str(test_time)
speedup = base_time / test_time if test_time > 0 else float("inf")
verbose_data[args_str]["speedup"] = str(speedup)

speedups = torch.tensor(base_times) / torch.tensor(test_times)
return speedups.log().mean().exp()


def eval_one_op(op, impl, correctness_tests, performance_tests):
"""Evaluate impl of op against correctness_tests and performance_tests."""
# TODO: We should have proper error reporting instead of just saying this is 0,
# but that should be a separate PR.
"""Evaluate impl of op against correctness_tests and performance_tests.

Returns:
Tuple of (correctness_score, performance_score, verbose_data)
"""
verbose_data = defaultdict(dict)

if uses_cuda_stream(impl):
logger.warning(f"Skipping {op.__name__} because it uses CUDA stream")
return 0.0, 1.0
return eval_correctness(op, impl, correctness_tests), eval_performance(
op, impl, performance_tests
)
for test in correctness_tests + performance_tests:
args_str = serialize_args(test.args, test.kwargs)
verbose_data[args_str] = {
"correctness_score": 0,
"benchmark_time": "",
"speedup": "",
"correctness_errors": "Skipped: uses CUDA stream",
"absolute_error": "",
"relative_error": "",
}
return 0, 1.0, verbose_data

correctness_score = eval_correctness(op, impl, correctness_tests, verbose_data)
performance_score = eval_performance(op, impl, performance_tests, verbose_data)
verbose_data = dict(verbose_data)
return correctness_score, performance_score, verbose_data


def save_verbose_results(
results: List[Dict[str, Any]],
output_path: str = "backendbench_verbose_results.json",
):
"""Save verbose results to a JSON file."""
with open(Path(output_path), "w") as f:
json.dump(results, f, indent=2)

logger.info(f"Verbose results saved to {output_path}")
12 changes: 11 additions & 1 deletion BackendBench/multiprocessing_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class EvalResult:
task_id: int
correctness_score: float
performance_score: float
verbose_data: Optional[dict] = None
error: Optional[str] = None


Expand Down Expand Up @@ -99,13 +100,14 @@ def _worker_process(worker_id, task_queue, result_queue):
if isinstance(impl, str):
impl = get_operator(impl)

correctness_score, performance_score = eval_one_op(
correctness_score, performance_score, verbose_data = eval_one_op(
op, impl, task.correctness_tests, task.performance_tests
)
result = EvalResult(
task_id=task.task_id,
correctness_score=correctness_score,
performance_score=performance_score,
verbose_data=verbose_data,
)
except Exception as e:
error_msg = f"Error in eval_one_op: {str(e)}\n{traceback.format_exc()}"
Expand All @@ -121,6 +123,14 @@ def _worker_process(worker_id, task_queue, result_queue):
task_id=task.task_id,
correctness_score=0.0,
performance_score=1.0,
verbose_data={
"correctness_score": 0.0,
"benchmark_time": "",
"speedup": "",
"correctness_errors": f"{error_msg}",
"absolute_error": "",
"relative_error": "",
},
error=error_msg,
)

Expand Down
39 changes: 35 additions & 4 deletions BackendBench/scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,17 @@ def setup_logging(log_level):
type=str,
help="Path to directory containing generated kernels",
)
@click.option(
"--output-path",
default=None,
type=str,
help="Path for JSON output file with detailed results (if not specified, no JSON output)",
)
@click.option(
"--num-workers",
default=None,
type=int,
help="Number of workers to use for multiprocessing, default to None to disable multiprocessing)",
help="Number of workers to use for multiprocessing, default to None to disable multiprocessing",
)
def cli(
log_level,
Expand All @@ -123,6 +129,7 @@ def cli(
kernel_agent_max_rounds,
torchbench_data_path,
ops_directory,
output_path,
num_workers,
):
setup_logging(log_level)
Expand Down Expand Up @@ -184,6 +191,7 @@ def cli(

overall_correctness = []
overall_performance = []
verbose_results = []

if num_workers is None:
for test in suite:
Expand All @@ -192,7 +200,7 @@ def cli(

logger.debug(test.op)

correctness, perf = eval.eval_one_op(
correctness, perf, op_verbose_data = eval.eval_one_op(
test.op,
backend[test.op],
test.correctness_tests,
Expand All @@ -201,19 +209,29 @@ def cli(
overall_correctness.append(correctness)
overall_performance.append(perf)

# Convert dict to list entries with op_name
op_name = getattr(test.op, "__name__", str(test.op))
for args_str, data in op_verbose_data.items():
entry = {"op_name": op_name, "args": args_str}
entry.update(data)
verbose_results.append(entry)

logger.debug(f"max memory allocated: {torch.cuda.max_memory_allocated():,}")
else:
with multiprocessing_eval.MultiprocessingEvaluator(num_workers) as evaluator:
# Submit all tasks
# Submit all tasks and track op names
task_to_op_name = {}
for test in suite:
if test.op not in backend:
continue

logger.debug(test.op)

evaluator.submit_task(
task_id = evaluator.submit_task(
test.op, backend[test.op], test.correctness_tests, test.performance_tests
)
op_name = getattr(test.op, "__name__", str(test.op))
task_to_op_name[task_id] = op_name

# Start evaluation
evaluator.start_evaluation()
Expand All @@ -227,11 +245,24 @@ def cli(
overall_correctness.append(correctness_score)
overall_performance.append(performance_score)

# Handle verbose data if present
if result.verbose_data and result.task_id in task_to_op_name:
op_name = task_to_op_name[result.task_id]
for args_str, data in result.verbose_data.items():
entry = {"op_name": op_name, "args": args_str}
entry.update(data)
verbose_results.append(entry)

mean_correctness = torch.tensor(overall_correctness).mean().item()
geomean_perf = torch.tensor(overall_performance).log().mean().exp().item()
print(f"correctness score (mean pass rate over all operators): {mean_correctness:.2f}")
print(f"performance score (geomean speedup over all operators): {geomean_perf:.2f}")

# Save verbose results if output path is specified
if output_path and verbose_results:
eval.save_verbose_results(verbose_results, output_path)
print(f"Detailed results saved to: {output_path}")


def setup_llm_backend(llm_backend, llm_client, suite, max_attempts=5):
"""Setup LLM backend by generating kernels for all operations in the suite."""
Expand Down
2 changes: 1 addition & 1 deletion test/test_backend_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def test_4_eval_integration(self):
impl = backend[test_op]
test = Test(lambda: torch.tensor([1, 2, 3]), lambda: torch.tensor([2, 3, 4]))

correctness, performance = eval_one_op(test_op, impl, [test], [test])
correctness, performance, _ = eval_one_op(test_op, impl, [test], [test])

print(f" Operation: {test_op}")
print(f" Correctness: {correctness}")
Expand Down
Loading