Skip to content

Commit 790656b

Browse files
Add wrapper to pass args and kwargs to the function and improve test coverage
Change-Id: I18b548b75b20eb8dc963d23d1dcdc575faeb903a
1 parent 626484c commit 790656b

File tree

11 files changed

+427
-109
lines changed

11 files changed

+427
-109
lines changed

bluepyparallel/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
"""BluePyParallel functions."""
2-
from .evaluator import evaluate # noqa
3-
from .parallel import init_parallel_factory # noqa
2+
from bluepyparallel.evaluator import evaluate # noqa
3+
from bluepyparallel.parallel import init_parallel_factory # noqa
4+
from bluepyparallel.version import VERSION as __version__ # noqa

bluepyparallel/database.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Module"""
1+
"""Module used to provide a simple API to a database in which the results are stored."""
22
import re
33

44
import pandas as pd
@@ -9,12 +9,28 @@
99
from sqlalchemy import schema
1010
from sqlalchemy import select
1111
from sqlalchemy.engine.reflection import Inspector
12+
from sqlalchemy.exc import OperationalError
1213
from sqlalchemy_utils import create_database
1314
from sqlalchemy_utils import database_exists
1415

16+
try: # pragma: no cover
17+
import psycopg2
18+
19+
with_psycopg2 = True
20+
except ImportError:
21+
with_psycopg2 = False
22+
1523

1624
class DataBase:
17-
"""A database API using SQLAlchemy."""
25+
"""A simple API to manage the database in which the results are inserted using SQLAlchemy.
26+
27+
Args:
28+
url (str): The URL of the database following the RFC-1738 format (
29+
https://docs.sqlalchemy.org/en/14/core/engines.html#database-urls)
30+
create (bool): If set to True, the database will be automatically created by the
31+
constructor.
32+
args and kwargs: They will be passed to the :func:`sqlalchemy.create_engine` function.
33+
"""
1834

1935
index_col = "df_index"
2036
_url_pattern = r"[a-zA-Z0-9_\-\+]+://.*"
@@ -25,25 +41,36 @@ def __init__(self, url, *args, create=False, **kwargs):
2541

2642
self.engine = create_engine(url, *args, **kwargs)
2743

28-
if create and not database_exists(self.engine.url):
44+
if create and not self.db_exists():
2945
create_database(self.engine.url)
3046

31-
self.connection = self.engine.connect()
47+
self._connection = None
3248
self.metadata = None
3349
self.table = None
3450

3551
def __del__(self):
52+
"""Close the connection and the engine to the database."""
3653
self.connection.close()
54+
self.engine.dispose()
55+
56+
@property
57+
def connection(self):
58+
"""Get a connection to the database."""
59+
if self._connection is None:
60+
self._connection = self.engine.connect()
61+
return self._connection
3762

3863
def get_url(self):
64+
"""Get the URL of the database."""
3965
return self.engine.url
4066

4167
def create(self, df, table_name=None, schema_name=None):
68+
"""Create a table in the database in which the results will be written."""
4269
if table_name is None:
4370
table_name = "df"
4471
if schema_name is not None and schema_name not in self.connection.dialect.get_schema_names(
4572
self.connection
46-
):
73+
): # pragma: no cover
4774
self.connection.execute(schema.CreateSchema(schema_name))
4875
new_df = df.loc[[]]
4976
new_df.to_sql(
@@ -55,11 +82,25 @@ def create(self, df, table_name=None, schema_name=None):
5582
)
5683
self.reflect(table_name, schema_name)
5784

85+
def db_exists(self):
86+
"""Check that the server and the database exist."""
87+
if with_psycopg2: # pragma: no cover
88+
exceptions = (OperationalError, psycopg2.OperationalError)
89+
else:
90+
exceptions = (OperationalError,)
91+
92+
try:
93+
return database_exists(self.engine.url)
94+
except exceptions: # pragma: no cover
95+
return False
96+
5897
def exists(self, table_name, schema_name=None):
98+
"""Test that the table exists in the database."""
5999
inspector = Inspector.from_engine(self.engine)
60100
return table_name in inspector.get_table_names(schema=schema_name)
61101

62102
def reflect(self, table_name, schema_name=None):
103+
"""Reflect the table from the database."""
63104
self.metadata = MetaData()
64105
self.table = Table(
65106
table_name,
@@ -70,10 +111,12 @@ def reflect(self, table_name, schema_name=None):
70111
)
71112

72113
def load(self):
114+
"""Load the table data from the database."""
73115
query = select([self.table])
74116
return pd.read_sql(query, self.connection, index_col=self.index_col)
75117

76118
def write(self, row_id, result=None, exception=None, **input_values):
119+
"""Write a result entry or an exception into the table."""
77120
if result is not None:
78121
vals = result
79122
elif exception is not None:

bluepyparallel/evaluator.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def evaluate(
8484

8585
# Set default new columns
8686
if new_columns is None:
87-
new_columns = [["data", ""]]
87+
new_columns = []
8888

8989
# Setup internal and new columns
9090
to_evaluate["exception"] = None
@@ -141,10 +141,9 @@ def evaluate(
141141
# Split the data into rows
142142
arg_list = list(to_evaluate.loc[task_ids, df.columns].to_dict("index").items())
143143

144+
res = []
144145
try:
145-
res = []
146-
147-
# Collect the results
146+
# Compute and collect the results
148147
for task_id, result, exception in tqdm(mapper(eval_func, arg_list), total=len(task_ids)):
149148
res.append(dict({"df_index": task_id, "exception": exception}, **result))
150149

@@ -154,13 +153,13 @@ def evaluate(
154153
task_id, result, exception, **to_evaluate.loc[task_id, df.columns].to_dict()
155154
)
156155

157-
# Gather the results to the output DataFrame
158-
res_df = pd.DataFrame(res)
159-
res_df.set_index("df_index", inplace=True)
160-
to_evaluate.loc[res_df.index, res_df.columns] = res_df
161-
162156
except (KeyboardInterrupt, SystemExit) as ex:
163157
# To save dataframe even if program is killed
164158
logger.warning("Stopping mapper loop. Reason: %r", ex)
165159

160+
# Gather the results to the output DataFrame
161+
res_df = pd.DataFrame(res)
162+
res_df.set_index("df_index", inplace=True)
163+
to_evaluate.loc[res_df.index, res_df.columns] = res_df
164+
166165
return to_evaluate

bluepyparallel/parallel.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,25 @@
1414
import dask_mpi
1515

1616
dask_available = True
17-
except ImportError:
17+
except ImportError: # pragma: no cover
1818
dask_available = False
1919

2020
try:
2121
import ipyparallel
2222

2323
ipyparallel_available = True
24-
except ImportError:
24+
except ImportError: # pragma: no cover
2525
ipyparallel_available = False
2626

2727

2828
L = logging.getLogger(__name__)
2929

3030

31+
def _func_wrapper(data, func, func_args, func_kwargs):
32+
"""Function wrapper used to pass args and kwargs."""
33+
return func(data, *func_args, **func_kwargs)
34+
35+
3136
class ParallelFactory:
3237
"""Abstract class that should be subclassed to provide parallel functions."""
3338

@@ -56,6 +61,10 @@ def get_mapper(self, batch_size=None, chunk_size=None, **kwargs):
5661
def shutdown(self):
5762
"""Can be used to cleanup."""
5863

64+
def mappable_func(self, func, *args, **kwargs):
65+
"""Can be used to add args and kwargs to a function before calling the mapper."""
66+
return partial(_func_wrapper, func=func, func_args=args, func_kwargs=kwargs)
67+
5968
def _with_batches(self, mapper, func, iterable, batch_size=None):
6069
"""Wrapper on mapper function creating batches of iterable to give to mapper.
6170
@@ -95,7 +104,7 @@ def __init__(self, group=None, target=None, name=None, args=(), kwargs={}):
95104

96105
def _get_daemon(self): # pylint: disable=no-self-use
97106
"""Get daemon flag"""
98-
return False
107+
return False # pragma: no cover
99108

100109
def _set_daemon(self, value):
101110
"""Set daemon flag"""
@@ -114,7 +123,12 @@ class SerialFactory(ParallelFactory):
114123

115124
def get_mapper(self, batch_size=None, chunk_size=None, **kwargs):
116125
"""Get a map."""
117-
return map
126+
127+
def _mapper(func, iterable, *func_args, **func_kwargs):
128+
func = self.mappable_func(func, *func_args, **func_kwargs)
129+
return self._with_batches(map, func, iterable)
130+
131+
return _mapper
118132

119133

120134
class MultiprocessingFactory(ParallelFactory):
@@ -134,7 +148,8 @@ def get_mapper(self, batch_size=None, chunk_size=None, **kwargs):
134148
"""Get a NestedPool."""
135149
self._chunksize_to_kwargs(chunk_size, kwargs, label="chunksize")
136150

137-
def _mapper(func, iterable):
151+
def _mapper(func, iterable, *func_args, **func_kwargs):
152+
func = self.mappable_func(func, *func_args, **func_kwargs)
138153
return self._with_batches(
139154
partial(self.pool.imap_unordered, **kwargs),
140155
func,
@@ -164,12 +179,13 @@ def __init__(self, batch_size=None, chunk_size=None, profile=None, **kwargs):
164179

165180
def get_mapper(self, batch_size=None, chunk_size=None, **kwargs):
166181
"""Get an ipyparallel mapper using the profile name provided."""
167-
if "ordered" not in kwargs:
182+
if "ordered" not in kwargs: # pragma: no cover
168183
kwargs["ordered"] = False
169184

170185
self._chunksize_to_kwargs(chunk_size, kwargs)
171186

172-
def _mapper(func, iterable):
187+
def _mapper(func, iterable, *func_args, **func_kwargs):
188+
func = self.mappable_func(func, *func_args, **func_kwargs)
173189
return self._with_batches(
174190
partial(self.lview.imap, **kwargs), func, iterable, batch_size=batch_size
175191
)
@@ -178,7 +194,7 @@ def _mapper(func, iterable):
178194

179195
def shutdown(self):
180196
"""Remove zmq."""
181-
if self.rc is not None:
197+
if self.rc is not None: # pragma: no cover
182198
self.rc.close()
183199

184200

@@ -193,11 +209,11 @@ def __init__(
193209
"""Initialize the dask factory."""
194210
dask_scheduler_path = scheduler_file or os.getenv(self._SCHEDULER_PATH)
195211
self.interactive = True
196-
if dask_scheduler_path:
212+
if dask_scheduler_path: # pragma: no cover
197213
L.info("Connecting dask_mpi with scheduler %s", dask_scheduler_path)
198-
if address:
214+
if address: # pragma: no cover
199215
L.info("Connecting dask_mpi with address %s", address)
200-
if not dask_scheduler_path and not address:
216+
if not dask_scheduler_path and not address: # pragma: no cover
201217
self.interactive = False
202218
dask_mpi.initialize()
203219
L.info("Starting dask_mpi...")
@@ -213,19 +229,20 @@ def shutdown(self):
213229
"""Close the scheduler and the cluster if it was created by the factory."""
214230
cluster = self.client.cluster
215231
self.client.close()
216-
if not self.interactive:
232+
if not self.interactive: # pragma: no cover
217233
cluster.close()
218234

219235
def get_mapper(self, batch_size=None, chunk_size=None, **kwargs):
220236
"""Get a Dask mapper."""
221237
self._chunksize_to_kwargs(chunk_size, kwargs, label="batch_size")
222238

223-
def _mapper(func, iterable):
224-
def _dask_mapper(func, iterable):
239+
def _mapper(func, iterable, *func_args, **func_kwargs):
240+
def _dask_mapper(func, iterable, **kwargs):
225241
futures = self.client.map(func, iterable, **kwargs)
226242
for _future, result in dask.distributed.as_completed(futures, with_results=True):
227243
yield result
228244

245+
func = self.mappable_func(func, *func_args, **func_kwargs)
229246
return self._with_batches(_dask_mapper, func, iterable, batch_size=batch_size)
230247

231248
return _mapper
@@ -245,9 +262,9 @@ def init_parallel_factory(parallel_lib, *args, **kwargs):
245262
None: SerialFactory,
246263
"multiprocessing": MultiprocessingFactory,
247264
}
248-
if dask_available:
265+
if dask_available: # pragma: no cover
249266
parallel_factories["dask"] = DaskFactory
250-
if ipyparallel_available:
267+
if ipyparallel_available: # pragma: no cover
251268
parallel_factories["ipyparallel"] = IPyParallelFactory
252269

253270
try:

bluepyparallel/version.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
"""Package version """
2-
3-
VERSION = "0.0.3.dev0"
4-
__version__ = VERSION
1+
"""Package version"""
2+
# pragma: no cover
3+
VERSION = "0.0.3.dev1"

doc/source/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ API Documentation
77
bluepyparallel
88
bluepyparallel.parallel
99
bluepyparallel.evaluator
10+
bluepyparallel.database

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"sphinx-bluebrain-theme",
2727
]
2828

29-
VERSION = imp.load_source("", "bluepyparallel/version.py").__version__
29+
VERSION = imp.load_source("", "bluepyparallel/version.py").VERSION
3030

3131
setup(
3232
name="BluePyParallel",

0 commit comments

Comments
 (0)