Skip to content

Commit d0ff8c6

Browse files
authored
Using cpu roundtrip to avoid cuda OOM (#101)
1 parent 289d8c6 commit d0ff8c6

File tree

1 file changed

+60
-3
lines changed

1 file changed

+60
-3
lines changed

BackendBench/multiprocessing_eval.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class EvalTask:
4242
impl: Any
4343
correctness_tests: List[Any]
4444
performance_tests: List[Any]
45+
device: str
4546

4647

4748
@dataclass
@@ -100,8 +101,17 @@ def _worker_process(worker_id, task_queue, result_queue):
100101
if isinstance(impl, str):
101102
impl = get_operator(impl)
102103

104+
device = torch.device(task.device)
105+
106+
def test_to_device_iterator(tests, device):
107+
for test in tests:
108+
yield test_to_device(test, device)
109+
110+
correctness_tests = test_to_device_iterator(task.correctness_tests, device)
111+
performance_tests = test_to_device_iterator(task.performance_tests, device)
112+
103113
correctness_score, performance_score, test_data = eval_one_op(
104-
op, impl, task.correctness_tests, task.performance_tests
114+
op, impl, correctness_tests, performance_tests
105115
)
106116
result = EvalResult(
107117
task_id=task.task_id,
@@ -118,6 +128,8 @@ def _worker_process(worker_id, task_queue, result_queue):
118128
)
119129
logger.error(error_msg)
120130
result_queue.put(ProcessDeathSignal(worker_id, error_msg))
131+
torch.cuda.synchronize()
132+
torch.cuda.empty_cache()
121133
break
122134
result = EvalResult(
123135
task_id=task.task_id,
@@ -158,6 +170,37 @@ def _worker_process(worker_id, task_queue, result_queue):
158170
logger.info(f"Worker {worker_id} exiting")
159171

160172

173+
def args_to_device(value, device):
174+
if isinstance(value, torch.Tensor):
175+
return value.to(device)
176+
elif isinstance(value, list):
177+
return [args_to_device(item, device) for item in value]
178+
elif isinstance(value, tuple):
179+
return tuple(args_to_device(item, device) for item in value)
180+
elif isinstance(value, dict):
181+
return {key: args_to_device(item, device) for key, item in value.items()}
182+
else:
183+
return value
184+
185+
186+
def find_device(test):
187+
if isinstance(test, torch.Tensor):
188+
return test.device
189+
elif isinstance(test, list):
190+
for item in test:
191+
return find_device(item)
192+
elif isinstance(test, dict):
193+
for item in test.values():
194+
return find_device(item)
195+
return None
196+
197+
198+
def test_to_device(test, device):
199+
test.args = args_to_device(test.args, device)
200+
test.kwargs = args_to_device(test.kwargs, device)
201+
return test
202+
203+
161204
class MultiprocessingEvaluator:
162205
def __init__(self, num_workers: int = 1):
163206
assert num_workers <= torch.cuda.device_count(), "performance will be suboptimal"
@@ -183,12 +226,26 @@ def submit_task(self, op, impl, correctness_tests, performance_tests) -> int:
183226
if not is_pickleable(impl):
184227
impl = _extract_spec_name_from_op(impl)
185228

229+
orig_device = None
230+
cpu_correctness_tests = []
231+
for test in correctness_tests:
232+
if orig_device is None:
233+
orig_device = find_device(test)
234+
cpu_correctness_tests.append(test_to_device(test, torch.device("cpu")))
235+
if orig_device is None:
236+
orig_device = torch.device("cuda")
237+
238+
cpu_performance_tests = []
239+
for test in performance_tests:
240+
cpu_performance_tests.append(test_to_device(test, torch.device("cpu")))
241+
186242
task = EvalTask(
187243
task_id=task_id,
188244
op=op,
189245
impl=impl,
190-
correctness_tests=list(correctness_tests),
191-
performance_tests=list(performance_tests),
246+
correctness_tests=cpu_correctness_tests,
247+
performance_tests=cpu_performance_tests,
248+
device=str(orig_device),
192249
)
193250

194251
self.task_queue.put(task)

0 commit comments

Comments
 (0)