diff --git a/BackendBench/eval.py b/BackendBench/eval.py index 2693cca..e0c3fac 100644 --- a/BackendBench/eval.py +++ b/BackendBench/eval.py @@ -4,14 +4,21 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +from collections import defaultdict +import json import logging +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple import torch try: - import triton.testing + if torch.cuda.is_available(): + import triton.testing - TRITON_AVAILABLE = True + TRITON_AVAILABLE = True + else: + TRITON_AVAILABLE = False except ImportError: TRITON_AVAILABLE = False @@ -42,23 +49,87 @@ def allclose(a, b): return a == b -def eval_correctness_test(op, impl, test): - """Evaluate impl of op against test.""" +def compute_errors(ref, res) -> Tuple[Optional[float], Optional[float]]: + """Compute absolute and relative errors between reference and result tensors. + + Returns: + Tuple of (absolute_error, relative_error) or (None, None) if not tensors/list of tensors + """ + if isinstance(ref, torch.Tensor) and isinstance(res, torch.Tensor): + if ref.shape != res.shape: + return None, None + + # Convert to float for error calculation + ref_float = ref.float() + res_float = res.float() + + # Absolute error + abs_error = (ref_float - res_float).abs().mean().item() + + # Relative error (avoid division by zero) + ref_abs = ref_float.abs() + rel_error = ((ref_float - res_float).abs() / (ref_abs + 1e-10)).mean().item() + + return abs_error, rel_error + elif isinstance(ref, (list, tuple)) and isinstance(res, (list, tuple)): + if len(ref) != len(res): + return None, None + + # For lists/tuples, compute mean error across all elements. + # We will return the mean of these means + mean_abs_error = 0.0 + mean_rel_error = 0.0 + + for r, s in zip(ref, res): + abs_err, rel_err = compute_errors(r, s) + mean_abs_error += abs_err + mean_rel_error += rel_err + + return mean_abs_error / len(ref), mean_rel_error / len(ref) + else: + return None, None + + +def eval_correctness_test( + op, impl, test +) -> Tuple[bool, Optional[str], Optional[float], Optional[float]]: + """Evaluate impl of op against test. + + Returns: + Tuple of (is_correct, error_message, absolute_error, relative_error) + """ args, kwargs = test.args, test.kwargs ref = op(*args, **kwargs) try: res = impl(*args, **kwargs) - return allclose(ref, res) + is_correct = allclose(ref, res) + + # Compute errors even if test passes (for verbose mode) + abs_error, rel_error = compute_errors(ref, res) + + return is_correct, None, abs_error, rel_error except Exception as e: - logger.warning(format_exception(e, op, args, kwargs)) - return False + error_msg = format_exception(e, op, args, kwargs) + logger.warning(error_msg) + return False, str(e), None, None -def eval_correctness(op, impl, tests): +def eval_correctness(op, impl, tests, verbose_data: defaultdict): + """Evaluate correctness of impl against tests.""" correct, total = 0, 0 for test in tests: - logging.debug(f"Testing {op.__name__} with args {serialize_args(test.args, test.kwargs)}") - if eval_correctness_test(op, impl, test): + args_str = serialize_args(test.args, test.kwargs) + logging.debug(f"Testing {op.__name__} with args {args_str}") + is_correct, error_msg, abs_error, rel_error = eval_correctness_test(op, impl, test) + + verbose_data[args_str] = { + "correctness_score": 1 if is_correct else 0, + "correctness_errors": error_msg or "", + "absolute_error": str(abs_error) if abs_error is not None else "", + "relative_error": str(rel_error) if rel_error is not None else "", + } + + if is_correct: correct += 1 total += 1 return correct / total @@ -77,34 +148,71 @@ def cpu_bench(fn, num_runs=100): return (time.perf_counter() - start) / num_runs -def eval_performance(op, impl, tests): +def eval_performance(op, impl, tests, verbose_data: defaultdict): + """Evaluate performance of impl against tests.""" bench_fn = ( triton.testing.do_bench if TRITON_AVAILABLE and torch.cuda.is_available() else cpu_bench ) base_times = [] test_times = [] + for test in tests: - logging.debug( - f"Benchmarking {op.__name__} with args {serialize_args(test.args, test.kwargs)}" - ) - base_times.append(bench_fn(lambda: op(*test.args, **test.kwargs))) + args_str = serialize_args(test.args, test.kwargs) + logging.debug(f"Benchmarking {op.__name__} with args {args_str}") + base_time = bench_fn(lambda: op(*test.args, **test.kwargs)) + base_times.append(base_time) + try: - allclose(op(*test.args, **test.kwargs), impl(*test.args, **test.kwargs)) + ref = op(*test.args, **test.kwargs) + res = impl(*test.args, **test.kwargs) + if not allclose(ref, res): + raise ValueError(f"Reference and result tensors are not close: {ref} vs {res}") + test_time = bench_fn(lambda: impl(*test.args, **test.kwargs)) except Exception: - test_times.append(base_times[-1]) - continue - test_times.append(bench_fn(lambda: impl(*test.args, **test.kwargs))) + test_time = -1 + + test_times.append(test_time) + verbose_data[args_str]["benchmark_time"] = str(test_time) + speedup = base_time / test_time if test_time > 0 else float("inf") + verbose_data[args_str]["speedup"] = str(speedup) + speedups = torch.tensor(base_times) / torch.tensor(test_times) return speedups.log().mean().exp() def eval_one_op(op, impl, correctness_tests, performance_tests): - """Evaluate impl of op against correctness_tests and performance_tests.""" - # TODO: We should have proper error reporting instead of just saying this is 0, - # but that should be a separate PR. + """Evaluate impl of op against correctness_tests and performance_tests. + + Returns: + Tuple of (correctness_score, performance_score, verbose_data) + """ + verbose_data = defaultdict(dict) + if uses_cuda_stream(impl): logger.warning(f"Skipping {op.__name__} because it uses CUDA stream") - return 0, 0 - return eval_correctness(op, impl, correctness_tests), eval_performance( - op, impl, performance_tests - ) + for test in correctness_tests + performance_tests: + args_str = serialize_args(test.args, test.kwargs) + verbose_data[args_str] = { + "correctness_score": 0, + "benchmark_time": "", + "speedup": "", + "correctness_errors": "Skipped: uses CUDA stream", + "absolute_error": "", + "relative_error": "", + } + return 0, 0, verbose_data + + correctness_score = eval_correctness(op, impl, correctness_tests, verbose_data) + performance_score = eval_performance(op, impl, performance_tests, verbose_data) + verbose_data = dict(verbose_data) + return correctness_score, performance_score, verbose_data + + +def save_verbose_results( + results: List[Dict[str, Any]], output_path: str = "backendbench_verbose_results.json" +): + """Save verbose results to a JSON file.""" + with open(Path(output_path), "w") as f: + json.dump(results, f, indent=2) + + logger.info(f"Verbose results saved to {output_path}") diff --git a/BackendBench/scripts/main.py b/BackendBench/scripts/main.py index 209c643..cf425e9 100644 --- a/BackendBench/scripts/main.py +++ b/BackendBench/scripts/main.py @@ -106,6 +106,10 @@ def setup_logging(log_level): help="Path to directory containing generated kernels", ) @click.option( + "--output-path", + default=None, + type=str, + help="Path for JSON output file with detailed results (if not specified, no JSON output)", "--num-workers", default=None, type=int, @@ -123,6 +127,7 @@ def cli( kernel_agent_max_rounds, torchbench_data_path, ops_directory, + output_path, num_workers, ): setup_logging(log_level) @@ -184,6 +189,7 @@ def cli( overall_correctness = [] overall_performance = [] + verbose_results = [] if num_workers is None: for test in suite: @@ -192,7 +198,7 @@ def cli( logger.debug(test.op) - correctness, perf = eval.eval_one_op( + correctness, perf, op_verbose_data = eval.eval_one_op( test.op, backend[test.op], test.correctness_tests, @@ -201,6 +207,13 @@ def cli( overall_correctness.append(correctness) overall_performance.append(perf) + # Convert dict to list entries with op_name + op_name = getattr(test.op, "__name__", str(test.op)) + for args_str, data in op_verbose_data.items(): + entry = {"op_name": op_name, "args": args_str} + entry.update(data) + verbose_results.append(entry) + logger.debug(f"max memory allocated: {torch.cuda.max_memory_allocated():,}") else: with multiprocessing_eval.MultiprocessingEvaluator(num_workers) as evaluator: @@ -232,6 +245,11 @@ def cli( print(f"correctness score (mean pass rate over all operators): {mean_correctness:.2f}") print(f"performance score (geomean speedup over all operators): {geomean_perf:.2f}") + # Save verbose results if output path is specified + if output_path and verbose_results: + eval.save_verbose_results(verbose_results, output_path) + print(f"Detailed results saved to: {output_path}") + def setup_llm_backend(llm_backend, llm_client, suite, max_attempts=5): """Setup LLM backend by generating kernels for all operations in the suite.""" diff --git a/test/test_eval.py b/test/test_eval.py index 78a58b8..f4088db 100644 --- a/test/test_eval.py +++ b/test/test_eval.py @@ -84,8 +84,8 @@ def __init__(self, args, kwargs): test = TestCase([torch.tensor([-1.0, 0.0, 1.0])], {}) - result = eval_correctness_test(op, impl, test) - assert result is True + is_correct, error_msg, abs_error, rel_error = eval_correctness_test(op, impl, test) + assert is_correct is True def test_eval_correctness_test_fail(self): # Use different operations that produce different results @@ -101,8 +101,8 @@ def __init__(self, args, kwargs): test = TestCase([torch.tensor([1.0, 2.0, 3.0])], {}) - result = eval_correctness_test(op, impl, test) - assert result is False + is_correct, error_msg, abs_error, rel_error = eval_correctness_test(op, impl, test) + assert is_correct is False def test_eval_correctness_test_exception(self): op = torch.relu @@ -118,8 +118,11 @@ def __init__(self, args, kwargs): test = TestCase([torch.tensor([1.0])], {}) # Just test that it returns False on exception - result = eval_correctness_test(op, impl_with_error, test) - assert result is False + is_correct, error_msg, abs_error, rel_error = eval_correctness_test( + op, impl_with_error, test + ) + assert is_correct is False + assert error_msg is not None # Should have an error message def test_eval_correctness_multiple_tests(self): op = torch.abs @@ -135,8 +138,10 @@ def __init__(self, args, kwargs): test = TestCase([torch.tensor([float(i) - 2.5])], {}) tests.append(test) - score = eval_correctness(op, impl, tests) + verbose_data = {} + score = eval_correctness(op, impl, tests, verbose_data) assert score == 1.0 + assert len(verbose_data) == len(tests) # Should have data for each test class TestEvalPerformance: @@ -180,9 +185,13 @@ def __init__(self, args, kwargs): correctness_tests = [TestCase([torch.tensor([-1.0, 0.0, 1.0])], {}) for _ in range(3)] performance_tests = [TestCase([torch.tensor([-1.0, 0.0, 1.0])], {}) for _ in range(2)] - correctness, performance = eval_one_op(op, impl, correctness_tests, performance_tests) + correctness, performance, verbose_data = eval_one_op( + op, impl, correctness_tests, performance_tests + ) # Should have perfect correctness since using same implementation assert correctness == 1.0 # Performance should be around 1.0 (same speed) assert performance.item() > 0 + # Verbose data should be populated + assert len(verbose_data) > 0 diff --git a/test/test_facto_suite.py b/test/test_facto_suite.py index 7aaa49d..c8a7948 100644 --- a/test/test_facto_suite.py +++ b/test/test_facto_suite.py @@ -53,7 +53,7 @@ def test_facto_suite_relu_default_correctness_not_empty(self): assert value.numel() > 0, f"Tensor kwarg is empty for {test.op}" # Evaluate the operation - correctness, _ = eval_one_op( + correctness, _, _ = eval_one_op( test.op, backend[test.op], # AtenBackend returns the original op test.correctness_tests, diff --git a/test/test_smoke.py b/test/test_smoke.py index 20a1ad8..23000c3 100644 --- a/test/test_smoke.py +++ b/test/test_smoke.py @@ -33,7 +33,7 @@ def test_smoke_suite_aten_backend(self, aten_backend): if test.op not in aten_backend: pytest.skip(f"Operation {test.op} not in backend") - correctness, perf = eval_one_op( + correctness, perf, _ = eval_one_op( test.op, aten_backend[test.op], test.correctness_tests,