Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions BackendBench/score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import torch


def fastp(correctness, performance, p=0.8):
assert len(correctness) == len(performance), (
"correctness and performance must have the same length"
)
return (
torch.where(torch.tensor(correctness).bool(), torch.tensor(performance) > p, 0)
.float()
.mean()
)
20 changes: 17 additions & 3 deletions BackendBench/scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TorchBenchTestSuite,
FactoTestSuite,
)
from BackendBench.score import fastp

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -209,7 +210,14 @@ def cli(
test.correctness_tests,
test.performance_tests,
)
overall_correctness.append(correctness)

overall_correctness.append(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're calculating correctness score here as an aggregate of all the tests per op. Therefore, we are calculating perf@p on the level of the aggregates per op rather than individual tests as kernelbench does.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe KernelBench fastp is at the same level as us because one task in KernelBench corresponds to one operation in BackendBench.

For each task in KernelBench, they verify the correctness of the generated kernel by comparing it against reference PyTorch operators multiple times using randomized inputs. Then, they measure speedup over multiple runs. The final fastp metric is calculated at the kernel level, rather than for individual runs. This is the same as BackendBench, where we verify the correctness of an operation by running a series of correctness tests and compute speedup by averaging performance results across multiple tests.

all(
data["correctness_score"]
for data in op_test_data.values()
if "correctness_score" in data.keys()
)
)
overall_performance.append(perf)

# Convert dict to list entries with op_name
Expand Down Expand Up @@ -243,7 +251,11 @@ def cli(
results = evaluator.get_results()

for result in results:
correctness_score = result.correctness_score
correctness_score = all(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a per@p on the operator level as well in test_data (I'd just do this in eval.py)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how perf@p works on the operator level. Currently, perf@p is defined as "the ratio of ops that are both correct and have a speedup greater than p." Applying this metric at the operator level may require adjusting the definition.

data["correctness_score"]
for data in op_test_data.values()
if "correctness_score" in data.keys()
)
performance_score = result.performance_score
overall_correctness.append(correctness_score)
overall_performance.append(performance_score)
Expand All @@ -256,10 +268,12 @@ def cli(
entry.update(data)
verbose_results.append(entry)

mean_correctness = torch.tensor(overall_correctness).mean().item()
mean_correctness = torch.tensor(overall_correctness).float().mean().item()
geomean_perf = torch.tensor(overall_performance).log().mean().exp().item()
fastp_score = fastp(overall_correctness, overall_performance)
print(f"correctness score (mean pass rate over all operators): {mean_correctness:.2f}")
print(f"performance score (geomean speedup over all operators): {geomean_perf:.2f}")
print(f"fastp score (rate of correct samples with a speedup greater than p): {fastp_score:.2f}")

# Save verbose results if output path is specified
if output_path and verbose_results:
Expand Down
8 changes: 4 additions & 4 deletions test/test_facto_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,20 @@ def test_facto_suite_relu_default_correctness_not_empty(self):
assert value.numel() > 0, f"Tensor kwarg is empty for {test.op}"

# Evaluate the operation
correctness, _, _ = eval_one_op(
correctness, _, op_test_data = eval_one_op(
test.op,
backend[test.op], # AtenBackend returns the original op
test.correctness_tests,
test.performance_tests,
)
print(f"Correctness for {test.op}: {correctness}")
overall_correctness.append(correctness)
is_correct = all(data["correctness_score"] for data in op_test_data.values())
overall_correctness.append(is_correct)

# Individual test assertions
assert correctness > 0, f"Operation {test.op} failed all correctness tests"

# Calculate mean correctness
mean_correctness = torch.tensor(overall_correctness).mean().item()
mean_correctness = torch.tensor(overall_correctness).float().mean().item()

# Main assertion: correctness should be > 0.8
assert mean_correctness > 0.8, (
Expand Down
46 changes: 46 additions & 0 deletions test/test_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch
from BackendBench.score import fastp


def fastp_kernel_bench(
is_correct: np.ndarray, baseline_speed: np.ndarray, actual_speed: np.ndarray, n: int, p: float
) -> float:
"""
Original fastp implementation from kernelBench
"""
filtered_baseline_speed = np.array([x for i, x in enumerate(baseline_speed) if is_correct[i]])
filtered_actual_speed = np.array([x for i, x in enumerate(actual_speed) if is_correct[i]])
speed_up = filtered_baseline_speed / filtered_actual_speed
fast_p_score = np.sum(speed_up > p)
return fast_p_score / n if n > 0 else 0


class TestFastp:
def get_results(self, num_tests=100):
overall_correctness = np.random.randint(0, 2, size=num_tests)
overall_performance = np.random.uniform(0.5, 2, size=num_tests)
return overall_correctness, overall_performance

def test_fastp(self):
for num_tests in [5, 10, 50, 100]:
for p in [0, 1, 1.5, 2]:
overall_correctness, overall_performance = self.get_results(num_tests)

actual_speed = np.random.randint(1, 101, size=num_tests)
baseline_speed = actual_speed * overall_performance
fastp_score_orig = fastp_kernel_bench(
overall_correctness, baseline_speed, actual_speed, num_tests, p
)

fastp_score = fastp(overall_correctness.tolist(), overall_performance.tolist(), p)

assert torch.allclose(
fastp_score, torch.tensor(fastp_score_orig, dtype=torch.float32)
)
11 changes: 8 additions & 3 deletions test/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,25 @@ def test_smoke_suite_aten_backend(self, aten_backend):
if test.op not in aten_backend:
pytest.skip(f"Operation {test.op} not in backend")

correctness, perf, _ = eval_one_op(
correctness, perf, op_test_data = eval_one_op(
test.op,
aten_backend[test.op],
test.correctness_tests,
test.performance_tests,
)

overall_correctness.append(correctness)
is_correct = all(
data["correctness_score"]
for data in op_test_data.values()
if "correctness_score" in data.keys()
)
overall_correctness.append(is_correct)
overall_performance.append(perf)

assert correctness > 0, f"Operation {test.op} failed all correctness tests"
assert perf > 0.1, f"Operation {test.op} is more than 10x slower than reference"

mean_correctness = torch.tensor(overall_correctness).mean().item()
mean_correctness = torch.tensor(overall_correctness).float().mean().item()
geomean_perf = torch.tensor(overall_performance).log().mean().exp().item()

assert mean_correctness >= 0.8, (
Expand Down