11# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/03a_parallel.ipynb (unless otherwise specified).
22
3- __all__ = ['threaded' , 'startthread' , 'set_num_threads' , 'check_parallel_num ' , 'ThreadPoolExecutor' ,
4- 'ProcessPoolExecutor ' , 'parallel ' , 'run_procs' , 'parallel_gen' ]
3+ __all__ = ['threaded' , 'startthread' , 'set_num_threads' , 'parallelable ' , 'ThreadPoolExecutor' , 'ProcessPoolExecutor ' ,
4+ 'parallel ' , 'add_one ' , 'run_procs' , 'parallel_gen' ]
55
66# Cell
77from .imports import *
@@ -58,12 +58,13 @@ def _call(lock, pause, n, g, item):
5858 return g (item )
5959
6060# Cell
61- def check_parallel_num (param_name , num_workers ):
62- if sys .platform == "win32" and IN_NOTEBOOK and num_workers > 0 :
61+ def parallelable (param_name , num_workers , f = None ):
62+ f_in_main = f == None or sys .modules [f .__module__ ].__name__ == "__main__"
63+ if sys .platform == "win32" and IN_NOTEBOOK and num_workers > 0 and f_in_main :
6364 print ("Due to IPython and Windows limitation, python multiprocessing isn't available now." )
64- print (f"So `{ param_name } ` is changed to 0 to avoid getting stuck" )
65- num_workers = 0
66- return num_workers
65+ print (f"So `{ param_name } ` has to be changed to 0 to avoid getting stuck" )
66+ return False
67+ return True
6768
6869# Cell
6970class ThreadPoolExecutor (concurrent .futures .ThreadPoolExecutor ):
@@ -88,13 +89,16 @@ class ProcessPoolExecutor(concurrent.futures.ProcessPoolExecutor):
8889 "Same as Python's ProcessPoolExecutor, except can pass `max_workers==0` for serial execution"
8990 def __init__ (self , max_workers = defaults .cpus , on_exc = print , pause = 0 , ** kwargs ):
9091 if max_workers is None : max_workers = defaults .cpus
91- max_workers = check_parallel_num ('max_workers' , max_workers )
9292 store_attr ()
9393 self .not_parallel = max_workers == 0
9494 if self .not_parallel : max_workers = 1
9595 super ().__init__ (max_workers , ** kwargs )
9696
9797 def map (self , f , items , * args , timeout = None , chunksize = 1 , ** kwargs ):
98+ if not parallelable ('max_workers' , self .max_workers , f ): self .max_workers = 0
99+ self .not_parallel = self .max_workers == 0
100+ if self .not_parallel : self .max_workers = 1
101+
98102 if self .not_parallel == False : self .lock = Manager ().Lock ()
99103 g = partial (f , * args , ** kwargs )
100104 if self .not_parallel : return map (g , items )
@@ -118,6 +122,13 @@ def parallel(f, items, *args, n_workers=defaults.cpus, total=None, progress=None
118122 r = progress_bar (r , total = total , leave = False )
119123 return L (r )
120124
125+ # Cell
126+ def add_one (x , a = 1 ):
127+ # this import is necessary for multiprocessing in notebook on windows
128+ import random
129+ time .sleep (random .random ()/ 80 )
130+ return x + a
131+
121132# Cell
122133def run_procs (f , f_done , args ):
123134 "Call `f` for each item in `args` in parallel, yielding `f_done`"
@@ -135,7 +146,7 @@ def _done_pg(queue, items): return (queue.get() for _ in items)
135146# Cell
136147def parallel_gen (cls , items , n_workers = defaults .cpus , ** kwargs ):
137148 "Instantiate `cls` in `n_workers` procs & call each on a subset of `items` in parallel."
138- n_workers = check_parallel_num ('n_workers' , n_workers )
149+ if not parallelable ('n_workers' , n_workers ): n_workers = 0
139150 if n_workers == 0 :
140151 yield from enumerate (list (cls (** kwargs )(items )))
141152 return
0 commit comments