Skip to content

Commit 1ea79a3

Browse files
committed
parallel in win notebook
1 parent 875988a commit 1ea79a3

File tree

4 files changed

+177
-71
lines changed

4 files changed

+177
-71
lines changed

fastcore/_nbdev.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,11 @@
182182
"threaded": "03a_parallel.ipynb",
183183
"startthread": "03a_parallel.ipynb",
184184
"set_num_threads": "03a_parallel.ipynb",
185-
"check_parallel_num": "03a_parallel.ipynb",
185+
"parallelable": "03a_parallel.ipynb",
186186
"ThreadPoolExecutor": "03a_parallel.ipynb",
187187
"ProcessPoolExecutor": "03a_parallel.ipynb",
188188
"parallel": "03a_parallel.ipynb",
189+
"add_one": "03a_parallel.ipynb",
189190
"run_procs": "03a_parallel.ipynb",
190191
"parallel_gen": "03a_parallel.ipynb",
191192
"url_default_headers": "03b_net.ipynb",

fastcore/parallel.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/03a_parallel.ipynb (unless otherwise specified).
22

3-
__all__ = ['threaded', 'startthread', 'set_num_threads', 'check_parallel_num', 'ThreadPoolExecutor',
4-
'ProcessPoolExecutor', 'parallel', 'run_procs', 'parallel_gen']
3+
__all__ = ['threaded', 'startthread', 'set_num_threads', 'parallelable', 'ThreadPoolExecutor', 'ProcessPoolExecutor',
4+
'parallel', 'add_one', 'run_procs', 'parallel_gen']
55

66
# Cell
77
from .imports import *
@@ -58,12 +58,13 @@ def _call(lock, pause, n, g, item):
5858
return g(item)
5959

6060
# Cell
61-
def check_parallel_num(param_name, num_workers):
62-
if sys.platform == "win32" and IN_NOTEBOOK and num_workers > 0:
61+
def parallelable(param_name, num_workers, f=None):
62+
f_in_main = f == None or sys.modules[f.__module__].__name__ == "__main__"
63+
if sys.platform == "win32" and IN_NOTEBOOK and num_workers > 0 and f_in_main:
6364
print("Due to IPython and Windows limitation, python multiprocessing isn't available now.")
64-
print(f"So `{param_name}` is changed to 0 to avoid getting stuck")
65-
num_workers = 0
66-
return num_workers
65+
print(f"So `{param_name}` has to be changed to 0 to avoid getting stuck")
66+
return False
67+
return True
6768

6869
# Cell
6970
class ThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
@@ -88,13 +89,16 @@ class ProcessPoolExecutor(concurrent.futures.ProcessPoolExecutor):
8889
"Same as Python's ProcessPoolExecutor, except can pass `max_workers==0` for serial execution"
8990
def __init__(self, max_workers=defaults.cpus, on_exc=print, pause=0, **kwargs):
9091
if max_workers is None: max_workers=defaults.cpus
91-
max_workers = check_parallel_num('max_workers', max_workers)
9292
store_attr()
9393
self.not_parallel = max_workers==0
9494
if self.not_parallel: max_workers=1
9595
super().__init__(max_workers, **kwargs)
9696

9797
def map(self, f, items, *args, timeout=None, chunksize=1, **kwargs):
98+
if not parallelable('max_workers', self.max_workers, f): self.max_workers = 0
99+
self.not_parallel = self.max_workers==0
100+
if self.not_parallel: self.max_workers=1
101+
98102
if self.not_parallel == False: self.lock = Manager().Lock()
99103
g = partial(f, *args, **kwargs)
100104
if self.not_parallel: return map(g, items)
@@ -118,6 +122,13 @@ def parallel(f, items, *args, n_workers=defaults.cpus, total=None, progress=None
118122
r = progress_bar(r, total=total, leave=False)
119123
return L(r)
120124

125+
# Cell
126+
def add_one(x, a=1):
127+
# this import is necessary for multiprocessing in notebook on windows
128+
import random
129+
time.sleep(random.random()/80)
130+
return x+a
131+
121132
# Cell
122133
def run_procs(f, f_done, args):
123134
"Call `f` for each item in `args` in parallel, yielding `f_done`"
@@ -135,7 +146,7 @@ def _done_pg(queue, items): return (queue.get() for _ in items)
135146
# Cell
136147
def parallel_gen(cls, items, n_workers=defaults.cpus, **kwargs):
137148
"Instantiate `cls` in `n_workers` procs & call each on a subset of `items` in parallel."
138-
n_workers = check_parallel_num('n_workers', n_workers)
149+
if not parallelable('n_workers', n_workers): n_workers = 0
139150
if n_workers==0:
140151
yield from enumerate(list(cls(**kwargs)(items)))
141152
return

0 commit comments

Comments
 (0)