Skip to content

Commit d4a6767

Browse files
Improve DB inserts for dask_dataframe factory
Change-Id: I7d696a45df6818d14a77dc4ae0d0db0fa844f88b
1 parent 30e3fd0 commit d4a6767

File tree

6 files changed

+58
-19
lines changed

6 files changed

+58
-19
lines changed

CHANGELOG.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
11
Changelog
22
=========
33

4+
Version 0.0.6
5+
-------------
6+
7+
- Improve DB inserts for dask.dataframe factory
8+
9+
Version 0.0.5
10+
-------------
11+
12+
- Add support for dask.dataframe
13+
414
Version 0.0.4
515
-------------
616

7-
- Added DaskDataframe factory
17+
- Update doc, README and author
818

919
Version 0.0.3
1020
-------------

bluepyparallel/database.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pandas as pd
55
from sqlalchemy import MetaData
66
from sqlalchemy import Table
7+
from sqlalchemy import bindparam
78
from sqlalchemy import create_engine
89
from sqlalchemy import insert
910
from sqlalchemy import schema
@@ -15,6 +16,7 @@
1516

1617
try: # pragma: no cover
1718
import psycopg2
19+
import psycopg2.extras
1820

1921
with_psycopg2 = True
2022
except ImportError:
@@ -126,3 +128,23 @@ def write(self, row_id, result=None, exception=None, **input_values):
126128

127129
query = insert(self.table).values(dict(**{self.index_col: row_id}, **vals, **input_values))
128130
self.connection.execute(query)
131+
132+
def write_batch(self, columns, data):
133+
"""Write entries from a list of lists into the table."""
134+
if not data: # pragma: no cover
135+
return
136+
assert len(columns) + 1 == len(
137+
data[0]
138+
), "The columns list must have one less entry than each data element"
139+
cursor = self.connection.connection.cursor()
140+
cols = {col: bindparam(col) for col in [self.index_col] + columns}
141+
# pylint: disable=no-value-for-parameter
142+
compiled = self.table.insert().values(**cols).compile(dialect=self.engine.dialect)
143+
144+
if hasattr(cursor, "mogrify") and with_psycopg2: # pragma: no cover
145+
psycopg2.extras.execute_values(cursor, str(compiled), data)
146+
else:
147+
cursor.executemany(str(compiled), data)
148+
149+
self.connection.connection.commit()
150+
self.connection.connection.close()

bluepyparallel/evaluator.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,15 @@ def _try_evaluation_df(task, evaluation_function, func_args, func_kwargs):
4242

4343

4444
def _evaluate_dataframe(
45-
to_evaluate, df, evaluation_function, func_args, func_kwargs, new_columns, mapper, task_ids, db
45+
to_evaluate,
46+
input_cols,
47+
evaluation_function,
48+
func_args,
49+
func_kwargs,
50+
new_columns,
51+
mapper,
52+
task_ids,
53+
db,
4654
):
4755
"""Internal evalution function for dask.dataframe."""
4856
# Setup the function to apply to the data
@@ -57,25 +65,21 @@ def _evaluate_dataframe(
5765
res = []
5866
try:
5967
# Compute and collect the results
60-
for batch in mapper(eval_func, to_evaluate.loc[task_ids, df.columns], meta=meta):
68+
for batch in mapper(eval_func, to_evaluate.loc[task_ids, input_cols], meta=meta):
6169
res.append(batch)
6270

6371
if db is not None:
64-
# pylint: disable=cell-var-from-loop
65-
batch_complete = to_evaluate[df.columns].join(batch, how="right")
66-
batch_cols = [col for col in batch_complete.columns if col != "exception"]
67-
batch_complete.apply(
68-
lambda row: db.write(row.name, row[batch_cols].to_dict(), row["exception"]),
69-
axis=1,
70-
)
72+
batch_complete = to_evaluate[input_cols].join(batch, how="right")
73+
data = batch_complete.to_records().tolist()
74+
db.write_batch(batch_complete.columns.tolist(), data)
7175
except (KeyboardInterrupt, SystemExit) as ex: # pragma: no cover
7276
# To save dataframe even if program is killed
7377
logger.warning("Stopping mapper loop. Reason: %r", ex)
7478
return pd.concat(res)
7579

7680

7781
def _evaluate_basic(
78-
to_evaluate, df, evaluation_function, func_args, func_kwargs, mapper, task_ids, db
82+
to_evaluate, input_cols, evaluation_function, func_args, func_kwargs, mapper, task_ids, db
7983
):
8084

8185
res = []
@@ -88,7 +92,7 @@ def _evaluate_basic(
8892
)
8993

9094
# Split the data into rows
91-
arg_list = list(to_evaluate.loc[task_ids, df.columns].to_dict("index").items())
95+
arg_list = list(to_evaluate.loc[task_ids, input_cols].to_dict("index").items())
9296

9397
try:
9498
# Compute and collect the results
@@ -98,7 +102,7 @@ def _evaluate_basic(
98102
# Save the results into the DB
99103
if db is not None:
100104
db.write(
101-
task_id, result, exception, **to_evaluate.loc[task_id, df.columns].to_dict()
105+
task_id, result, exception, **to_evaluate.loc[task_id, input_cols].to_dict()
102106
)
103107
except (KeyboardInterrupt, SystemExit) as ex:
104108
# To save dataframe even if program is killed
@@ -132,7 +136,7 @@ def _prepare_db(db_url, to_evaluate, df, resume, task_ids):
132136
logger.info("Create SQL database")
133137
db.create(to_evaluate)
134138

135-
return db, db.get_url()
139+
return db, db.get_url(), task_ids
136140

137141

138142
def evaluate(
@@ -209,7 +213,7 @@ def evaluate(
209213
logger.info("Not using SQL backend to save iterations")
210214
db = None
211215
else:
212-
db, db_url = _prepare_db(db_url, to_evaluate, df, resume, task_ids)
216+
db, db_url, task_ids = _prepare_db(db_url, to_evaluate, df, resume, task_ids)
213217

214218
# Log the number of tasks to run
215219
if len(task_ids) > 0:
@@ -224,7 +228,7 @@ def evaluate(
224228
if isinstance(parallel_factory, DaskDataFrameFactory):
225229
res_df = _evaluate_dataframe(
226230
to_evaluate,
227-
df,
231+
df.columns,
228232
evaluation_function,
229233
func_args,
230234
func_kwargs,
@@ -236,7 +240,7 @@ def evaluate(
236240
else:
237241
res_df = _evaluate_basic(
238242
to_evaluate,
239-
df,
243+
df.columns,
240244
evaluation_function,
241245
func_args,
242246
func_kwargs,

bluepyparallel/parallel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def _with_batches(self, mapper, func, iterable, batch_size=None):
9090
else:
9191
iterables = [iterable]
9292

93-
for _iterable in iterables:
93+
for i, _iterable in enumerate(iterables):
94+
if len(iterables) > 1:
95+
L.info("Computing batch %s / %s", i + 1, len(iterables))
9496
yield from mapper(func, _iterable)
9597

9698
def _chunksize_to_kwargs(self, chunk_size, kwargs, label="chunk_size"):

bluepyparallel/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Package version"""
22
# pragma: no cover
3-
VERSION = "0.0.5"
3+
VERSION = "0.0.6.dev0"

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
]
2424

2525
doc_reqs = [
26+
"sphinx<4",
2627
"sphinx-bluebrain-theme",
2728
]
2829

0 commit comments

Comments
 (0)