Skip to content

Add verbose mode #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 134 additions & 26 deletions BackendBench/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,21 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from collections import defaultdict
import json
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import torch

try:
import triton.testing
if torch.cuda.is_available():
import triton.testing

TRITON_AVAILABLE = True
TRITON_AVAILABLE = True
else:
TRITON_AVAILABLE = False
except ImportError:
TRITON_AVAILABLE = False

Expand Down Expand Up @@ -44,23 +51,87 @@ def allclose(a, b):
return a == b


def eval_correctness_test(op, impl, test):
"""Evaluate impl of op against test."""
def compute_errors(ref, res) -> Tuple[Optional[float], Optional[float]]:
"""Compute absolute and relative errors between reference and result tensors.

Returns:
Tuple of (absolute_error, relative_error) or (None, None) if not tensors/list of tensors
"""
if isinstance(ref, torch.Tensor) and isinstance(res, torch.Tensor):
if ref.shape != res.shape:
return None, None

# Convert to float for error calculation
ref_float = ref.float()
res_float = res.float()

# Absolute error
abs_error = (ref_float - res_float).abs().mean().item()

# Relative error (avoid division by zero)
ref_abs = ref_float.abs()
rel_error = ((ref_float - res_float).abs() / (ref_abs + 1e-10)).mean().item()

return abs_error, rel_error
elif isinstance(ref, (list, tuple)) and isinstance(res, (list, tuple)):
if len(ref) != len(res):
return None, None

# For lists/tuples, compute mean error across all elements.
# We will return the mean of these means
mean_abs_error = 0.0
mean_rel_error = 0.0

for r, s in zip(ref, res):
abs_err, rel_err = compute_errors(r, s)
mean_abs_error += abs_err
mean_rel_error += rel_err

return mean_abs_error / len(ref), mean_rel_error / len(ref)
else:
return None, None


def eval_correctness_test(
op, impl, test
) -> Tuple[bool, Optional[str], Optional[float], Optional[float]]:
"""Evaluate impl of op against test.

Returns:
Tuple of (is_correct, error_message, absolute_error, relative_error)
"""
args, kwargs = test.args, test.kwargs
ref = op(*args, **kwargs)
try:
res = impl(*args, **kwargs)
return allclose(ref, res)
is_correct = allclose(ref, res)

# Compute errors even if test passes (for verbose mode)
abs_error, rel_error = compute_errors(ref, res)

return is_correct, None, abs_error, rel_error
except Exception as e:
logger.warning(format_exception(e, op, args, kwargs))
return False
error_msg = format_exception(e, op, args, kwargs)
logger.warning(error_msg)
return False, str(e), None, None


def eval_correctness(op, impl, tests):
def eval_correctness(op, impl, tests, verbose_data: defaultdict):
"""Evaluate correctness of impl against tests."""
correct, total = 0, 0
for test in tests:
logging.debug(f"Testing {op.__name__} with args {serialize_args(test.args, test.kwargs)}")
if eval_correctness_test(op, impl, test):
args_str = serialize_args(test.args, test.kwargs)
logging.debug(f"Testing {op.__name__} with args {args_str}")
is_correct, error_msg, abs_error, rel_error = eval_correctness_test(op, impl, test)

verbose_data[args_str] = {
"correctness_score": 1 if is_correct else 0,
"correctness_errors": error_msg or "",
"absolute_error": str(abs_error) if abs_error is not None else "",
"relative_error": str(rel_error) if rel_error is not None else "",
}

if is_correct:
correct += 1
total += 1
return correct / total
Expand All @@ -79,34 +150,71 @@ def cpu_bench(fn, num_runs=100):
return (time.perf_counter() - start) / num_runs


def eval_performance(op, impl, tests):
def eval_performance(op, impl, tests, verbose_data: defaultdict):
"""Evaluate performance of impl against tests."""
bench_fn = (
triton.testing.do_bench if TRITON_AVAILABLE and torch.cuda.is_available() else cpu_bench
)
base_times = []
test_times = []

for test in tests:
logging.debug(
f"Benchmarking {op.__name__} with args {serialize_args(test.args, test.kwargs)}"
)
base_times.append(bench_fn(lambda: op(*test.args, **test.kwargs)))
args_str = serialize_args(test.args, test.kwargs)
logging.debug(f"Benchmarking {op.__name__} with args {args_str}")
base_time = bench_fn(lambda: op(*test.args, **test.kwargs))
base_times.append(base_time)

