Skip to content

Commit df4b447

Browse files
arnaudonadrien-berchet
authored andcommitted
dask_dataframe
Change-Id: Ib26384396ff8a803ab69b99804b7dfc4307eb0be
1 parent 38e45d2 commit df4b447

File tree

10 files changed

+295
-107
lines changed

10 files changed

+295
-107
lines changed

bluepyparallel/evaluator.py

Lines changed: 138 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from tqdm import tqdm
99

1010
from bluepyparallel.database import DataBase
11+
from bluepyparallel.parallel import DaskDataFrameFactory
1112
from bluepyparallel.parallel import init_parallel_factory
1213

1314
logger = logging.getLogger(__name__)
@@ -16,16 +17,124 @@
1617
def _try_evaluation(task, evaluation_function, func_args, func_kwargs):
1718
"""Encapsulate the evaluation function into a try/except and isolate to record exceptions."""
1819
task_id, task_args = task
20+
1921
try:
2022
result = evaluation_function(task_args, *func_args, **func_kwargs)
2123
exception = None
2224
except Exception: # pylint: disable=broad-except
2325
result = {}
2426
exception = "".join(traceback.format_exception(*sys.exc_info()))
2527
logger.exception("Exception for ID=%s: %s", task_id, exception)
28+
2629
return task_id, result, exception
2730

2831

