Skip to content

Commit 427930e

Browse files
Improve factory interfaces
Change-Id: I012e2de0635b3d6701ac441520ae96aba6be8899
1 parent 4619d24 commit 427930e

File tree

8 files changed

+114
-59
lines changed

8 files changed

+114
-59
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ coverage.xml
8181
*.cover
8282
.hypothesis/
8383
reports
84+
dask-worker-space
8485

8586
# Translations
8687
*.mo

bluepyparallel/evaluator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,15 @@ def evaluate(
7373
if func_kwargs is None:
7474
func_kwargs = {}
7575

76+
# Drop exception column if present
77+
if "exception" in df.columns:
78+
logger.warning("The 'exception' column is going to be replaced")
79+
df = df.drop(columns=["exception"])
80+
7681
# Shallow copy the given DataFrame to add internal rows
7782
to_evaluate = df.copy()
7883
task_ids = to_evaluate.index
7984

80-
if "exception" in to_evaluate.columns:
81-
logger.warning("The exception column is going to be replaced")
82-
to_evaluate = to_evaluate.drop(columns=["exception"])
83-
8485
# Set default new columns
8586
if new_columns is None:
8687
new_columns = [["data", ""]]

bluepyparallel/parallel.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import logging
33
import multiprocessing
44
import os
5-
import time
65
from abc import abstractmethod
76
from collections.abc import Iterator
87
from functools import partial
@@ -36,14 +35,19 @@ class ParallelFactory:
3635
_CHUNK_SIZE = "PARALLEL_CHUNK_SIZE"
3736

3837
# pylint: disable=unused-argument
39-
def __init__(self, batch_size=None, chunk_size=None, **kwargs):
38+
def __init__(self, batch_size=None, chunk_size=None):
4039
self.batch_size = batch_size or int(os.getenv(self._BATCH_SIZE, "0")) or None
4140
L.info("Using %s=%s", self._BATCH_SIZE, self.batch_size)
4241

4342
self.chunk_size = batch_size or int(os.getenv(self._CHUNK_SIZE, "0")) or None
4443
L.info("Using %s=%s", self._CHUNK_SIZE, self.chunk_size)
4544

46-
self.nb_processes = 1
45+
if not hasattr(self, "nb_processes"):
46+
self.nb_processes = 1
47+
48+
def __del__(self):
49+
"""Call the shutdown method."""
50+
self.shutdown()
4751

4852
@abstractmethod
4953
def get_mapper(self, batch_size=None, chunk_size=None, **kwargs):
@@ -62,7 +66,12 @@ def _with_batches(self, mapper, func, iterable, batch_size=None):
6266

6367
batch_size = batch_size or self.batch_size
6468
if batch_size is not None:
65-
iterables = np.array_split(iterable, len(iterable) // min(batch_size, len(iterable)))
69+
iterables = [
70+
_iterable.tolist()
71+
for _iterable in np.array_split(
72+
iterable, len(iterable) // min(batch_size, len(iterable))
73+
)
74+
]
6675
else:
6776
iterables = [iterable]
6877

@@ -113,17 +122,17 @@ class MultiprocessingFactory(ParallelFactory):
113122

114123
_CHUNKSIZE = "PARALLEL_CHUNKSIZE"
115124

116-
def __init__(self, processes=None, **kwargs):
125+
def __init__(self, batch_size=None, chunk_size=None, processes=None, **kwargs):
117126
"""Initialize multiprocessing factory."""
118127

119-
super().__init__(**kwargs)
128+
super().__init__(batch_size, chunk_size)
120129

121-
self.pool = NestedPool(processes=processes)
122130
self.nb_processes = processes or os.cpu_count()
131+
self.pool = NestedPool(processes=self.nb_processes, **kwargs)
123132

124133
def get_mapper(self, batch_size=None, chunk_size=None, **kwargs):
125134
"""Get a NestedPool."""
126-
self._chunksize_to_kwargs(chunk_size, kwargs)
135+
self._chunksize_to_kwargs(chunk_size, kwargs, label="chunksize")
127136

128137
def _mapper(func, iterable):
129138
return self._with_batches(
@@ -144,29 +153,25 @@ class IPyParallelFactory(ParallelFactory):
144153

145154
_IPYTHON_PROFILE = "IPYTHON_PROFILE"
146155

147-
def __init__(self, **kwargs):
156+
def __init__(self, batch_size=None, chunk_size=None, profile=None, **kwargs):
148157
"""Initialize the ipyparallel factory."""
149-
150-
super().__init__(**kwargs)
151-
self.rc = None
152-
self.nb_processes = 1
153-
154-
def get_mapper(self, batch_size=None, chunk_size=None, **kwargs):
155-
"""Get an ipyparallel mapper using the profile name provided."""
156-
profile = os.getenv(self._IPYTHON_PROFILE, None)
158+
profile = profile or os.getenv(self._IPYTHON_PROFILE, None)
157159
L.debug("Using %s=%s", self._IPYTHON_PROFILE, profile)
158-
self.rc = ipyparallel.Client(profile=profile)
160+
self.rc = ipyparallel.Client(profile=profile, **kwargs)
159161
self.nb_processes = len(self.rc.ids)
160-
lview = self.rc.load_balanced_view()
162+
self.lview = self.rc.load_balanced_view()
163+
super().__init__(batch_size, chunk_size)
161164

165+
def get_mapper(self, batch_size=None, chunk_size=None, **kwargs):
166+
"""Get an ipyparallel mapper using the profile name provided."""
162167
if "ordered" not in kwargs:
163168
kwargs["ordered"] = False
164169

165170
self._chunksize_to_kwargs(chunk_size, kwargs)
166171

167172
def _mapper(func, iterable):
168173
return self._with_batches(
169-
partial(lview.imap, **kwargs), func, iterable, batch_size=batch_size
174+
partial(self.lview.imap, **kwargs), func, iterable, batch_size=batch_size
170175
)
171176

172177
return _mapper
@@ -182,27 +187,34 @@ class DaskFactory(ParallelFactory):
182187

183188
_SCHEDULER_PATH = "PARALLEL_DASK_SCHEDULER_PATH"
184189

185-
def __init__(self, **kwargs):
190+
def __init__(
191+
self, batch_size=None, chunk_size=None, scheduler_file=None, address=None, **kwargs
192+
):
186193
"""Initialize the dask factory."""
187-
dask_scheduler_path = os.getenv(self._SCHEDULER_PATH)
194+
dask_scheduler_path = scheduler_file or os.getenv(self._SCHEDULER_PATH)
195+
self.interactive = True
188196
if dask_scheduler_path:
189-
self.interactive = True
190197
L.info("Connecting dask_mpi with scheduler %s", dask_scheduler_path)
191-
self.client = dask.distributed.Client(scheduler_file=dask_scheduler_path)
192-
else:
198+
if address:
199+
L.info("Connecting dask_mpi with address %s", address)
200+
if not dask_scheduler_path and not address:
193201
self.interactive = False
194-
dask_mpi.initialize()
195202
L.info("Starting dask_mpi...")
196-
self.client = dask.distributed.Client()
203+
dask_mpi.initialize()
204+
self.client = dask.distributed.Client(
205+
address=address,
206+
scheduler_file=dask_scheduler_path,
207+
**kwargs,
208+
)
197209
self.nb_processes = len(self.client.scheduler_info()["workers"])
198-
super().__init__(**kwargs)
210+
super().__init__(batch_size, chunk_size)
199211

200212
def shutdown(self):
201-
"""Retire the workers on the scheduler."""
213+
"""Close the scheduler and the cluster if it was created by the factory."""
214+
cluster = self.client.cluster
215+
self.client.close()
202216
if not self.interactive:
203-
time.sleep(1)
204-
self.client.retire_workers()
205-
self.client = None
217+
cluster.close()
206218

207219
def get_mapper(self, batch_size=None, chunk_size=None, **kwargs):
208220
"""Get a Dask mapper."""

examples/large_computation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import time
55
from bluepyparallel import evaluate
66
from bluepyparallel import init_parallel_factory
7-
from data_validation_framework.util import apply_to_df
87

98

109
def func(row):
1110
"""Trivial computation"""
11+
time.sleep(5)
12+
1213
if row["data"] in [1, 3]:
1314
raise ValueError(f"The value {row['data']} is forbidden")
1415
else:
@@ -20,7 +21,7 @@ def func(row):
2021
batch_size = int(sys.argv[2]) if len(sys.argv) >= 3 else None
2122
chunk_size = int(sys.argv[3]) if len(sys.argv) >= 4 else None
2223
df = pd.DataFrame()
23-
df["data"] = np.arange(1e6)
24+
df["data"] = np.arange(200)
2425

2526
parallel_factory = init_parallel_factory(parallel_lib, batch_size=batch_size)
2627
df = evaluate(

tests/__init__.py

Whitespace-only changes.

tests/conftest.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,46 @@
11
"""Prepare the tests."""
2+
# pylint: disable=redefined-outer-name
3+
import copy
4+
import os
5+
6+
import dask.distributed
27
import pytest
38

9+
from bluepyparallel import init_parallel_factory
10+
411

512
@pytest.fixture
613
def db_url(tmpdir):
714
return tmpdir / "db.sql"
15+
16+
17+
@pytest.fixture(params=[None, "multiprocessing", "ipyparallel", "dask"])
18+
def factory_type(request):
19+
return request.param
20+
21+
22+
@pytest.fixture(scope="session")
23+
def dask_cluster():
24+
cluster = dask.distributed.LocalCluster()
25+
yield cluster
26+
cluster.close()
27+
28+
29+
@pytest.fixture(
30+
params=[
31+
{},
32+
{"chunk_size": 2},
33+
{"batch_size": 2},
34+
{"chunk_size": 2, "batch_size": 2},
35+
{"chunk_size": 999, "batch_size": 999},
36+
]
37+
)
38+
def parallel_factory(factory_type, dask_cluster, request):
39+
factory_kwargs = copy.deepcopy(request.param)
40+
if factory_type == "dask":
41+
factory_kwargs["address"] = dask_cluster
42+
elif factory_type == "ipyparallel":
43+
tox_name = os.environ.get("TOX_ENV_NAME")
44+
if tox_name:
45+
factory_kwargs["cluster_id"] = f"bluepyparallel_{tox_name}"
46+
return init_parallel_factory(factory_type, **factory_kwargs)

tests/test_evaluator.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,8 @@ class TestEvaluate:
6161
"""Test the bluepyparallel.evaluator.evaluate function."""
6262

6363
@pytest.mark.parametrize("with_sql", [True, False])
64-
@pytest.mark.parametrize("factory_type", [None, "multiprocessing"])
65-
def test_evaluate(self, input_df, new_columns, expected_df, db_url, with_sql, factory_type):
64+
def test_evaluate(self, input_df, new_columns, expected_df, db_url, with_sql, parallel_factory):
6665
"""Test evaluator on a trivial example."""
67-
parallel_factory = init_parallel_factory(factory_type)
68-
6966
result_df = evaluate(
7067
input_df,
7168
_evaluation_function,
@@ -88,7 +85,6 @@ def test_evaluate(self, input_df, new_columns, expected_df, db_url, with_sql, fa
8885
],
8986
)
9087
@pytest.mark.parametrize("with_sql", [True, False])
91-
@pytest.mark.parametrize("factory_type", [None, "multiprocessing"])
9288
def test_evaluate_args_kwargs(
9389
self,
9490
input_df,
@@ -97,10 +93,9 @@ def test_evaluate_args_kwargs(
9793
db_url,
9894
func_args_kwargs,
9995
with_sql,
100-
factory_type,
96+
parallel_factory,
10197
):
10298
"""Test evaluator on a trivial example with passing args or kwargs."""
103-
parallel_factory = init_parallel_factory(factory_type)
10499
args, kwargs = deepcopy(func_args_kwargs)
105100

106101
result_df = evaluate(
@@ -124,11 +119,8 @@ def test_evaluate_args_kwargs(
124119

125120
assert_frame_equal(result_df, expected_df, check_like=True)
126121

127-
@pytest.mark.parametrize("factory_type", [None, "multiprocessing"])
128-
def test_evaluate_resume(self, input_df, new_columns, expected_df, db_url, factory_type):
122+
def test_evaluate_resume(self, input_df, new_columns, expected_df, db_url, parallel_factory):
129123
"""Test evaluator on a trivial example."""
130-
parallel_factory = init_parallel_factory(factory_type)
131-
132124
# Compute some values
133125
tmp_df = evaluate(
134126
input_df.loc[[0, 2]],
@@ -193,11 +185,10 @@ def test_evaluate_resume_bad_cols(self, input_df, new_columns, db_url):
193185
db_url=db_url,
194186
)
195187

196-
@pytest.mark.parametrize("factory_type", [None, "multiprocessing"])
197-
def test_evaluate_overwrite_db(self, input_df, new_columns, expected_df, db_url, factory_type):
188+
def test_evaluate_overwrite_db(
189+
self, input_df, new_columns, expected_df, db_url, parallel_factory
190+
):
198191
"""Test evaluator on a trivial example."""
199-
parallel_factory = init_parallel_factory(factory_type)
200-
201192
# Compute once
202193
previous_df = input_df.copy(deep=True)
203194
previous_df["name"] += "_previous"
@@ -228,7 +219,6 @@ class TestBenchmark:
228219
@pytest.mark.parametrize("df_size", ["small", "big"])
229220
@pytest.mark.parametrize("function_type", ["fast", "slow"])
230221
@pytest.mark.parametrize("with_sql", [True, False])
231-
@pytest.mark.parametrize("factory_type", [None, "multiprocessing"])
232222
def test_evaluate(
233223
self,
234224
input_df,
@@ -238,12 +228,10 @@ def test_evaluate(
238228
df_size,
239229
function_type,
240230
with_sql,
241-
factory_type,
231+
parallel_factory,
242232
benchmark,
243233
):
244234
"""Test evaluator on a trivial example."""
245-
parallel_factory = init_parallel_factory(factory_type, processes=None)
246-
247235
if df_size == "big":
248236
input_df = input_df.loc[np.repeat(input_df.index.values, 50)].reset_index(drop=True)
249237
expected_df = expected_df.loc[np.repeat(expected_df.index.values, 50)].reset_index(

tox.ini

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ testdeps =
66
pytest-benchmark
77
pytest-cov
88
pytest-html
9-
pytest-xdist
109

1110
[tox]
1211
envlist =
@@ -19,7 +18,21 @@ minversion = 3.1.0
1918

2019
[testenv]
2120
deps = {[base]testdeps}
22-
commands = pytest -n 2 --basetemp={envtmpdir} --cov={envsitepackagesdir}/{[base]name} --cov-branch --no-cov-on-fail --html reports/pytest-{envname}.html --self-contained-html --benchmark-skip {posargs}
21+
commands_pre =
22+
- ipcluster stop --cluster-id={[base]name}_{envname} --debug
23+
ipcluster start -n 2 --daemonize --log-to-file --cluster-id={[base]name}_{envname} --debug
24+
commands =
25+
pytest \
26+
--basetemp={envtmpdir} \
27+
--cov={[base]name} \
28+
--cov-branch \
29+
--no-cov-on-fail \
30+
--html reports/pytest-{envname}.html \
31+
--self-contained-html \
32+
--benchmark-skip \
33+
{posargs}
34+
commands_post =
35+
- ipcluster stop --cluster-id={[base]name}_{envname} --debug
2336

2437
[testenv:check-version]
2538
skip_install = true

0 commit comments

Comments
 (0)