-
Notifications
You must be signed in to change notification settings - Fork 1
Spawning isolated processes for each test #67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+393
−37
Merged
Changes from 8 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
ebeb9b8
create one process for each test
jiannanWang 28ab6cd
fix
jiannanWang a7dd7ae
fix
jiannanWang e40a7d8
ruff
jiannanWang 32674da
fix
jiannanWang a5d8a3d
ruff
jiannanWang 320cfdb
Merge branch 'main' into jiannanWang/eval_multiprocessing
jiannanWang fe3c17c
add eval_multiprocessing, test_eval_multiprocessing and revert eval
jiannanWang 5f20147
updating multiprocessing implementation
jiannanWang bb1389b
Merge branch 'main' into jiannanWang/eval_multiprocessing
jiannanWang 8501273
ruff
jiannanWang b4bf82c
fix
jiannanWang 3d3648e
default to disable multiprocessing
jiannanWang 4f3c1df
refactor
jiannanWang 0eee43e
fix
jiannanWang 83a9a85
deselect large tensor testing in smoke test
jiannanWang 463c3f8
fix
jiannanWang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import logging | ||
import multiprocessing as mp | ||
|
||
import torch | ||
|
||
import triton.testing | ||
|
||
from BackendBench.utils import uses_cuda_stream, is_pickleable | ||
from BackendBench.opregistry import get_operator, _extract_spec_name_from_op | ||
from BackendBench.eval import allclose, cpu_bench | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def _set_gpu_device(gpu_id): | ||
if gpu_id is not None and torch.cuda.is_available(): | ||
if gpu_id < torch.cuda.device_count(): | ||
torch.cuda.set_device(gpu_id) | ||
jiannanWang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
logger.debug(f"Set CUDA device to GPU {gpu_id}") | ||
else: | ||
logger.warning(f"GPU {gpu_id} not available. Using default device.") | ||
|
||
|
||
def _run_single_correctness_test(op, impl, args, kwargs, gpu_id): | ||
try: | ||
_set_gpu_device(gpu_id) | ||
|
||
if torch.cuda.is_available(): | ||
torch.cuda.synchronize() | ||
torch.cuda.empty_cache() | ||
|
||
# Get operators from string specs | ||
if isinstance(op, str): | ||
jiannanWang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
op = get_operator(op) | ||
if isinstance(impl, str): | ||
impl = get_operator(impl) | ||
|
||
ref = op(*args, **kwargs) | ||
res = impl(*args, **kwargs) | ||
|
||
return allclose(ref, res) | ||
|
||
except Exception: | ||
return False | ||
finally: | ||
if torch.cuda.is_available(): | ||
torch.cuda.synchronize() | ||
torch.cuda.empty_cache() | ||
|
||
|
||
def _run_single_performance_test(op, impl, args, kwargs, gpu_id): | ||
_set_gpu_device(gpu_id) | ||
|
||
if torch.cuda.is_available(): | ||
torch.cuda.synchronize() | ||
torch.cuda.empty_cache() | ||
|
||
# Get operators from string specs | ||
if isinstance(op, str): | ||
op = get_operator(op) | ||
if isinstance(impl, str): | ||
impl = get_operator(impl) | ||
|
||
bench_fn = triton.testing.do_bench if torch.cuda.is_available() else cpu_bench | ||
base_time = bench_fn(lambda: op(*args, **kwargs)) | ||
|
||
try: | ||
allclose(op(*args, **kwargs), impl(*args, **kwargs)) | ||
except Exception: | ||
test_time = base_time | ||
jiannanWang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return base_time, test_time | ||
|
||
test_time = bench_fn(lambda: impl(*args, **kwargs)) | ||
|
||
if torch.cuda.is_available(): | ||
torch.cuda.synchronize() | ||
torch.cuda.empty_cache() | ||
|
||
return base_time, test_time | ||
|
||
|
||
def eval_correctness_multiprocessing(op, impl, tests, num_workers): | ||
if not is_pickleable(op): | ||
jiannanWang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
op = _extract_spec_name_from_op(op) | ||
if not is_pickleable(impl): | ||
impl = _extract_spec_name_from_op(impl) | ||
|
||
correct, total = 0, 0 | ||
|
||
mp.set_start_method("spawn", force=True) | ||
jiannanWang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
with mp.Pool(num_workers) as pool: | ||
while tests: | ||
current_batch = tests[:num_workers] if len(tests) >= num_workers else tests | ||
tests = tests[num_workers:] if len(tests) >= num_workers else [] | ||
async_results = [] | ||
for i, test in enumerate(current_batch): | ||
async_result = pool.apply_async( | ||
_run_single_correctness_test, | ||
(op, impl, test.args, test.kwargs, i % num_workers), | ||
) | ||
async_results.append(async_result) | ||
|
||
for async_result in async_results: | ||
jiannanWang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
result = async_result.get() | ||
if result: | ||
correct += 1 | ||
total += 1 | ||
|
||
return correct / total if total > 0 else 0.0 | ||
|
||
|
||
def eval_performance_multiprocessing(op, impl, tests, num_workers, timeout=120): | ||
if not is_pickleable(op): | ||
op = _extract_spec_name_from_op(op) | ||
if not is_pickleable(impl): | ||
impl = _extract_spec_name_from_op(impl) | ||
|
||
base_times = [] | ||
test_times = [] | ||
|
||
mp.set_start_method("spawn", force=True) | ||
with mp.Pool(num_workers) as pool: | ||
while tests: | ||
current_batch = tests[:num_workers] if len(tests) >= num_workers else tests | ||
tests = tests[num_workers:] if len(tests) >= num_workers else [] | ||
async_results = [] | ||
for i, test in enumerate(current_batch): | ||
async_result = pool.apply_async( | ||
_run_single_performance_test, | ||
(op, impl, test.args, test.kwargs, i % num_workers), | ||
) | ||
async_results.append(async_result) | ||
|
||
for i, async_result in enumerate(async_results): | ||
base_time, test_time = async_result.get(timeout) | ||
base_times.append(base_time) | ||
test_times.append(test_time) | ||
|
||
speedups = torch.tensor(base_times) / torch.tensor(test_times) | ||
return speedups.log().mean().exp() | ||
jiannanWang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def eval_one_op_multiprocessing( | ||
op, impl, correctness_tests, performance_tests, num_workers: int = None | ||
): | ||
if uses_cuda_stream(impl): | ||
logger.warning(f"Skipping {op.__name__} because it uses CUDA stream") | ||
return 0, 0 | ||
|
||
if num_workers is None: | ||
num_workers = 1 | ||
|
||
return eval_correctness_multiprocessing( | ||
op, impl, correctness_tests, num_workers | ||
), eval_performance_multiprocessing(op, impl, correctness_tests, num_workers) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import pytest | ||
import torch | ||
|
||
try: | ||
import importlib.util | ||
from BackendBench.eval_multiprocessing import ( | ||
eval_correctness_multiprocessing, | ||
eval_one_op_multiprocessing, | ||
) | ||
|
||
HAS_TRITON = importlib.util.find_spec("triton") is not None | ||
jiannanWang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
except ImportError: | ||
HAS_TRITON = False | ||
|
||
pytestmark = pytest.mark.skipif(not HAS_TRITON, reason="triton not available") | ||
|
||
|
||
class TestEvalCorrectnessMultiprocessing: | ||
def test_eval_correctness_multiple_tests(self): | ||
op = torch.abs | ||
impl = torch.abs # Same implementation | ||
|
||
class TestCase: | ||
def __init__(self, args, kwargs): | ||
self.args = args | ||
self.kwargs = kwargs | ||
|
||
tests = [] | ||
for i in range(5): | ||
test = TestCase([torch.tensor([float(i) - 2.5])], {}) | ||
tests.append(test) | ||
|
||
score = eval_correctness_multiprocessing(op, impl, tests, torch.cuda.device_count()) | ||
assert score == 1.0 | ||
|
||
|
||
class TestEvalOneOp: | ||
def test_eval_one_op(self): | ||
op = torch.relu | ||
impl = torch.relu # Same implementation | ||
|
||
class TestCase: | ||
def __init__(self, args, kwargs): | ||
self.args = args | ||
self.kwargs = kwargs | ||
|
||
correctness_tests = [TestCase([torch.tensor([-1.0, 0.0, 1.0])], {}) for _ in range(3)] | ||
performance_tests = [TestCase([torch.tensor([-1.0, 0.0, 1.0])], {}) for _ in range(2)] | ||
|
||
correctness, performance = eval_one_op_multiprocessing( | ||
op, impl, correctness_tests, performance_tests | ||
) | ||
|
||
# Should have perfect correctness since using same implementation | ||
assert correctness == 1.0 | ||
# Performance should be around 1.0 (same speed) | ||
assert performance.item() > 0 |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.