Skip to content

Commit 2a3e1f0

Browse files
adrien-berchetarnaudon
authored andcommitted
Reduce SQL I/Os and can now pass args to the evaluation function
Change-Id: I825dd4d0920c8fcd0fef123420991d57ad49e899
1 parent 8762ada commit 2a3e1f0

File tree

7 files changed

+349
-116
lines changed

7 files changed

+349
-116
lines changed

bluepyparallel/evaluator.py

Lines changed: 84 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import sqlite3
44
import sys
55
import traceback
6-
from collections import defaultdict
76
from functools import partial
87
from pathlib import Path
98

@@ -15,47 +14,48 @@
1514
logger = logging.getLogger(__name__)
1615

1716

18-
def _try_evaluation(task, evaluation_function=None):
17+
def _try_evaluation(task, evaluation_function, db_filename, func_args, func_kwargs):
1918
"""Encapsulate the evaluation function into a try/except and isolate to record exceptions."""
2019
task_id, task_args = task
2120
try:
22-
result = evaluation_function(task_args)
23-
exception = ""
21+
result = evaluation_function(task_args, *func_args, **func_kwargs)
22+
exception = None
2423
except Exception: # pylint: disable=broad-except
2524
result = None
2625
exception = "".join(traceback.format_exception(*sys.exc_info()))
2726
logger.exception("Exception for ID=%s: %s", task_id, exception)
27+
28+
# Save the results into the DB
29+
if db_filename is not None:
30+
_write_to_sql(db_filename, task_id, result, exception)
2831
return task_id, result, exception
2932

3033

31-
def _create_database(df, new_columns, db_filename="db.sql"):
34+
def _create_database(df, db_filename="db.sql"):
3235
"""Create a sqlite database from dataframe."""
33-
df["exception"] = None
34-
for new_column in new_columns:
35-
df[new_column[0]] = new_column[1]
36-
df["to_run_" + new_column[0]] = 1
3736
with sqlite3.connect(str(db_filename)) as db:
38-
df.to_sql("df", db, if_exists="replace", index_label="index")
39-
return df
37+
df.to_sql("df", db, if_exists="replace", index_label="df_index")
4038

4139

4240
def _load_database_to_dataframe(db_filename="db.sql"):
4341
"""Load an SQL database and construct the dataframe."""
4442
with sqlite3.connect(str(db_filename)) as db:
45-
out = pd.read_sql("SELECT * FROM df", db, index_col="index")
46-
return out
43+
return pd.read_sql("SELECT * FROM df", db, index_col="df_index")
4744

4845

49-
def _write_to_sql(db_filename, task_id, results, new_columns, exception):
46+
def _write_to_sql(db_filename, task_id, results, exception):
5047
"""Write row data to SQL."""
5148
with sqlite3.connect(str(db_filename)) as db:
52-
for new_column in new_columns:
53-
res = results[new_column[0]] if results is not None else None
54-
db.execute(
55-
"UPDATE df SET " + new_column[0] + "=?, "
56-
"exception=?, to_run_" + new_column[0] + "=? WHERE `index`=?",
57-
(res, exception, 0, task_id),
58-
)
49+
if results is not None:
50+
keys, vals = zip(*results.items())
51+
query_keys = ", ".join([f"{k}=?" for k in keys])
52+
else:
53+
query_keys = "exception=?"
54+
vals = [exception]
55+
db.execute(
56+
"UPDATE df SET " + query_keys + " WHERE df_index=?",
57+
list(vals) + [task_id],
58+
)
5959

6060

