Skip to content

Commit 4619d24

Browse files
Improve performance
Change-Id: I5ab455421da440fc36c38611c920fd13bbe6a6c2
1 parent ccdca61 commit 4619d24

File tree

5 files changed

+104
-61
lines changed

5 files changed

+104
-61
lines changed

bluepyparallel/evaluator.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import traceback
55
from functools import partial
66

7+
import pandas as pd
78
from tqdm import tqdm
89

910
from bluepyparallel.database import DataBase
@@ -19,7 +20,7 @@ def _try_evaluation(task, evaluation_function, func_args, func_kwargs):
1920
result = evaluation_function(task_args, *func_args, **func_kwargs)
2021
exception = None
2122
except Exception: # pylint: disable=broad-except
22-
result = None
23+
result = {}
2324
exception = "".join(traceback.format_exception(*sys.exc_info()))
2425
logger.exception("Exception for ID=%s: %s", task_id, exception)
2526
return task_id, result, exception
@@ -34,6 +35,7 @@ def evaluate(
3435
db_url=None,
3536
func_args=None,
3637
func_kwargs=None,
38+
**mapper_kwargs,
3739
):
3840
"""Evaluate and save results in a sqlite database on the fly and return dataframe.
3941
@@ -54,14 +56,15 @@ def evaluate(
5456
SQL database.
5557
func_args (list): the arguments to pass to the evaluation_function.
5658
func_kwargs (dict): the keyword arguments to pass to the evaluation_function.
59+
**mapper_kwargs: the keyword arguments are passed to the get_mapper() method of the
60+
:class:`ParallelFactory` instance.
5761
5862
Return:
5963
pandas.DataFrame: dataframe with new columns containing the computed results.
6064
"""
6165
# Initialize the parallel factory
6266
if isinstance(parallel_factory, str) or parallel_factory is None:
6367
parallel_factory = init_parallel_factory(parallel_factory)
64-
6568
# Set default args
6669
if func_args is None:
6770
func_args = []
@@ -74,6 +77,10 @@ def evaluate(
7477
to_evaluate = df.copy()
7578
task_ids = to_evaluate.index
7679

80+
if "exception" in to_evaluate.columns:
81+
logger.warning("The exception column is going to be replaced")
82+
to_evaluate = to_evaluate.drop(columns=["exception"])
83+
7784
# Set default new columns
7885
if new_columns is None:
7986
new_columns = [["data", ""]]
@@ -120,7 +127,7 @@ def evaluate(
120127
return to_evaluate
121128

122129
# Get the factory mapper
123-
mapper = parallel_factory.get_mapper()
130+
mapper = parallel_factory.get_mapper(**mapper_kwargs)
124131

125132
# Setup the function to apply to the data
126133
eval_func = partial(
@@ -134,18 +141,23 @@ def evaluate(
134141
arg_list = list(to_evaluate.loc[task_ids, df.columns].to_dict("index").items())
135142

136143
try:
137-
for task_id, results, exception in tqdm(mapper(eval_func, arg_list), total=len(task_ids)):
144+
res = []
145+
146+
# Collect the results
147+
for task_id, result, exception in tqdm(mapper(eval_func, arg_list), total=len(task_ids)):
148+
res.append(dict({"df_index": task_id, "exception": exception}, **result))
149+
138150
# Save the results into the DB
139151
if db is not None:
140152
db.write(
141-
task_id, results, exception, **to_evaluate.loc[task_id, df.columns].to_dict()
153+
task_id, result, exception, **to_evaluate.loc[task_id, df.columns].to_dict()
142154
)
143155

144-
# Save the results into the DataFrame
145-
if results is not None:
146-
to_evaluate.loc[task_id, results.keys()] = list(results.values())
147-
elif exception is not None:
148-
to_evaluate.loc[task_id, "exception"] = exception
156+
# Gather the results to the output DataFrame
157+
res_df = pd.DataFrame(res)
158+
res_df.set_index("df_index", inplace=True)
159+
to_evaluate.loc[res_df.index, res_df.columns] = res_df
160+
149161
except (KeyboardInterrupt, SystemExit) as ex:
150162
# To save dataframe even if program is killed
151163
logger.warning("Stopping mapper loop. Reason: %r", ex)

bluepyparallel/parallel.py

Lines changed: 60 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,47 @@ class ParallelFactory:
3333
"""Abstract class that should be subclassed to provide parallel functions."""
3434

3535
_BATCH_SIZE = "PARALLEL_BATCH_SIZE"
36+
_CHUNK_SIZE = "PARALLEL_CHUNK_SIZE"
3637

37-
def __init__(self, *args, batch_size=None, **kwargs): # pylint: disable=unused-argument
38+
# pylint: disable=unused-argument
39+
def __init__(self, batch_size=None, chunk_size=None, **kwargs):
3840
self.batch_size = batch_size or int(os.getenv(self._BATCH_SIZE, "0")) or None
39-
self.nb_processes = 1
4041
L.info("Using %s=%s", self._BATCH_SIZE, self.batch_size)
4142

43+
self.chunk_size = batch_size or int(os.getenv(self._CHUNK_SIZE, "0")) or None
44+
L.info("Using %s=%s", self._CHUNK_SIZE, self.chunk_size)
45+
46+
self.nb_processes = 1
47+
4248
@abstractmethod
43-
def get_mapper(self):
49+
def get_mapper(self, batch_size=None, chunk_size=None, **kwargs):
4450
"""Return a mapper function that can be used to execute functions in parallel."""
4551

4652
def shutdown(self):
4753
"""Can be used to cleanup."""
4854

55+
def _with_batches(self, mapper, func, iterable, batch_size=None):
56+
"""Wrapper on mapper function creating batches of iterable to give to mapper.
57+
58+
The batch_size is an int corresponding to the number of evaluation in each batch/
59+
"""
60+
if isinstance(iterable, Iterator):
61+
iterable = list(iterable)
62+
63+
batch_size = batch_size or self.batch_size
64+
if batch_size is not None:
65+
iterables = np.array_split(iterable, len(iterable) // min(batch_size, len(iterable)))
66+
else:
67+
iterables = [iterable]
68+
69+
for _iterable in iterables:
70+
yield from mapper(func, _iterable)
71+
72+
def _chunksize_to_kwargs(self, chunk_size, kwargs, label="chunk_size"):
73+
chunk_size = chunk_size or self.chunk_size
74+
if chunk_size is not None:
75+
kwargs[label] = chunk_size
76+
4977

5078
class NoDaemonProcess(multiprocessing.Process):
5179
"""Class that represents a non-daemon process"""
@@ -72,26 +100,10 @@ class NestedPool(Pool): # pylint: disable=abstract-method
72100
Process = NoDaemonProcess
73101

74102

75-
def _with_batches(mapper, func, iterable, batch_size=None):
76-
"""Wrapper on mapper function creating batches of iterable to give to mapper.
77-
78-
The batch_size is an int corresponding to the number of evaluation in each batch/
79-
"""
80-
if isinstance(iterable, Iterator):
81-
iterable = list(iterable)
82-
if batch_size is not None:
83-
iterables = np.array_split(iterable, len(iterable) // min(batch_size, len(iterable)))
84-
else:
85-
iterables = [iterable]
86-
87-
for _iterable in iterables:
88-
yield from mapper(func, _iterable)
89-
90-
91103
class SerialFactory(ParallelFactory):
92104
"""Factory that do not work in parallel."""
93105

94-
def get_mapper(self):
106+
def get_mapper(self, batch_size=None, chunk_size=None, **kwargs):
95107
"""Get a map."""
96108
return map
97109

@@ -101,18 +113,24 @@ class MultiprocessingFactory(ParallelFactory):
101113

102114
_CHUNKSIZE = "PARALLEL_CHUNKSIZE"
103115

104-
def __init__(self, *args, processes=None, **kwargs):
116+
def __init__(self, processes=None, **kwargs):
105117
"""Initialize multiprocessing factory."""
106118

107-
super().__init__()
108-
self.pool = NestedPool(*args, **kwargs)
119+
super().__init__(**kwargs)
120+
121+
self.pool = NestedPool(processes=processes)
109122
self.nb_processes = processes or os.cpu_count()
110123

111-
def get_mapper(self):
124+
def get_mapper(self, batch_size=None, chunk_size=None, **kwargs):
112125
"""Get a NestedPool."""
126+
self._chunksize_to_kwargs(chunk_size, kwargs)
113127

114128
def _mapper(func, iterable):
115-
return _with_batches(self.pool.imap_unordered, func, iterable, self.batch_size)
129+
return self._with_batches(
130+
partial(self.pool.imap_unordered, **kwargs),
131+
func,
132+
iterable,
133+
)
116134

117135
return _mapper
118136

@@ -126,24 +144,29 @@ class IPyParallelFactory(ParallelFactory):
126144

127145
_IPYTHON_PROFILE = "IPYTHON_PROFILE"
128146

129-
def __init__(self, *args, **kwargs):
147+
def __init__(self, **kwargs):
130148
"""Initialize the ipyparallel factory."""
131149

132-
super().__init__()
150+
super().__init__(**kwargs)
133151
self.rc = None
134152
self.nb_processes = 1
135153

136-
def get_mapper(self):
154+
def get_mapper(self, batch_size=None, chunk_size=None, **kwargs):
137155
"""Get an ipyparallel mapper using the profile name provided."""
138-
profile = os.getenv(self._IPYTHON_PROFILE, "DEFAULT_IPYTHON_PROFILE")
156+
profile = os.getenv(self._IPYTHON_PROFILE, None)
139157
L.debug("Using %s=%s", self._IPYTHON_PROFILE, profile)
140158
self.rc = ipyparallel.Client(profile=profile)
141159
self.nb_processes = len(self.rc.ids)
142160
lview = self.rc.load_balanced_view()
143161

162+
if "ordered" not in kwargs:
163+
kwargs["ordered"] = False
164+
165+
self._chunksize_to_kwargs(chunk_size, kwargs)
166+
144167
def _mapper(func, iterable):
145-
return _with_batches(
146-
partial(lview.imap, ordered=False), func, iterable, self.batch_size
168+
return self._with_batches(
169+
partial(lview.imap, **kwargs), func, iterable, batch_size=batch_size
147170
)
148171

149172
return _mapper
@@ -159,7 +182,7 @@ class DaskFactory(ParallelFactory):
159182

160183
_SCHEDULER_PATH = "PARALLEL_DASK_SCHEDULER_PATH"
161184

162-
def __init__(self, *args, **kwargs):
185+
def __init__(self, **kwargs):
163186
"""Initialize the dask factory."""
164187
dask_scheduler_path = os.getenv(self._SCHEDULER_PATH)
165188
if dask_scheduler_path:
@@ -172,7 +195,7 @@ def __init__(self, *args, **kwargs):
172195
L.info("Starting dask_mpi...")
173196
self.client = dask.distributed.Client()
174197
self.nb_processes = len(self.client.scheduler_info()["workers"])
175-
super().__init__()
198+
super().__init__(**kwargs)
176199

177200
def shutdown(self):
178201
"""Retire the workers on the scheduler."""
@@ -181,16 +204,17 @@ def shutdown(self):
181204
self.client.retire_workers()
182205
self.client = None
183206

184-
def get_mapper(self):
207+
def get_mapper(self, batch_size=None, chunk_size=None, **kwargs):
185208
"""Get a Dask mapper."""
209+
self._chunksize_to_kwargs(chunk_size, kwargs, label="batch_size")
186210

187211
def _mapper(func, iterable):
188212
def _dask_mapper(func, iterable):
189-
futures = self.client.map(func, iterable)
213+
futures = self.client.map(func, iterable, **kwargs)
190214
for _future, result in dask.distributed.as_completed(futures, with_results=True):
191215
yield result
192216

193-
return _with_batches(_dask_mapper, func, iterable, self.batch_size)
217+
return self._with_batches(_dask_mapper, func, iterable, batch_size=batch_size)
194218

195219
return _mapper
196220

examples/large_computation.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,33 @@
44
import time
55
from bluepyparallel import evaluate
66
from bluepyparallel import init_parallel_factory
7+
from data_validation_framework.util import apply_to_df
78

89

910
def func(row):
1011
"""Trivial computation"""
11-
12-
time.sleep(1)
13-
14-
return {"out": row["data"] + 10}
12+
if row["data"] in [1, 3]:
13+
raise ValueError(f"The value {row['data']} is forbidden")
14+
else:
15+
return {"out": row["data"] + 10}
1516

1617

1718
if __name__ == "__main__":
18-
parallel_lib = sys.argv[1]
19-
import bglibpy
20-
21-
parallel_factory = init_parallel_factory(parallel_lib)
22-
print("using ", parallel_lib)
19+
parallel_lib = sys.argv[1] or None
20+
batch_size = int(sys.argv[2]) if len(sys.argv) >= 3 else None
21+
chunk_size = int(sys.argv[3]) if len(sys.argv) >= 4 else None
2322
df = pd.DataFrame()
2423
df["data"] = np.arange(1e6)
25-
print(df)
26-
df = evaluate(df, func, new_columns=[["out", 0]], parallel_factory=parallel_factory)
24+
25+
parallel_factory = init_parallel_factory(parallel_lib, batch_size=batch_size)
26+
df = evaluate(
27+
df,
28+
func,
29+
new_columns=[["out", 0]],
30+
parallel_factory=parallel_factory,
31+
chunksize=chunk_size,
32+
)
2733
parallel_factory.shutdown()
34+
print(df)
35+
print(df.loc[1, "exception"])
36+
print(df.loc[3, "exception"])

examples/run_large_dask.sh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@ module load unstable py-dask-mpi
1616
module load unstable py-bglibpy
1717
module load unstable neurodamus-neocortex
1818

19-
deactivate
20-
. venv/bin/activate
19+
. ~/base/bin/activate
2120

2221
unset PMI_RANK
2322

24-
srun python large_computation.py dask
23+
srun python large_computation.py dask 100000 1000

tox.ini

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ deps = bbp-nse-ci
2929
commands = do_release.py -p . check-version
3030

3131
[testenv:lint]
32-
basepython=python3.6
3332
deps =
3433
{[base]testdeps}
3534
pycodestyle

0 commit comments

Comments
 (0)