Skip to content

Spawning isolated processes for each test #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Aug 19, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions BackendBench/eval_multiprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import logging
import multiprocessing as mp

import torch

import triton.testing

from BackendBench.utils import uses_cuda_stream, is_pickleable
from BackendBench.opregistry import get_operator, _extract_spec_name_from_op
from BackendBench.eval import allclose, cpu_bench

logger = logging.getLogger(__name__)


def _set_gpu_device(gpu_id):
if gpu_id is not None and torch.cuda.is_available():
if gpu_id < torch.cuda.device_count():
torch.cuda.set_device(gpu_id)
logger.debug(f"Set CUDA device to GPU {gpu_id}")
else:
logger.warning(f"GPU {gpu_id} not available. Using default device.")


def _run_single_correctness_test(op, impl, args, kwargs, gpu_id):
try:
_set_gpu_device(gpu_id)

if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()

# Get operators from string specs
if isinstance(op, str):
op = get_operator(op)
if isinstance(impl, str):
impl = get_operator(impl)

ref = op(*args, **kwargs)
res = impl(*args, **kwargs)

return allclose(ref, res)

except Exception:
return False
finally:
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()


def _run_single_performance_test(op, impl, args, kwargs, gpu_id):
_set_gpu_device(gpu_id)

if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()

# Get operators from string specs
if isinstance(op, str):
op = get_operator(op)
if isinstance(impl, str):
impl = get_operator(impl)

bench_fn = triton.testing.do_bench if torch.cuda.is_available() else cpu_bench
base_time = bench_fn(lambda: op(*args, **kwargs))

try:
allclose(op(*args, **kwargs), impl(*args, **kwargs))
except Exception:
test_time = base_time
return base_time, test_time

test_time = bench_fn(lambda: impl(*args, **kwargs))

if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()

return base_time, test_time


def eval_correctness_multiprocessing(op, impl, tests, num_workers):
if not is_pickleable(op):
op = _extract_spec_name_from_op(op)
if not is_pickleable(impl):
impl = _extract_spec_name_from_op(impl)

correct, total = 0, 0

mp.set_start_method("spawn", force=True)
with mp.Pool(num_workers) as pool:
while tests:
current_batch = tests[:num_workers] if len(tests) >= num_workers else tests
tests = tests[num_workers:] if len(tests) >= num_workers else []
async_results = []
for i, test in enumerate(current_batch):
async_result = pool.apply_async(
_run_single_correctness_test,
(op, impl, test.args, test.kwargs, i % num_workers),
)
async_results.append(async_result)

for async_result in async_results:
result = async_result.get()
if result:
correct += 1
total += 1

return correct / total if total > 0 else 0.0


def eval_performance_multiprocessing(op, impl, tests, num_workers, timeout=120):
if not is_pickleable(op):
op = _extract_spec_name_from_op(op)
if not is_pickleable(impl):
impl = _extract_spec_name_from_op(impl)

base_times = []
test_times = []

mp.set_start_method("spawn", force=True)
with mp.Pool(num_workers) as pool:
while tests:
current_batch = tests[:num_workers] if len(tests) >= num_workers else tests
tests = tests[num_workers:] if len(tests) >= num_workers else []
async_results = []
for i, test in enumerate(current_batch):
async_result = pool.apply_async(
_run_single_performance_test,
(op, impl, test.args, test.kwargs, i % num_workers),
)
async_results.append(async_result)

for i, async_result in enumerate(async_results):
base_time, test_time = async_result.get(timeout)
base_times.append(base_time)
test_times.append(test_time)

speedups = torch.tensor(base_times) / torch.tensor(test_times)
return speedups.log().mean().exp()


def eval_one_op_multiprocessing(
op, impl, correctness_tests, performance_tests, num_workers: int = None
):
if uses_cuda_stream(impl):
logger.warning(f"Skipping {op.__name__} because it uses CUDA stream")
return 0, 0

if num_workers is None:
num_workers = 1