try:
allclose(op(*test.args, **test.kwargs), impl(*test.args, **test.kwargs))
ref = op(*test.args, **test.kwargs)
res = impl(*test.args, **test.kwargs)
if not allclose(ref, res):
raise ValueError(f"Reference and result tensors are not close: {ref} vs {res}")
test_time = bench_fn(lambda: impl(*test.args, **test.kwargs))
except Exception:
test_times.append(base_times[-1])
continue
test_times.append(bench_fn(lambda: impl(*test.args, **test.kwargs)))
test_time = -1

test_times.append(test_time)
verbose_data[args_str]["benchmark_time"] = str(test_time)
speedup = base_time / test_time if test_time > 0 else float("inf")
verbose_data[args_str]["speedup"] = str(speedup)

speedups = torch.tensor(base_times) / torch.tensor(test_times)
return speedups.log().mean().exp()


def eval_one_op(op, impl, correctness_tests, performance_tests):
"""Evaluate impl of op against correctness_tests and performance_tests."""
# TODO: We should have proper error reporting instead of just saying this is 0,
# but that should be a separate PR.
"""Evaluate impl of op against correctness_tests and performance_tests.

Returns:
Tuple of (correctness_score, performance_score, verbose_data)
"""
verbose_data = defaultdict(dict)

if uses_cuda_stream(impl):
logger.warning(f"Skipping {op.__name__} because it uses CUDA stream")
return 0, 0
return eval_correctness(op, impl, correctness_tests), eval_performance(
op, impl, performance_tests
)
for test in correctness_tests + performance_tests:
args_str = serialize_args(test.args, test.kwargs)
verbose_data[args_str] = {
"correctness_score": 0,
"benchmark_time": "",
"speedup": "",
"correctness_errors": "Skipped: uses CUDA stream",
"absolute_error": "",
"relative_error": "",
}
return 0, 0, verbose_data

correctness_score = eval_correctness(op, impl, correctness_tests, verbose_data)
performance_score = eval_performance(op, impl, performance_tests, verbose_data)
verbose_data = dict(verbose_data)
return correctness_score, performance_score, verbose_data


def save_verbose_results(
results: List[Dict[str, Any]], output_path: str = "backendbench_verbose_results.json"
):
"""Save verbose results to a JSON file."""
with open(Path(output_path), "w") as f:
json.dump(results, f, indent=2)

