@@ -348,24 +348,27 @@ def _convert_old_args(self, args):
348
348
return aux
349
349
350
350
351
- class NonDaemonPool (multiprocessing .pool .Pool ):
352
- def Process (self , * args , ** kwds ):
353
- proc = super (NonDaemonPool , self ).Process (* args , ** kwds )
354
-
355
- class NonDaemonProcess (proc .__class__ ):
356
- """Monkey-patch process to ensure it is never daemonized"""
357
-
358
- @property
359
- def daemon (self ):
360
- return False
351
+ class NonDaemonProcess (multiprocessing .context .SpawnProcess ):
352
+ """Processes must use 'spawn' instead of 'fork' (which is the default
353
+ in Linux) in order to work CUDA [1] or Tensorflow [2].
354
+
355
+ [1] https://pytorch.org/docs/stable/notes/multiprocessing.html
356
+ #cuda-in-multiprocessing
357
+ [2] https://github.com/tensorflow/tensorflow/issues/5448
358
+ #issuecomment-258934405
359
+ """
360
+ @property
361
+ def daemon (self ):
362
+ return False
361
363
362
- @daemon .setter
363
- def daemon (self , val ):
364
- pass
364
+ @daemon .setter
365
+ def daemon (self , value ):
366
+ pass
365
367
366
- proc .__class__ = NonDaemonProcess
367
368
368
- return proc
369
+ class NonDaemonPool (multiprocessing .pool .Pool ):
370
+ # Based on https://stackoverflow.com/questions/6974695/
371
+ Process = NonDaemonProcess
369
372
370
373
371
374
class CancellablePool (object ):
@@ -375,7 +378,7 @@ def __init__(self, max_workers=None):
375
378
self ._change = asyncio .Event ()
376
379
377
380
def _new_pool (self ):
378
- return NonDaemonPool (1 )
381
+ return NonDaemonPool (1 , context = multiprocessing . get_context ( 'spawn' ) )
379
382
380
383
async def apply (self , fn , * args ):
381
384
"""
0 commit comments