Skip to content

Commit e2b5edd

Browse files
ignacioalvarolopez
authored andcommitted
use spawn for training pool
1 parent 64f501f commit e2b5edd

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

deepaas/model/v2/wrapper.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -348,24 +348,27 @@ def _convert_old_args(self, args):
348348
return aux
349349

350350

351-
class NonDaemonPool(multiprocessing.pool.Pool):
352-
def Process(self, *args, **kwds):
353-
proc = super(NonDaemonPool, self).Process(*args, **kwds)
354-
355-
class NonDaemonProcess(proc.__class__):
356-
"""Monkey-patch process to ensure it is never daemonized"""
357-
358-
@property
359-
def daemon(self):
360-
return False
351+
class NonDaemonProcess(multiprocessing.context.SpawnProcess):
352+
"""Processes must use 'spawn' instead of 'fork' (which is the default
353+
in Linux) in order to work CUDA [1] or Tensorflow [2].
354+
355+
[1] https://pytorch.org/docs/stable/notes/multiprocessing.html
356+
#cuda-in-multiprocessing
357+
[2] https://github.com/tensorflow/tensorflow/issues/5448
358+
#issuecomment-258934405
359+
"""
360+
@property
361+
def daemon(self):
362+
return False
361363

362-
@daemon.setter
363-
def daemon(self, val):
364-
pass
364+
@daemon.setter
365+
def daemon(self, value):
366+
pass
365367

366-
proc.__class__ = NonDaemonProcess
367368

368-
return proc
369+
class NonDaemonPool(multiprocessing.pool.Pool):
370+
# Based on https://stackoverflow.com/questions/6974695/
371+
Process = NonDaemonProcess
369372

370373

371374
class CancellablePool(object):
@@ -375,7 +378,7 @@ def __init__(self, max_workers=None):
375378
self._change = asyncio.Event()
376379

377380
def _new_pool(self):
378-
return NonDaemonPool(1)
381+
return NonDaemonPool(1, context=multiprocessing.get_context('spawn'))
379382

380383
async def apply(self, fn, *args):
381384
"""

0 commit comments

Comments
 (0)