Skip to content

Commit 79ed562

Browse files
authored
Spawning isolated processes for each test (#67)
1 parent 4acfe6c commit 79ed562

File tree

4 files changed

+393
-37
lines changed

4 files changed

+393
-37
lines changed

BackendBench/multiprocessing_eval.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
# The module contains multiprocessing evaluation for BackendBench.
9+
# It is used to recover from CUDA errors.
10+
# Example usage:
11+
#
12+
# with multiprocessing_eval.MultiprocessingEvaluator(num_workers) as evaluator:
13+
# for test in suite:
14+
# evaluator.submit_task(
15+
# test.op, backend[test.op], test.correctness_tests, test.performance_tests
16+
# )
17+
# evaluator.start_evaluation()
18+
# results = evaluator.get_results()
19+
20+
import logging
21+
from dataclasses import dataclass
22+
import multiprocessing as mp
23+
import time
24+
import queue
25+
import traceback
26+
from typing import Any, List, Optional
27+
28+
import torch
29+
30+
from BackendBench.eval import eval_one_op
31+
from BackendBench.opregistry import get_operator, _extract_spec_name_from_op
32+
33+
logger = logging.getLogger(__name__)
34+
35+
36+
@dataclass
37+
class EvalTask:
38+
"""Task for multiprocessing evaluation."""
39+
40+
task_id: int
41+
op: Any
42+
impl: Any
43+
correctness_tests: List[Any]
44+
performance_tests: List[Any]
45+
46+
47+
@dataclass
48+
class EvalResult:
49+
"""Result from multiprocessing evaluation."""
50+
51+
task_id: int
52+
correctness_score: float
53+
performance_score: float
54+
error: Optional[str] = None
55+
56+
57+
@dataclass
58+
class ProcessDeathSignal:
59+
"""Signal indicating a process has died."""
60+
61+
worker_id: int
62+
error_msg: str
63+
64+
65+
def is_pickleable(obj):
66+
import pickle
67+
import io
68+
69+
try:
70+
with io.BytesIO() as stream:
71+
pickle.dump(obj, stream)
72+
return True
73+
except Exception:
74+
return False
75+
76+
77+
def _worker_process(worker_id, task_queue, result_queue):
78+
try:
79+
torch.cuda.set_device(worker_id)
80+
torch.cuda.synchronize()
81+
torch.cuda.empty_cache()
82+
83+
while True:
84+
try:
85+
task = task_queue.get(block=False)
86+
87+
if task is None:
88+
logger.info(f"Worker {worker_id} received shutdown signal")
89+
break
90+
91+
# Process the task
92+
logger.debug(f"Worker {worker_id} processing task {task.task_id}")
93+
94+
try:
95+
op = task.op
96+
if isinstance(op, str):
97+
op = get_operator(op)
98+
impl = task.impl
99+
if isinstance(impl, str):
100+
impl = get_operator(impl)
101+
102+
correctness_score, performance_score = eval_one_op(
103+
op, impl, task.correctness_tests, task.performance_tests
104+
)
105+
result = EvalResult(
106+
task_id=task.task_id,
107+
correctness_score=correctness_score,
108+
performance_score=performance_score,
109+
)
110+
except Exception as e:
111+
error_msg = f"Error in eval_one_op: {str(e)}\n{traceback.format_exc()}"
112+
logger.warning(f"Worker {worker_id} task {task.task_id} failed: {error_msg}")
113+
if "cuda" in str(e).lower(): # CUDA error
114+
error_msg = (
115+
f"Worker {worker_id} CUDA error: {str(e)}\n{traceback.format_exc()}"
116+
)
117+
logger.error(error_msg)
118+
result_queue.put(ProcessDeathSignal(worker_id, error_msg))
119+
break
120+
result = EvalResult(
121+
task_id=task.task_id,
122+
correctness_score=0.0,
123+
performance_score=1.0,
124+
error=error_msg,
125+
)
126+
127+
# Put result in result queue
128+
result_queue.put(result)
129+
130+
except queue.Empty:
131+
time.sleep(0.1)
132+
continue
133+
except Exception as e:
134+
# Unexpected error in worker loop
135+
error_msg = f"Worker {worker_id} loop error: {str(e)}\n{traceback.format_exc()}"
136+
logger.error(error_msg)
137+
result_queue.put(ProcessDeathSignal(worker_id, error_msg))
138+
break
139+
140+
except Exception as e:
141+
error_msg = f"Worker {worker_id} fatal error: {str(e)}\n{traceback.format_exc()}"
142+
logger.error(error_msg)
143+
result_queue.put(ProcessDeathSignal(worker_id, error_msg))
144+
finally:
145+
torch.cuda.synchronize()
146+
torch.cuda.empty_cache()
147+
148+
logger.info(f"Worker {worker_id} exiting")
149+
150+
151+
class MultiprocessingEvaluator:
152+
def __init__(self, num_workers: int = 1):
153+
assert num_workers <= torch.cuda.device_count(), "performance will be suboptimal"
154+
155+
self.mp_context = mp.get_context("spawn")
156+
self.num_workers = num_workers
157+
self.task_queue = self.mp_context.Queue()
158+
self.result_queue = self.mp_context.Queue()
159+
self.workers = {}
160+
self.next_task_id = 0
161+
self.next_worker_id = 0
162+
self.total_tasks = 0
163+
self.completed_tasks = 0
164+
165+
logger.info(f"Initialized MultiprocessingEvaluator with {num_workers} workers")
166+
167+
def submit_task(self, op, impl, correctness_tests, performance_tests) -> int:
168+
task_id = self.next_task_id
169+
self.next_task_id += 1
170+
171+
if not is_pickleable(op):
172+
op = _extract_spec_name_from_op(op)
173+
if not is_pickleable(impl):
174+
impl = _extract_spec_name_from_op(impl)
175+
176+
task = EvalTask(
177+
task_id=task_id,
178+
op=op,
179+
impl=impl,
180+
correctness_tests=list(correctness_tests),
181+
performance_tests=list(performance_tests),
182+
)
183+
184+
self.task_queue.put(task)
185+
self.total_tasks += 1
186+
187+
logger.debug(f"Submitted task {task_id} for {getattr(op, '__name__', str(op))}")
188+
return task_id
189+
190+
def _start_worker(self, worker_id):
191+
process = self.mp_context.Process(
192+
target=_worker_process,
193+
args=(worker_id, self.task_queue, self.result_queue),
194+
daemon=True,
195+
)
196+
process.start()
197+
self.workers[worker_id] = process
198+
199+
logger.info(f"Started worker {worker_id} (PID: {process.pid}, GPU: {worker_id})")
200+
201+
def _restart_worker(self, worker_id):
202+
"""Restart a dead worker process."""
203+
# Clean up old process
204+
if worker_id in self.workers:
205+
old_process = self.workers[worker_id]
206+
if old_process.is_alive():
207+
old_process.terminate()
208+
old_process.join(timeout=5)
209+
del self.workers[worker_id]
210+
211+
# Start new process with the same worker_id
212+
process = self.mp_context.Process(
213+
target=_worker_process,
214+
args=(worker_id, self.task_queue, self.result_queue),
215+
daemon=True,
216+
)
217+
process.start()
218+
self.workers[worker_id] = process
219+
220+
logger.warning(f"Restarted worker {worker_id} (PID: {process.pid}, GPU: {worker_id})")
221+
222+
def start_evaluation(self) -> None:
223+
"""Start all worker processes to begin evaluation."""
224+
logger.info("Starting multiprocessing evaluation...")
225+
226+
# Start all workers
227+
for i in range(self.num_workers):
228+
self._start_worker(i)
229+
230+
def get_results(self):
231+
results = []
232+
233+
while self.completed_tasks < self.total_tasks:
234+
try:
235+
# Get result from queue
236+
result = self.result_queue.get(block=False)
237+
logger.info(f"Result obtained: {result}")
238+
239+
if isinstance(result, ProcessDeathSignal):
240+
self.completed_tasks += 1
241+
# Worker died, restart it
242+
logger.error(f"Worker {result.worker_id} died: {result.error_msg}")
243+
self._restart_worker(result.worker_id)
244+
continue
245+
246+
if isinstance(result, EvalResult):
247+
results.append(result)
248+
self.completed_tasks += 1
249+
250+
if result.error:
251+
logger.warning(
252+
f"Task {result.task_id} completed with error: {result.error}"
253+
)
254+
else:
255+
logger.debug(f"Task {result.task_id} completed successfully")
256+
except queue.Empty:
257+
time.sleep(0.1)
258+
continue
259+
260+
except Exception as e:
261+
logger.error(f"Error getting results: {e}/n{traceback.format_exc()}")
262+
break
263+
264+
# Sort results by task_id to maintain order
265+
results.sort(key=lambda r: r.task_id)
266+
267+
logger.info(f"Collected {len(results)} results out of {self.total_tasks} tasks")
268+
return results
269+
270+
def shutdown(self) -> None:
271+
"""Shutdown all worker processes."""
272+
logger.info("Shutting down multiprocessing evaluator...")
273+
274+
for _ in range(self.num_workers):
275+
self.task_queue.put(None)
276+
277+
# Wait for workers to finish
278+
for worker_id, process in list(self.workers.items()):
279+
try:
280+
process.join(timeout=5)
281+
if process.is_alive():
282+
logger.warning(f"Force terminating worker {worker_id}")
283+
process.terminate()
284+
process.join(timeout=2)
285+
except Exception as e:
286+
logger.error(f"Error shutting down worker {worker_id}: {e}")
287+
288+
torch.cuda.synchronize()
289+
torch.cuda.empty_cache()
290+
291+
self.workers.clear()
292+
logger.info("Multiprocessing evaluator shutdown complete")
293+
294+
def __enter__(self):
295+
return self
296+
297+
def __exit__(self, exc_type, exc_val, exc_tb):
298+
self.shutdown()

BackendBench/scripts/main.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import BackendBench.backends as backends
1313
import BackendBench.eval as eval
14+
import BackendBench.multiprocessing_eval as multiprocessing_eval
1415
import click
1516
import torch
1617

@@ -104,6 +105,12 @@ def setup_logging(log_level):
104105
type=str,
105106
help="Path to directory containing generated kernels",
106107
)
108+
@click.option(
109+
"--num-workers",
110+
default=None,
111+
type=int,
112+
help="Number of workers to use for multiprocessing, default to None to disable multiprocessing)",
113+
)
107114
def cli(
108115
log_level,
109116
suite,
@@ -116,6 +123,7 @@ def cli(
116123
kernel_agent_max_rounds,
117124
torchbench_data_path,
118125
ops_directory,
126+
num_workers,
119127
):
120128
setup_logging(log_level)
121129
if ops:
@@ -177,22 +185,47 @@ def cli(
177185
overall_correctness = []
178186
overall_performance = []
179187

180-
for test in suite:
181-
if test.op not in backend:
182-
continue
188+
if num_workers is None:
189+
for test in suite:
190+
if test.op not in backend:
191+
continue
183192

184-
logger.debug(test.op)
193+
logger.debug(test.op)
185194

186-
correctness, perf = eval.eval_one_op(
187-
test.op,
188-
backend[test.op],
189-
test.correctness_tests,
190-
test.performance_tests,
191-
)
192-
overall_correctness.append(correctness)
193-
overall_performance.append(perf)
195+
correctness, perf = eval.eval_one_op(
196+
test.op,
197+
backend[test.op],
198+
test.correctness_tests,
199+
test.performance_tests,
200+
)
201+
overall_correctness.append(correctness)
202+
overall_performance.append(perf)
203+
204+
logger.debug(f"max memory allocated: {torch.cuda.max_memory_allocated():,}")
205+
else:
206+
with multiprocessing_eval.MultiprocessingEvaluator(num_workers) as evaluator:
207+
# Submit all tasks
208+
for test in suite:
209+
if test.op not in backend:
210+
continue
211+
212+
logger.debug(test.op)
213+
214+
evaluator.submit_task(
215+
test.op, backend[test.op], test.correctness_tests, test.performance_tests
216+
)
217+
218+
# Start evaluation
219+
evaluator.start_evaluation()
220+
221+
# Get results
222+
results = evaluator.get_results()
194223

195-
logger.debug(f"max memory allocated: {torch.cuda.max_memory_allocated():,}")
224+
for result in results:
225+
correctness_score = result.correctness_score
226+
performance_score = result.performance_score
227+
overall_correctness.append(correctness_score)
228+
overall_performance.append(performance_score)
196229

197230
mean_correctness = torch.tensor(overall_correctness).mean().item()
198231
geomean_perf = torch.tensor(overall_performance).log().mean().exp().item()

0 commit comments

Comments
 (0)