32+
def _try_evaluation_df(task, evaluation_function, func_args, func_kwargs):
33+
task_id, result, exception = _try_evaluation(
34+
(task.name, task),
35+
evaluation_function,
36+
func_args,
37+
func_kwargs,
38+
)
39+
res_cols = list(result.keys())
40+
result["exception"] = exception
41+
return pd.Series(result, name=task_id, dtype="object", index=["exception"] + res_cols)
42+
43+
44+
def _evaluate_dataframe(
45+
to_evaluate, df, evaluation_function, func_args, func_kwargs, new_columns, mapper, task_ids, db
46+
):
47+
"""Internal evalution function for dask.dataframe."""
48+
# Setup the function to apply to the data
49+
eval_func = partial(
50+
_try_evaluation_df,
51+
evaluation_function=evaluation_function,
52+
func_args=func_args,
53+
func_kwargs=func_kwargs,
54+
)
55+
meta = pd.DataFrame({col[0]: pd.Series(dtype="object") for col in new_columns})
56+
57+
res = []
58+
try:
59+
# Compute and collect the results
60+
for batch in mapper(eval_func, to_evaluate.loc[task_ids, df.columns], meta=meta):
61+
res.append(batch)
62+
63+
if db is not None:
64+
# pylint: disable=cell-var-from-loop
65+
batch_complete = to_evaluate[df.columns].join(batch, how="right")
66+
batch_cols = [col for col in batch_complete.columns if col != "exception"]
67+
batch_complete.apply(
68+
lambda row: db.write(row.name, row[batch_cols].to_dict(), row["exception"]),
69+
axis=1,
70+
)
71+
except (KeyboardInterrupt, SystemExit) as ex: # pragma: no cover
72+
# To save dataframe even if program is killed
73+
logger.warning("Stopping mapper loop. Reason: %r", ex)
74+
return pd.concat(res)
75+
76+
77+
def _evaluate_basic(
78+
to_evaluate, df, evaluation_function, func_args, func_kwargs, mapper, task_ids, db
79+
):
80+
81+
res = []
82+
# Setup the function to apply to the data
83+
eval_func = partial(
84+
_try_evaluation,
85+
evaluation_function=evaluation_function,
86+
func_args=func_args,
87+
func_kwargs=func_kwargs,
88+
)
89+
90+
# Split the data into rows
91+
arg_list = list(to_evaluate.loc[task_ids, df.columns].to_dict("index").items())
92+
93+
try:
94+
# Compute and collect the results
95+
for task_id, result, exception in tqdm(mapper(eval_func, arg_list), total=len(task_ids)):
96+
res.append(dict({"df_index": task_id, "exception": exception}, **result))
97+
98+
# Save the results into the DB
99+
if db is not None:
100+
db.write(
101+
task_id, result, exception, **to_evaluate.loc[task_id, df.columns].to_dict()
102+
)
103+
except (KeyboardInterrupt, SystemExit) as ex:
104+
# To save dataframe even if program is killed
105+
logger.warning("Stopping mapper loop. Reason: %r", ex)
106+
107+
# Gather the results to the output DataFrame
108+
return pd.DataFrame(res).set_index("df_index")
109+
110+
111+
def _prepare_db(db_url, to_evaluate, df, resume, task_ids):
112+
""" "Prepare db."""
113+
db = DataBase(db_url)
114+
115+
if resume and db.exists("df"):
116+
logger.info("Load data from SQL database")
117+
db.reflect("df")
118+
previous_results = db.load()
119+
previous_idx = previous_results.index
120+
bad_cols = [
121+
col
122+
for col in df.columns
123+
if not to_evaluate.loc[previous_idx, col].equals(previous_results[col])
124+
]
125+
if bad_cols:
126+
raise ValueError(
127+
f"The following columns have different values from the DataBase: {bad_cols}"
128+
)
129+
to_evaluate.loc[previous_results.index] = previous_results.loc[previous_results.index]
130+
task_ids = task_ids.difference(previous_results.index)
131+
else:
132+
logger.info("Create SQL database")
133+
db.create(to_evaluate)
134+
135+
return db, db.get_url()
136+
137+
29138
def evaluate(
30139
df,
31140
evaluation_function,
@@ -84,10 +193,14 @@ def evaluate(
84193

85194
# Set default new columns
86195
if new_columns is None:
196+
if isinstance(parallel_factory, DaskDataFrameFactory):
197+
raise ValueError("The new columns must be provided when using 'DaskDataFrameFactory'")
87198
new_columns = []
88199

89200
# Setup internal and new columns
90-
to_evaluate["exception"] = None
201+
if any(col[0] == "exception" for col in new_columns):
202+
raise ValueError("The 'exception' column can not be one of the new columns")
203+
new_columns = [["exception", None]] + new_columns # Don't use append to keep the input as is.
91204
for new_column in new_columns:
92205
to_evaluate[new_column[0]] = new_column[1]
93206

@@ -96,29 +209,7 @@ def evaluate(
96209
logger.info("Not using SQL backend to save iterations")
97210
db = None
98211
else:
99-
db = DataBase(db_url)
100-
101-
if resume and db.exists("df"):
102-
logger.info("Load data from SQL database")
103-
db.reflect("df")
104-
previous_results = db.load()
105-
previous_idx = previous_results.index
106-
bad_cols = [
107-
col
108-
for col in df.columns
109-
if not to_evaluate.loc[previous_idx, col].equals(previous_results[col])
110-
]
111-
if bad_cols:
112-
raise ValueError(
113-
f"The following columns have different values from the DataBase: {bad_cols}"
114-
)
115-
to_evaluate.loc[previous_results.index] = previous_results.loc[previous_results.index]
116-
task_ids = task_ids.difference(previous_results.index)
117-
else:
118-
logger.info("Create SQL database")
119-
db.create(to_evaluate)
120-
121-
db_url = db.get_url()
212+
db, db_url = _prepare_db(db_url, to_evaluate, df, resume, task_ids)
122213

123214
# Log the number of tasks to run
124215
if len(task_ids) > 0:
@@ -130,36 +221,29 @@ def evaluate(
130221
# Get the factory mapper
131222
mapper = parallel_factory.get_mapper(**mapper_kwargs)
132223

133-
# Setup the function to apply to the data
134-
eval_func = partial(
135-
_try_evaluation,
136-
evaluation_function=evaluation_function,
137-
func_args=func_args,
138-
func_kwargs=func_kwargs,
139-
)
140-
141-
# Split the data into rows
142-
arg_list = list(to_evaluate.loc[task_ids, df.columns].to_dict("index").items())
143-
144-
res = []
145-
try:
146-
# Compute and collect the results
147-
for task_id, result, exception in tqdm(mapper(eval_func, arg_list), total=len(task_ids)):
148-
res.append(dict({"df_index": task_id, "exception": exception}, **result))
149-
150-
# Save the results into the DB
151-
if db is not None:
152-
db.write(
153-
task_id, result, exception, **to_evaluate.loc[task_id, df.columns].to_dict()
154-
)
155-
156-
except (KeyboardInterrupt, SystemExit) as ex:
157-
# To save dataframe even if program is killed
158-
logger.warning("Stopping mapper loop. Reason: %r", ex)
159-
160-
# Gather the results to the output DataFrame
161-
res_df = pd.DataFrame(res)
162-
res_df.set_index("df_index", inplace=True)
224+
if isinstance(parallel_factory, DaskDataFrameFactory):
225+
res_df = _evaluate_dataframe(
226+
to_evaluate,
227+
df,
228+
evaluation_function,
229+
func_args,
230+
func_kwargs,
231+
new_columns,
232+
mapper,
233+
task_ids,
234+
db,
235+
)
236+
else:
237+
res_df = _evaluate_basic(
238+
to_evaluate,
239+
df,
240+
evaluation_function,
241+
func_args,
242+
func_kwargs,
243+
mapper,
244+
task_ids,
245+
db,
246+
)
163247
to_evaluate.loc[res_df.index, res_df.columns] = res_df
164248

165249
return to_evaluate

0 commit comments

Comments
 (0)