diff --git a/BackendBench/eval.py b/BackendBench/eval.py index fea83d37..d0a5cde7 100644 --- a/BackendBench/eval.py +++ b/BackendBench/eval.py @@ -205,3 +205,14 @@ def save_verbose_results( json.dump(results, f, indent=2) logger.info(f"Verbose results saved to {output_path}") + + +def perf_at_p(correctness, performance, p=1.0): + assert len(correctness) == len(performance), ( + "correctness and performance must have the same length" + ) + return ( + torch.where(torch.tensor(correctness).bool(), torch.tensor(performance) > p, 0) + .float() + .mean() + ) diff --git a/BackendBench/scripts/main.py b/BackendBench/scripts/main.py index 957499df..7be5777e 100644 --- a/BackendBench/scripts/main.py +++ b/BackendBench/scripts/main.py @@ -120,6 +120,16 @@ def setup_logging(log_level): type=int, help="Number of workers to use for multiprocessing, default to None to disable multiprocessing", ) +@click.option( + "--p", + default=1.0, + type=float, + help=( + "Performance score threshold for perf@p score calculation" + "Note: Increasing this value makes the threshold more stringent, " + "requiring a higher speedup to meet the performance criteria." + ), +) def cli( log_level, suite, @@ -134,6 +144,7 @@ def cli( ops_directory, output_path, num_workers, + p, ): setup_logging(log_level) if ops: @@ -209,7 +220,14 @@ def cli( test.correctness_tests, test.performance_tests, ) - overall_correctness.append(correctness) + + overall_correctness.append( + all( + data["correctness_score"] + for data in op_test_data.values() + if "correctness_score" in data.keys() + ) + ) overall_performance.append(perf) # Convert dict to list entries with op_name @@ -243,7 +261,11 @@ def cli( results = evaluator.get_results() for result in results: - correctness_score = result.correctness_score + correctness_score = all( + data["correctness_score"] + for data in result.test_data.values() + if "correctness_score" in data.keys() + ) performance_score = result.performance_score overall_correctness.append(correctness_score) overall_performance.append(performance_score) @@ -256,10 +278,14 @@ def cli( entry.update(data) verbose_results.append(entry) - mean_correctness = torch.tensor(overall_correctness).mean().item() + mean_correctness = torch.tensor(overall_correctness).float().mean().item() geomean_perf = torch.tensor(overall_performance).log().mean().exp().item() + perf_at_p_score = eval.perf_at_p(overall_correctness, overall_performance, p) print(f"correctness score (mean pass rate over all operators): {mean_correctness:.2f}") print(f"performance score (geomean speedup over all operators): {geomean_perf:.2f}") + print( + f"perf@p score (rate of correct samples with a speedup greater than p, p={p}): {perf_at_p_score:.2f}" + ) # Save verbose results if output path is specified if output_path and verbose_results: diff --git a/test/test_eval.py b/test/test_eval.py index 30a2f9c5..27b960dd 100644 --- a/test/test_eval.py +++ b/test/test_eval.py @@ -6,6 +6,7 @@ import pytest import torch +import numpy as np try: import importlib.util @@ -17,6 +18,7 @@ eval_one_op, cpu_bench, gpu_bench, + perf_at_p, ) HAS_TRITON = importlib.util.find_spec("triton") is not None @@ -219,3 +221,46 @@ def __init__(self, args, kwargs): assert performance.item() > 0 # Verbose data should be populated assert len(test_data) > 0 + + +def fastp_kernel_bench( + is_correct: np.ndarray, baseline_speed: np.ndarray, actual_speed: np.ndarray, n: int, p: float +) -> float: + """ + Original fastp implementation from kernelBench + """ + filtered_baseline_speed = np.array([x for i, x in enumerate(baseline_speed) if is_correct[i]]) + filtered_actual_speed = np.array([x for i, x in enumerate(actual_speed) if is_correct[i]]) + speed_up = filtered_baseline_speed / filtered_actual_speed + fast_p_score = np.sum(speed_up > p) + return fast_p_score / n if n > 0 else 0 + + +class TestPerfAtP: + def get_results(self, num_tests=100): + overall_correctness = np.random.randint(0, 2, size=num_tests) + overall_performance = np.random.uniform(0.5, 2, size=num_tests) + return overall_correctness, overall_performance + + def test_perf_at_p(self): + for num_tests in [5, 10, 50, 100]: + for p in [0, 1, 1.5, 2]: + overall_correctness, overall_performance = self.get_results(num_tests) + + actual_speed = np.random.randint(1, 101, size=num_tests) + baseline_speed = actual_speed * overall_performance + fastp_score_orig = fastp_kernel_bench( + overall_correctness, baseline_speed, actual_speed, num_tests, p + ) + + # Note: The perf@p score calculation here differs subtly from the original fastp score in + # kernel bench. The original fastp score filters correct samples first, then averages. + # Here, perf@p averages first, then filters correct samples. Despite this difference, + # both methods produce equivalent results, so the test remains valid. + perf_at_p_score = perf_at_p( + overall_correctness.tolist(), overall_performance.tolist(), p + ) + + assert torch.allclose( + perf_at_p_score, torch.tensor(fastp_score_orig, dtype=torch.float32) + ) diff --git a/test/test_facto_suite.py b/test/test_facto_suite.py index 8ca2a347..fb5aee6f 100644 --- a/test/test_facto_suite.py +++ b/test/test_facto_suite.py @@ -53,20 +53,20 @@ def test_facto_suite_relu_default_correctness_not_empty(self): assert value.numel() > 0, f"Tensor kwarg is empty for {test.op}" # Evaluate the operation - correctness, _, _ = eval_one_op( + correctness, _, op_test_data = eval_one_op( test.op, backend[test.op], # AtenBackend returns the original op test.correctness_tests, test.performance_tests, ) - print(f"Correctness for {test.op}: {correctness}") - overall_correctness.append(correctness) + is_correct = all(data["correctness_score"] for data in op_test_data.values()) + overall_correctness.append(is_correct) # Individual test assertions assert correctness > 0, f"Operation {test.op} failed all correctness tests" # Calculate mean correctness - mean_correctness = torch.tensor(overall_correctness).mean().item() + mean_correctness = torch.tensor(overall_correctness).float().mean().item() # Main assertion: correctness should be > 0.8 assert mean_correctness > 0.8, ( diff --git a/test/test_smoke.py b/test/test_smoke.py index 23000c36..de4ab363 100644 --- a/test/test_smoke.py +++ b/test/test_smoke.py @@ -33,20 +33,25 @@ def test_smoke_suite_aten_backend(self, aten_backend): if test.op not in aten_backend: pytest.skip(f"Operation {test.op} not in backend") - correctness, perf, _ = eval_one_op( + correctness, perf, op_test_data = eval_one_op( test.op, aten_backend[test.op], test.correctness_tests, test.performance_tests, ) - overall_correctness.append(correctness) + is_correct = all( + data["correctness_score"] + for data in op_test_data.values() + if "correctness_score" in data.keys() + ) + overall_correctness.append(is_correct) overall_performance.append(perf) assert correctness > 0, f"Operation {test.op} failed all correctness tests" assert perf > 0.1, f"Operation {test.op} is more than 10x slower than reference" - mean_correctness = torch.tensor(overall_correctness).mean().item() + mean_correctness = torch.tensor(overall_correctness).float().mean().item() geomean_perf = torch.tensor(overall_performance).log().mean().exp().item() assert mean_correctness >= 0.8, (