Skip to content

Commit f2b685a

Browse files
authored
Update how we calculate correctness score and performance score (#107)
1 parent d0ff8c6 commit f2b685a

File tree

5 files changed

+97
-10
lines changed

5 files changed

+97
-10
lines changed

BackendBench/eval.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,14 @@ def save_verbose_results(
205205
json.dump(results, f, indent=2)
206206

207207
logger.info(f"Verbose results saved to {output_path}")
208+
209+
210+
def perf_at_p(correctness, performance, p=1.0):
211+
assert len(correctness) == len(performance), (
212+
"correctness and performance must have the same length"
213+
)
214+
return (
215+
torch.where(torch.tensor(correctness).bool(), torch.tensor(performance) > p, 0)
216+
.float()
217+
.mean()
218+
)

BackendBench/scripts/main.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,16 @@ def setup_logging(log_level):
120120
type=int,
121121
help="Number of workers to use for multiprocessing, default to None to disable multiprocessing",
122122
)
123+
@click.option(
124+
"--p",
125+
default=1.0,
126+
type=float,
127+
help=(
128+
"Performance score threshold for perf@p score calculation"
129+
"Note: Increasing this value makes the threshold more stringent, "
130+
"requiring a higher speedup to meet the performance criteria."
131+
),
132+
)
123133
def cli(
124134
log_level,
125135
suite,
@@ -134,6 +144,7 @@ def cli(
134144
ops_directory,
135145
output_path,
136146
num_workers,
147+
p,
137148
):
138149
setup_logging(log_level)
139150
if ops:
@@ -209,7 +220,14 @@ def cli(
209220
test.correctness_tests,
210221
test.performance_tests,
211222
)
212-
overall_correctness.append(correctness)
223+
224+
overall_correctness.append(
225+
all(
226+
data["correctness_score"]
227+
for data in op_test_data.values()
228+
if "correctness_score" in data.keys()
229+
)
230+
)
213231
overall_performance.append(perf)
214232

215233
# Convert dict to list entries with op_name
@@ -243,7 +261,11 @@ def cli(
243261
results = evaluator.get_results()
244262

245263
for result in results:
246-
correctness_score = result.correctness_score
264+
correctness_score = all(
265+
data["correctness_score"]
266+
for data in result.test_data.values()
267+
if "correctness_score" in data.keys()
268+
)
247269
performance_score = result.performance_score
248270
overall_correctness.append(correctness_score)
249271
overall_performance.append(performance_score)
@@ -256,10 +278,14 @@ def cli(
256278
entry.update(data)
257279
verbose_results.append(entry)
258280

259-
mean_correctness = torch.tensor(overall_correctness).mean().item()
281+
mean_correctness = torch.tensor(overall_correctness).float().mean().item()
260282
geomean_perf = torch.tensor(overall_performance).log().mean().exp().item()
283+
perf_at_p_score = eval.perf_at_p(overall_correctness, overall_performance, p)
261284
print(f"correctness score (mean pass rate over all operators): {mean_correctness:.2f}")
262285
print(f"performance score (geomean speedup over all operators): {geomean_perf:.2f}")
286+
print(
287+
f"perf@p score (rate of correct samples with a speedup greater than p, p={p}): {perf_at_p_score:.2f}"
288+
)
263289

264290
# Save verbose results if output path is specified
265291
if output_path and verbose_results:

test/test_eval.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import pytest
88
import torch
9+
import numpy as np
910

1011
try:
1112
import importlib.util
@@ -17,6 +18,7 @@
1718
eval_one_op,
1819
cpu_bench,
1920
gpu_bench,
21+
perf_at_p,
2022
)
2123

2224
HAS_TRITON = importlib.util.find_spec("triton") is not None
@@ -219,3 +221,46 @@ def __init__(self, args, kwargs):
219221
assert performance.item() > 0
220222
# Verbose data should be populated
221223
assert len(test_data) > 0
224+
225+
226+
def fastp_kernel_bench(
227+
is_correct: np.ndarray, baseline_speed: np.ndarray, actual_speed: np.ndarray, n: int, p: float
228+
) -> float:
229+
"""
230+
Original fastp implementation from kernelBench
231+
"""
232+
filtered_baseline_speed = np.array([x for i, x in enumerate(baseline_speed) if is_correct[i]])
233+
filtered_actual_speed = np.array([x for i, x in enumerate(actual_speed) if is_correct[i]])
234+
speed_up = filtered_baseline_speed / filtered_actual_speed
235+
fast_p_score = np.sum(speed_up > p)
236+
return fast_p_score / n if n > 0 else 0
237+
238+
239+
class TestPerfAtP:
240+
def get_results(self, num_tests=100):
241+
overall_correctness = np.random.randint(0, 2, size=num_tests)
242+
overall_performance = np.random.uniform(0.5, 2, size=num_tests)
243+
return overall_correctness, overall_performance
244+
245+
def test_perf_at_p(self):
246+
for num_tests in [5, 10, 50, 100]:
247+
for p in [0, 1, 1.5, 2]:
248+
overall_correctness, overall_performance = self.get_results(num_tests)
249+
250+
actual_speed = np.random.randint(1, 101, size=num_tests)
251+
baseline_speed = actual_speed * overall_performance
252+
fastp_score_orig = fastp_kernel_bench(
253+
overall_correctness, baseline_speed, actual_speed, num_tests, p
254+
)
255+
256+
# Note: The perf@p score calculation here differs subtly from the original fastp score in
257+
# kernel bench. The original fastp score filters correct samples first, then averages.
258+
# Here, perf@p averages first, then filters correct samples. Despite this difference,
259+
# both methods produce equivalent results, so the test remains valid.
260+
perf_at_p_score = perf_at_p(
261+
overall_correctness.tolist(), overall_performance.tolist(), p
262+
)
263+
264+
assert torch.allclose(
265+
perf_at_p_score, torch.tensor(fastp_score_orig, dtype=torch.float32)
266+
)

test/test_facto_suite.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,20 @@ def test_facto_suite_relu_default_correctness_not_empty(self):
5353
assert value.numel() > 0, f"Tensor kwarg is empty for {test.op}"
5454

5555
# Evaluate the operation
56-
correctness, _, _ = eval_one_op(
56+
correctness, _, op_test_data = eval_one_op(
5757
test.op,
5858
backend[test.op], # AtenBackend returns the original op
5959
test.correctness_tests,
6060
test.performance_tests,
6161
)
62-
print(f"Correctness for {test.op}: {correctness}")
63-
overall_correctness.append(correctness)
62+
is_correct = all(data["correctness_score"] for data in op_test_data.values())
63+
overall_correctness.append(is_correct)
6464

6565
# Individual test assertions
6666
assert correctness > 0, f"Operation {test.op} failed all correctness tests"
6767

6868
# Calculate mean correctness
69-
mean_correctness = torch.tensor(overall_correctness).mean().item()
69+
mean_correctness = torch.tensor(overall_correctness).float().mean().item()
7070

7171
# Main assertion: correctness should be > 0.8
7272
assert mean_correctness > 0.8, (

test/test_smoke.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,25 @@ def test_smoke_suite_aten_backend(self, aten_backend):
3333
if test.op not in aten_backend:
3434
pytest.skip(f"Operation {test.op} not in backend")
3535

36-
correctness, perf, _ = eval_one_op(
36+
correctness, perf, op_test_data = eval_one_op(
3737
test.op,
3838
aten_backend[test.op],
3939
test.correctness_tests,
4040
test.performance_tests,
4141
)
4242

43-
overall_correctness.append(correctness)
43+
is_correct = all(
44+
data["correctness_score"]
45+
for data in op_test_data.values()
46+
if "correctness_score" in data.keys()
47+
)
48+
overall_correctness.append(is_correct)
4449
overall_performance.append(perf)
4550

4651
assert correctness > 0, f"Operation {test.op} failed all correctness tests"
4752
assert perf > 0.1, f"Operation {test.op} is more than 10x slower than reference"
4853

49-
mean_correctness = torch.tensor(overall_correctness).mean().item()
54+
mean_correctness = torch.tensor(overall_correctness).float().mean().item()
5055
geomean_perf = torch.tensor(overall_performance).log().mean().exp().item()
5156

5257
assert mean_correctness >= 0.8, (

0 commit comments

Comments
 (0)