logger.info(f"Verbose results saved to {output_path}")
22 changes: 21 additions & 1 deletion BackendBench/scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ def setup_logging(log_level):
type=str,
help="Path to directory containing generated kernels",
)
@click.option(
"--output-path",
default=None,
type=str,
help="Path for JSON output file with detailed results (if not specified, no JSON output)",
)
def cli(
log_level,
suite,
Expand All @@ -115,6 +121,7 @@ def cli(
kernel_agent_max_rounds,
torchbench_data_path,
ops_directory,
output_path,
):
setup_logging(log_level)
if ops:
Expand Down Expand Up @@ -175,14 +182,15 @@ def cli(

overall_correctness = []
overall_performance = []
verbose_results = []

for test in suite:
if test.op not in backend:
continue

logger.debug(test.op)

correctness, perf = eval.eval_one_op(
correctness, perf, op_verbose_data = eval.eval_one_op(
test.op,
backend[test.op],
test.correctness_tests,
Expand All @@ -191,13 +199,25 @@ def cli(
overall_correctness.append(correctness)
overall_performance.append(perf)

# Convert dict to list entries with op_name
op_name = getattr(test.op, "__name__", str(test.op))
for args_str, data in op_verbose_data.items():
entry = {"op_name": op_name, "args": args_str}
entry.update(data)
verbose_results.append(entry)

logger.debug(f"max memory allocated: {torch.cuda.max_memory_allocated():,}")

mean_correctness = torch.tensor(overall_correctness).mean().item()
geomean_perf = torch.tensor(overall_performance).log().mean().exp().item()
print(f"correctness score (mean pass rate over all operators): {mean_correctness:.2f}")
print(f"performance score (geomean speedup over all operators): {geomean_perf:.2f}")

# Save verbose results if output path is specified
if output_path and verbose_results:
eval.save_verbose_results(verbose_results, output_path)
print(f"Detailed results saved to: {output_path}")


def setup_llm_backend(llm_backend, llm_client, suite, max_attempts=5):
"""Setup LLM backend by generating kernels for all operations in the suite."""
Expand Down
4 changes: 2 additions & 2 deletions test/test_adverse_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_adaptive_avg_pool2d_backward_gpu(self):
# run test that should brick the gpu due to an illegal memory access
backend = backends.AtenBackend()
with pytest.raises(RuntimeError):
_, _ = eval_one_op(
_, _, _ = eval_one_op(
op_test_should_error.op,
backend[op_test_should_error.op],
list(op_test_should_error.correctness_tests),
Expand All @@ -43,7 +43,7 @@ def test_adaptive_avg_pool2d_backward_gpu(self):
torch.cuda.empty_cache()

# tests that a simple op works afterwards to make sure we recover after an illegal memory access
correctness, _ = eval_one_op(
correctness, _, _ = eval_one_op(
op_test_should_succeed.op,
backend[op_test_should_succeed.op],
list(op_test_should_succeed.correctness_tests),
Expand Down
25 changes: 17 additions & 8 deletions test/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def __init__(self, args, kwargs):

test = TestCase([torch.tensor([-1.0, 0.0, 1.0])], {})

result = eval_correctness_test(op, impl, test)
assert result is True
is_correct, error_msg, abs_error, rel_error = eval_correctness_test(op, impl, test)
assert is_correct is True

def test_eval_correctness_test_fail(self):
# Use different operations that produce different results
Expand All @@ -101,8 +101,8 @@ def __init__(self, args, kwargs):

test = TestCase([torch.tensor([1.0, 2.0, 3.0])], {})

result = eval_correctness_test(op, impl, test)
assert result is False
is_correct, error_msg, abs_error, rel_error = eval_correctness_test(op, impl, test)
assert is_correct is False

def test_eval_correctness_test_exception(self):
op = torch.relu
Expand All @@ -118,8 +118,11 @@ def __init__(self, args, kwargs):
test = TestCase([torch.tensor([1.0])], {})

# Just test that it returns False on exception
result = eval_correctness_test(op, impl_with_error, test)
assert result is False
is_correct, error_msg, abs_error, rel_error = eval_correctness_test(
op, impl_with_error, test
)
assert is_correct is False
assert error_msg is not None # Should have an error message

def test_eval_correctness_multiple_tests(self):
op = torch.abs
Expand All @@ -135,8 +138,10 @@ def __init__(self, args, kwargs):
test = TestCase([torch.tensor([float(i) - 2.5])], {})
tests.append(test)

score = eval_correctness(op, impl, tests)
verbose_data = {}
score = eval_correctness(op, impl, tests, verbose_data)
assert score == 1.0
assert len(verbose_data) == len(tests) # Should have data for each test


class TestEvalPerformance:
Expand Down Expand Up @@ -180,9 +185,13 @@ def __init__(self, args, kwargs):
correctness_tests = [TestCase([torch.tensor([-1.0, 0.0, 1.0])], {}) for _ in range(3)]
performance_tests = [TestCase([torch.tensor([-1.0, 0.0, 1.0])], {}) for _ in range(2)]

correctness, performance = eval_one_op(op, impl, correctness_tests, performance_tests)
correctness, performance, verbose_data = eval_one_op(
op, impl, correctness_tests, performance_tests
)

# Should have perfect correctness since using same implementation
assert correctness == 1.0
# Performance should be around 1.0 (same speed)
assert performance.item() > 0
# Verbose data should be populated
assert len(verbose_data) > 0
2 changes: 1 addition & 1 deletion test/test_facto_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_facto_suite_relu_default_correctness_not_empty(self):
assert value.numel() > 0, f"Tensor kwarg is empty for {test.op}"

# Evaluate the operation
correctness, _ = eval_one_op(
correctness, _, _ = eval_one_op(
test.op,
backend[test.op], # AtenBackend returns the original op
test.correctness_tests,
Expand Down
2 changes: 1 addition & 1 deletion test/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_smoke_suite_aten_backend(self, aten_backend):
if test.op not in aten_backend:
pytest.skip(f"Operation {test.op} not in backend")

correctness, perf = eval_one_op(
correctness, perf, _ = eval_one_op(
test.op,
aten_backend[test.op],
test.correctness_tests,
Expand Down
Loading