Skip to content

Commit eda54a7

Browse files
committed
Feat: Shuffle rows for speedup
1 parent d3d55a3 commit eda54a7

File tree

4 files changed

+31
-2
lines changed

4 files changed

+31
-2
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ max-locals=20
2222
# Maximum number of return / yield for function / method body
2323
max-returns=6
2424
# Maximum number of branch for function / method body
25-
max-branches=15
25+
max-branches=16
2626
# Maximum number of statements in function / method body
2727
max-statements=50
2828
# Maximum number of parents for a class (see R0901).

bluepyparallel/evaluator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Module to evaluate generic functions on rows of dataframe."""
2+
23
import logging
34
import sys
45
import traceback
@@ -147,6 +148,7 @@ def evaluate(
147148
db_url=None,
148149
func_args=None,
149150
func_kwargs=None,
151+
shuffle_rows=True,
150152
**mapper_kwargs,
151153
):
152154
"""Evaluate and save results in a sqlite database on the fly and return dataframe.
@@ -168,6 +170,7 @@ def evaluate(
168170
communication with the SQL database.
169171
func_args (list): the arguments to pass to the evaluation_function.
170172
func_kwargs (dict): the keyword arguments to pass to the evaluation_function.
173+
shuffle_rows (bool): if :obj:`True`, it will shuffle the rows before computing the results.
171174
**mapper_kwargs: the keyword arguments are passed to the get_mapper() method of the
172175
:class:`ParallelFactory` instance.
173176
@@ -192,6 +195,10 @@ def evaluate(
192195

193196
# Shallow copy the given DataFrame to add internal rows
194197
to_evaluate = df.copy()
198+
199+
if shuffle_rows:
200+
to_evaluate = to_evaluate.sample(frac=1)
201+
195202
task_ids = to_evaluate.index
196203

197204
# Set default new columns
@@ -249,4 +256,7 @@ def evaluate(
249256
)
250257
to_evaluate.loc[res_df.index, res_df.columns] = res_df
251258

259+
if shuffle_rows:
260+
return to_evaluate.loc[df.index]
261+
252262
return to_evaluate

tests/test_evaluator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,24 @@ def test_evaluate(self, input_df, new_columns, expected_df, db_url, with_sql, pa
9797

9898
assert_frame_equal(result_df, expected_df, check_like=True)
9999

100+
@pytest.mark.parametrize("with_sql", [True, False])
101+
def test_evaluate_no_shuffle(
102+
self, input_df, new_columns, expected_df, db_url, with_sql, parallel_factory
103+
):
104+
"""Test evaluator on a trivial example."""
105+
result_df = evaluate(
106+
input_df,
107+
_evaluation_function,
108+
new_columns,
109+
parallel_factory=parallel_factory,
110+
db_url=db_url if with_sql else None,
111+
shuffle_rows=False,
112+
)
113+
if not with_sql:
114+
remove_sql_cols(expected_df)
115+
116+
assert_frame_equal(result_df, expected_df, check_like=True)
117+
100118
def test_evaluate_no_factory(self, input_df, new_columns, expected_df):
101119
"""Test evaluator with no given factory."""
102120
result_df = evaluate(
@@ -180,6 +198,7 @@ def test_evaluate_keyboard_interrupt(self, input_df, expected_df):
180198
result_df = evaluate(
181199
input_df,
182200
_interrupting_function,
201+
shuffle_rows=False,
183202
)
184203
remove_sql_cols(expected_df)
185204

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ commands_pre =
9595
commands_post =
9696

9797
[testenv:format]
98-
basepython = python3.8
98+
basepython = python3
9999
skip_install = true
100100
deps =
101101
codespell

0 commit comments

Comments
 (0)