6161
def evaluate(
@@ -65,88 +65,111 @@ def evaluate(
6565
resume=False,
6666
parallel_factory=None,
6767
db_filename=None,
68+
func_args=None,
69+
func_kwargs=None,
6870
):
6971
"""Evaluate and save results in a sqlite database on the fly and return dataframe.
7072
7173
Args:
72-
df (DataFrame): each row contains information for the computation
74+
df (DataFrame): each row contains information for the computation.
7375
evaluation_function (function): function used to evaluate each row,
7476
should have a single argument as list-like containing values of the rows of df,
75-
and return a dict with keys corresponding to the names in new_columns
77+
and return a dict with keys corresponding to the names in new_columns.
7678
new_columns (list): list of names of new column and empty value to save evaluation results,
77-
i.e.: [['result', 0.0], ['valid', False]]
79+
i.e.: [['result', 0.0], ['valid', False]].
7880
resume (bool): if True, it will use only compute the empty rows of the database,
79-
if False, it will ecrase or generate the database
80-
parallel_factory (ParallelFactory): parallel factory instance
81+
if False, it will ecrase or generate the database.
82+
parallel_factory (ParallelFactory): parallel factory instance.
8183
db_filename (str): if a file path is given, SQL backend will be enabled and will use this
8284
path for the SQLite database. Should not be used when evaluations are numerous and
8385
fast, in order to avoid the overhead of communication with SQL database.
86+
func_args (list): the arguments to pass to the evaluation_function.
87+
func_kwargs (dict): the keyword arguments to pass to the evaluation_function.
8488
8589
Return:
86-
pandas.DataFrame: dataframe with new columns containing computed results
90+
pandas.DataFrame: dataframe with new columns containing the computed results.
8791
"""
92+
# Initialize the parallel factory
8893
if isinstance(parallel_factory, str) or parallel_factory is None:
8994
parallel_factory = init_parallel_factory(parallel_factory)
9095

91-
task_ids = df.index
96+
# Set default args
97+
if func_args is None:
98+
func_args = []
99+
100+
# Set default kwargs
101+
if func_kwargs is None:
102+
func_kwargs = {}
92103

104+
# Shallow copy the given DataFrame to add internal rows
105+
to_evaluate = df.copy()
106+
task_ids = to_evaluate.index
107+
108+
# Set default new columns
93109
if new_columns is None:
94110
new_columns = [["data", ""]]
95111

112+
# Setup internal and new columns
113+
to_evaluate["exception"] = None
114+
for new_column in new_columns:
115+
to_evaluate[new_column[0]] = new_column[1]
116+
117+
# Create the database if required and get the task ids to run
96118
if db_filename is None:
97119
logger.info("Not using SQL backend to save iterations")
98-
to_evaluate = df
99120
elif resume:
100121
logger.info("Load data from SQL database")
101122
if Path(db_filename).exists():
102-
to_evaluate = _load_database_to_dataframe(db_filename=db_filename)
103-
task_ids = task_ids.intersection(to_evaluate.index)
123+
previous_results = _load_database_to_dataframe(db_filename=db_filename)
124+
previous_idx = previous_results.index
125+
bad_cols = [
126+
col
127+
for col in df.columns
128+
if not to_evaluate.loc[previous_idx, col].equals(previous_results[col])
129+
]
130+
if bad_cols:
131+
raise ValueError(
132+
f"The following columns have different values from the DataBase: {bad_cols}"
133+
)
134+
to_evaluate.loc[previous_results.index] = previous_results.loc[previous_results.index]
135+
task_ids = task_ids.difference(previous_results.index)
104136
else:
105-
to_evaluate = _create_database(df, new_columns, db_filename=db_filename)
106-
107-
# Find tasks to run
108-
should_run = (
109-
to_evaluate.loc[task_ids, ["to_run_" + col[0] for col in new_columns]] == 1
110-
).any(axis=1)
111-
task_ids = should_run.loc[should_run].index
137+
_create_database(to_evaluate, db_filename=db_filename)
112138
else:
113139
logger.info("Create SQL database")
114-
to_evaluate = _create_database(df, new_columns, db_filename=db_filename)
140+
_create_database(to_evaluate, db_filename=db_filename)
115141

142+
# Log the number of tasks to run
116143
if len(task_ids) > 0:
117144
logger.info("%s rows to compute.", str(len(task_ids)))
118145
else:
119146
logger.warning("WARNING: No row to compute, something may be wrong")
120-
return _load_database_to_dataframe(db_filename)
147+
return to_evaluate
121148

149+
# Get the factory mapper
122150
mapper = parallel_factory.get_mapper()
123151

124-
eval_func = partial(_try_evaluation, evaluation_function=evaluation_function)
125-
arg_list = to_evaluate.to_dict("index").items()
152+
# Setup the function to apply to the data
153+
eval_func = partial(
154+
_try_evaluation,
155+
evaluation_function=evaluation_function,
156+
db_filename=db_filename,
157+
func_args=func_args,
158+
func_kwargs=func_kwargs,
159+
)
126160

127-
if db_filename is None:
128-
_results = defaultdict(dict)
161+
# Split the data into rows
162+
arg_list = list(to_evaluate.loc[task_ids].to_dict("index").items())
129163

130164
try:
131165
for task_id, results, exception in tqdm(mapper(eval_func, arg_list), total=len(task_ids)):
132-
if db_filename is None:
133-
for new_column, _ in new_columns:
134-
_results[new_column][task_id] = (
135-
results[new_column] if results is not None else None
136-
)
137-
else:
138-
_write_to_sql(
139-
db_filename,
140-
task_id,
141-
results,
142-
new_columns,
143-
exception,
144-
)
166+
# Save the results into the DataFrame
167+
if results is not None:
168+
to_evaluate.loc[task_id, results.keys()] = list(results.values())
169+
elif exception is not None:
170+
to_evaluate.loc[task_id, "exception"] = exception
145171
except (KeyboardInterrupt, SystemExit) as ex:
146172
# To save dataframe even if program is killed
147173
logger.warning("Stopping mapper loop. Reason: %r", ex)
148174

149-
if db_filename is None:
150-
to_evaluate = pd.concat([to_evaluate, pd.DataFrame(_results)], axis=1)
151-
return to_evaluate
152-
return _load_database_to_dataframe(db_filename)
175+
return to_evaluate

bluepyparallel/parallel.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from abc import abstractmethod
77
from collections.abc import Iterator
88
from functools import partial
9+
from multiprocessing.pool import Pool
910

1011
import numpy as np
1112

@@ -33,8 +34,9 @@ class ParallelFactory:
3334

3435
_BATCH_SIZE = "PARALLEL_BATCH_SIZE"
3536

36-
def __init__(self):
37-
self.batch_size = int(os.getenv(self._BATCH_SIZE, "0")) or None
37+
def __init__(self, *args, batch_size=None, **kwargs): # pylint: disable=unused-argument
38+
self.batch_size = batch_size or int(os.getenv(self._BATCH_SIZE, "0")) or None
39+
self.nb_processes = 1
3840
L.info("Using %s=%s", self._BATCH_SIZE, self.batch_size)
3941

4042
@abstractmethod
@@ -64,7 +66,7 @@ def _set_daemon(self, value):
6466
daemon = property(_get_daemon, _set_daemon)
6567

6668

67-
class NestedPool(multiprocessing.pool.Pool): # pylint: disable=abstract-method
69+
class NestedPool(Pool): # pylint: disable=abstract-method
6870
"""Class that represents a MultiProcessing nested pool"""
6971

7072
Process = NoDaemonProcess
@@ -78,7 +80,7 @@ def _with_batches(mapper, func, iterable, batch_size=None):
7880
if isinstance(iterable, Iterator):
7981
iterable = list(iterable)
8082
if batch_size is not None:
81-
iterables = np.array_split(iterable, len(iterable) // batch_size)
83+
iterables = np.array_split(iterable, len(iterable) // min(batch_size, len(iterable)))
8284
else:
8385
iterables = [iterable]
8486

@@ -99,11 +101,12 @@ class MultiprocessingFactory(ParallelFactory):
99101

100102
_CHUNKSIZE = "PARALLEL_CHUNKSIZE"
101103

102-
def __init__(self):
104+
def __init__(self, *args, processes=None, **kwargs):
103105
"""Initialize multiprocessing factory."""
104106

105107
super().__init__()
106-
self.pool = NestedPool()
108+
self.pool = NestedPool(*args, **kwargs)
109+
self.nb_processes = processes or os.cpu_count()
107110

108111
def get_mapper(self):
109112
"""Get a NestedPool."""
@@ -123,17 +126,19 @@ class IPyParallelFactory(ParallelFactory):
123126

124127
_IPYTHON_PROFILE = "IPYTHON_PROFILE"
125128

126-
def __init__(self):
129+
def __init__(self, *args, **kwargs):
127130
"""Initialize the ipyparallel factory."""
128131

129132
super().__init__()
130133
self.rc = None
134+
self.nb_processes = 1
131135

132136
def get_mapper(self):
133137
"""Get an ipyparallel mapper using the profile name provided."""
134138
profile = os.getenv(self._IPYTHON_PROFILE, "DEFAULT_IPYTHON_PROFILE")
135139
L.debug("Using %s=%s", self._IPYTHON_PROFILE, profile)
136140
self.rc = ipyparallel.Client(profile=profile)
141+
self.nb_processes = len(self.rc.ids)
137142
lview = self.rc.load_balanced_view()
138143

139144
def _mapper(func, iterable):
@@ -154,7 +159,7 @@ class DaskFactory(ParallelFactory):
154159

155160
_SCHEDULER_PATH = "PARALLEL_DASK_SCHEDULER_PATH"
156161

157-
def __init__(self):
162+
def __init__(self, *args, **kwargs):
158163
"""Initialize the dask factory."""
159164
dask_scheduler_path = os.getenv(self._SCHEDULER_PATH)
160165
if dask_scheduler_path:
@@ -166,6 +171,7 @@ def __init__(self):
166171
dask_mpi.initialize()
167172
L.info("Starting dask_mpi...")
168173
self.client = dask.distributed.Client()
174+
self.nb_processes = len(self.client.scheduler_info()["workers"])
169175
super().__init__()
170176

171177
def shutdown(self):
@@ -189,7 +195,7 @@ def _dask_mapper(func, iterable):
189195
return _mapper
190196

191197

192-
def init_parallel_factory(parallel_lib):
198+
def init_parallel_factory(parallel_lib, *args, **kwargs):
193199
"""Return the desired instance of the parallel factory.
194200
195201
The main factories are:
@@ -209,7 +215,7 @@ def init_parallel_factory(parallel_lib):
209215
parallel_factories["ipyparallel"] = IPyParallelFactory
210216

211217
try:
212-
parallel_factory = parallel_factories[parallel_lib]()
218+
parallel_factory = parallel_factories[parallel_lib](*args, **kwargs)
213219
except KeyError:
214220
L.critical(
215221
"The %s factory is not available, maybe the required libraries are not properly "

examples/large_computation.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pandas as pd
2+
import sys
3+
import numpy as np
4+
import time
5+
from bluepyparallel import evaluate
6+
from bluepyparallel import init_parallel_factory
7+
8+
9+
def func(row):
10+
"""Trivial computation"""
11+
12+
time.sleep(1)
13+
14+
return {"out": row["data"] + 10}
15+
16+
17+
if __name__ == "__main__":
18+
parallel_lib = sys.argv[1]
19+
import bglibpy
20+
21+
parallel_factory = init_parallel_factory(parallel_lib)
22+
print("using ", parallel_lib)
23+
df = pd.DataFrame()
24+
df["data"] = np.arange(1e6)
25+
print(df)
26+
df = evaluate(df, func, new_columns=[["out", 0]], parallel_factory=parallel_factory)
27+
parallel_factory.shutdown()

examples/run_large_dask.sh

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#!/bin/bash -l
2+
#SBATCH --nodes=1 # Number of nodes
3+
#SBATCH --time=00:10:00 # Time limit
4+
#SBATCH --partition=prod
5+
#SBATCH --constraint=cpu
6+
#SBATCH --mem=0
7+
#SBATCH --cpus-per-task=1
8+
#SBATCH --account=proj82 # your project number
9+
#SBATCH --job-name=test_bpp
10+
set -e
11+
12+
13+
module purge
14+
module load unstable py-mpi4py
15+
module load unstable py-dask-mpi
16+
module load unstable py-bglibpy
17+
module load unstable neurodamus-neocortex
18+
19+
deactivate
20+
. venv/bin/activate
21+
22+
unset PMI_RANK
23+
24+
srun python large_computation.py dask

0 commit comments

Comments
 (0)