Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions BackendBench/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
exc: {exc}
"""

FAIL_FACTOR = 1.1


def format_exception(e, op, args, kwargs):
op_name = getattr(op, "__name__", str(op))
Expand Down Expand Up @@ -61,7 +63,7 @@ def eval_correctness(op, impl, tests):
if eval_correctness_test(op, impl, test):
correct += 1
total += 1
return correct / total
return correct == total, correct, total


def cpu_bench(fn, num_runs=100):
Expand Down Expand Up @@ -91,7 +93,7 @@ def eval_performance(op, impl, tests):
try:
allclose(op(*test.args, **test.kwargs), impl(*test.args, **test.kwargs))
except Exception:
test_times.append(base_times[-1])
test_times.append(base_times[-1] * FAIL_FACTOR)
continue
test_times.append(bench_fn(lambda: impl(*test.args, **test.kwargs)))
speedups = torch.tensor(base_times) / torch.tensor(test_times)
Expand All @@ -104,7 +106,7 @@ def eval_one_op(op, impl, correctness_tests, performance_tests):
# but that should be a separate PR.
if uses_cuda_stream(impl):
logger.warning(f"Skipping {op.__name__} because it uses CUDA stream")
return 0.0, 1.0
return (False, 0, 0), 1.0
return eval_correctness(op, impl, correctness_tests), eval_performance(
op, impl, performance_tests
)
2 changes: 1 addition & 1 deletion BackendBench/multiprocessing_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _worker_process(worker_id, task_queue, result_queue):
break
result = EvalResult(
task_id=task.task_id,
correctness_score=0.0,
correctness_score=(False, 0, 0),
performance_score=1.0,
error=error_msg,
)
Expand Down
6 changes: 3 additions & 3 deletions BackendBench/scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def cli(
test.correctness_tests,
test.performance_tests,
)
overall_correctness.append(correctness)
overall_correctness.append(correctness[0])
overall_performance.append(perf)

logger.debug(f"max memory allocated: {torch.cuda.max_memory_allocated():,}")
Expand All @@ -222,12 +222,12 @@ def cli(
results = evaluator.get_results()

for result in results:
correctness_score = result.correctness_score
correctness_score = result.correctness_score[0]
performance_score = result.performance_score
overall_correctness.append(correctness_score)
overall_performance.append(performance_score)

mean_correctness = torch.tensor(overall_correctness).mean().item()
mean_correctness = torch.tensor(overall_correctness).float().mean().item()
geomean_perf = torch.tensor(overall_performance).log().mean().exp().item()
print(f"correctness score (mean pass rate over all operators): {mean_correctness:.2f}")
print(f"performance score (geomean speedup over all operators): {geomean_perf:.2f}")
Expand Down
4 changes: 2 additions & 2 deletions test/test_adverse_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_adaptive_avg_pool2d_backward_gpu(self):
results = evaluator.get_results()

assert len(results) == 1
assert results[0].correctness_score == 1.0
assert results[0].correctness_score[0]


class TestCase:
Expand All @@ -74,7 +74,7 @@ def test_multiprocessing_evaluator(self):

assert len(results) == 1
# Should have perfect correctness since using same implementation
assert results[0].correctness_score == 1.0
assert results[0].correctness_score[0]
# Performance should be around 1.0 (same speed)
assert results[0].performance_score.item() > 0

Expand Down
10 changes: 6 additions & 4 deletions test/test_backend_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def test_2_watermarked_implementations_fail_correctness(self):
try:
impl = backend[op]
test = Test(*arg_generators)
correctness = eval_correctness(op, impl, [test])
_, correct, _ = eval_correctness(op, impl, [test])

total_tested += 1
if correctness == 0.0:
if correct == 0:
failed_count += 1
print(f" ✓ {str(op).split('.')[-2]}: Failed correctness (watermarked)")
else:
Expand Down Expand Up @@ -183,11 +183,13 @@ def test_4_eval_integration(self):
correctness, performance = eval_one_op(test_op, impl, [test], [test])

print(f" Operation: {test_op}")
print(f" Correctness: {correctness}")
print(f" Correctness: {correctness[0]}")
print(f" Performance: {performance}")

# Watermarked implementation should fail correctness
self.assertEqual(correctness, 0.0, "Watermarked implementation should fail correctness")
self.assertEqual(
correctness[0], 0, "Watermarked implementation should fail correctness"
)

print(" ✓ eval_one_op works correctly with watermarked implementation")
else:
Expand Down
2 changes: 1 addition & 1 deletion test/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,6 @@ def __init__(self, args, kwargs):
correctness, performance = eval_one_op(op, impl, correctness_tests, performance_tests)

# Should have perfect correctness since using same implementation
assert correctness == 1.0
assert correctness[0] == 1
# Performance should be around 1.0 (same speed)
assert performance.item() > 0
8 changes: 4 additions & 4 deletions test/test_facto_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ def test_facto_suite_relu_default_correctness_not_empty(self):
test.correctness_tests,
test.performance_tests,
)
print(f"Correctness for {test.op}: {correctness}")
overall_correctness.append(correctness)
print(f"Correctness for {test.op}: {correctness[0]}")
overall_correctness.append(correctness[0])

# Individual test assertions
assert correctness > 0, f"Operation {test.op} failed all correctness tests"
assert correctness[1] > 0, f"Operation {test.op} failed all correctness tests"

# Calculate mean correctness
mean_correctness = torch.tensor(overall_correctness).mean().item()
mean_correctness = torch.tensor(overall_correctness).float().mean().item()

# Main assertion: correctness should be > 0.8
assert mean_correctness > 0.8, (
Expand Down
6 changes: 3 additions & 3 deletions test/test_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def test_smoke_suite_aten_backend(self, aten_backend):
test.performance_tests,
)

overall_correctness.append(correctness)
overall_correctness.append(correctness[0])
overall_performance.append(perf)

assert correctness > 0, f"Operation {test.op} failed all correctness tests"
assert correctness[1] > 0, f"Operation {test.op} failed all correctness tests"
assert perf > 0.1, f"Operation {test.op} is more than 10x slower than reference"

mean_correctness = torch.tensor(overall_correctness).mean().item()
mean_correctness = torch.tensor(overall_correctness).float().mean().item()
geomean_perf = torch.tensor(overall_performance).log().mean().exp().item()

assert mean_correctness >= 0.8, (
Expand Down