return eval_correctness_multiprocessing(
op, impl, correctness_tests, num_workers
), eval_performance_multiprocessing(op, impl, correctness_tests, num_workers)
12 changes: 12 additions & 0 deletions BackendBench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,15 @@ def deserialize_args(inps):
for key in dtype_abbrs_parsing:
inps = inps.replace(f"'{key}'", key)
return eval(inps.strip().strip("'").strip('"'), global_vals)


def is_pickleable(obj):
import pickle
import io

try:
with io.BytesIO() as stream:
pickle.dump(obj, stream)
return True
except Exception:
return False
11 changes: 6 additions & 5 deletions test/test_adverse_cases.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import pytest
from BackendBench.torchbench_suite import TorchBenchOpTest
from BackendBench.eval import eval_one_op
from BackendBench.eval_multiprocessing import eval_one_op_multiprocessing
import BackendBench.backends as backends
import torch


class TestAdaptiveAvgPool2dBackward:
# todo: @jiannanWang unskip this test
@pytest.mark.skip(reason="Not ready for testing yet as it'd brick the gpu")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU")
def test_adaptive_avg_pool2d_backward_gpu(self):
"""Test on GPU with eval_one_op."""
op_test_should_error = TorchBenchOpTest(
Expand All @@ -25,23 +24,25 @@ def test_adaptive_avg_pool2d_backward_gpu(self):
# run test that should brick the gpu due to an illegal memory access
backend = backends.AtenBackend()
with pytest.raises(RuntimeError):
_, _ = eval_one_op(
_, _ = eval_one_op_multiprocessing(
op_test_should_error.op,
backend[op_test_should_error.op],
list(op_test_should_error.correctness_tests),
list(op_test_should_error.performance_tests),
torch.cuda.device_count(),
)

# add these in case code changes in eval_one_op. There shouldn't be any errors here
torch.cuda.synchronize()
torch.cuda.empty_cache()

# tests that a simple op works afterwards to make sure we recover after an illegal memory access
correctness, _ = eval_one_op(
correctness, _ = eval_one_op_multiprocessing(
op_test_should_succeed.op,
backend[op_test_should_succeed.op],
list(op_test_should_succeed.correctness_tests),
list(op_test_should_succeed.performance_tests),
torch.cuda.device_count(),
)

assert correctness == 1.0
Expand Down
57 changes: 57 additions & 0 deletions test/test_eval_multiprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
import torch

try:
import importlib.util
from BackendBench.eval_multiprocessing import (
eval_correctness_multiprocessing,
eval_one_op_multiprocessing,
)

HAS_TRITON = importlib.util.find_spec("triton") is not None
except ImportError:
HAS_TRITON = False

pytestmark = pytest.mark.skipif(not HAS_TRITON, reason="triton not available")


class TestEvalCorrectnessMultiprocessing:
def test_eval_correctness_multiple_tests(self):
op = torch.abs
impl = torch.abs # Same implementation

class TestCase:
def __init__(self, args, kwargs):
self.args = args
self.kwargs = kwargs

tests = []
for i in range(5):
test = TestCase([torch.tensor([float(i) - 2.5])], {})
tests.append(test)

score = eval_correctness_multiprocessing(op, impl, tests, torch.cuda.device_count())
assert score == 1.0


class TestEvalOneOp:
def test_eval_one_op(self):
op = torch.relu
impl = torch.relu # Same implementation

class TestCase:
def __init__(self, args, kwargs):
self.args = args
self.kwargs = kwargs

correctness_tests = [TestCase([torch.tensor([-1.0, 0.0, 1.0])], {}) for _ in range(3)]
performance_tests = [TestCase([torch.tensor([-1.0, 0.0, 1.0])], {}) for _ in range(2)]

correctness, performance = eval_one_op_multiprocessing(
op, impl, correctness_tests, performance_tests
)

# Should have perfect correctness since using same implementation
assert correctness == 1.0
# Performance should be around 1.0 (same speed)
assert performance.item() > 0