88from tqdm import tqdm
99
1010from bluepyparallel .database import DataBase
11+ from bluepyparallel .parallel import DaskDataFrameFactory
1112from bluepyparallel .parallel import init_parallel_factory
1213
1314logger = logging .getLogger (__name__ )
1617def _try_evaluation (task , evaluation_function , func_args , func_kwargs ):
1718 """Encapsulate the evaluation function into a try/except and isolate to record exceptions."""
1819 task_id , task_args = task
20+
1921 try :
2022 result = evaluation_function (task_args , * func_args , ** func_kwargs )
2123 exception = None
2224 except Exception : # pylint: disable=broad-except
2325 result = {}
2426 exception = "" .join (traceback .format_exception (* sys .exc_info ()))
2527 logger .exception ("Exception for ID=%s: %s" , task_id , exception )
28+
2629 return task_id , result , exception
2730
2831
32+ def _try_evaluation_df (task , evaluation_function , func_args , func_kwargs ):
33+ task_id , result , exception = _try_evaluation (
34+ (task .name , task ),
35+ evaluation_function ,
36+ func_args ,
37+ func_kwargs ,
38+ )
39+ res_cols = list (result .keys ())
40+ result ["exception" ] = exception
41+ return pd .Series (result , name = task_id , dtype = "object" , index = ["exception" ] + res_cols )
42+
43+
44+ def _evaluate_dataframe (
45+ to_evaluate , df , evaluation_function , func_args , func_kwargs , new_columns , mapper , task_ids , db
46+ ):
47+ """Internal evalution function for dask.dataframe."""
48+ # Setup the function to apply to the data
49+ eval_func = partial (
50+ _try_evaluation_df ,
51+ evaluation_function = evaluation_function ,
52+ func_args = func_args ,
53+ func_kwargs = func_kwargs ,
54+ )
55+ meta = pd .DataFrame ({col [0 ]: pd .Series (dtype = "object" ) for col in new_columns })
56+
57+ res = []
58+ try :
59+ # Compute and collect the results
60+ for batch in mapper (eval_func , to_evaluate .loc [task_ids , df .columns ], meta = meta ):
61+ res .append (batch )
62+
63+ if db is not None :
64+ # pylint: disable=cell-var-from-loop
65+ batch_complete = to_evaluate [df .columns ].join (batch , how = "right" )
66+ batch_cols = [col for col in batch_complete .columns if col != "exception" ]
67+ batch_complete .apply (
68+ lambda row : db .write (row .name , row [batch_cols ].to_dict (), row ["exception" ]),
69+ axis = 1 ,
70+ )
71+ except (KeyboardInterrupt , SystemExit ) as ex : # pragma: no cover
72+ # To save dataframe even if program is killed
73+ logger .warning ("Stopping mapper loop. Reason: %r" , ex )
74+ return pd .concat (res )
75+
76+
77+ def _evaluate_basic (
78+ to_evaluate , df , evaluation_function , func_args , func_kwargs , mapper , task_ids , db
79+ ):
80+
81+ res = []
82+ # Setup the function to apply to the data
83+ eval_func = partial (
84+ _try_evaluation ,
85+ evaluation_function = evaluation_function ,
86+ func_args = func_args ,
87+ func_kwargs = func_kwargs ,
88+ )
89+
90+ # Split the data into rows
91+ arg_list = list (to_evaluate .loc [task_ids , df .columns ].to_dict ("index" ).items ())
92+
93+ try :
94+ # Compute and collect the results
95+ for task_id , result , exception in tqdm (mapper (eval_func , arg_list ), total = len (task_ids )):
96+ res .append (dict ({"df_index" : task_id , "exception" : exception }, ** result ))
97+
98+ # Save the results into the DB
99+ if db is not None :
100+ db .write (
101+ task_id , result , exception , ** to_evaluate .loc [task_id , df .columns ].to_dict ()
102+ )
103+ except (KeyboardInterrupt , SystemExit ) as ex :
104+ # To save dataframe even if program is killed
105+ logger .warning ("Stopping mapper loop. Reason: %r" , ex )
106+
107+ # Gather the results to the output DataFrame
108+ return pd .DataFrame (res ).set_index ("df_index" )
109+
110+
111+ def _prepare_db (db_url , to_evaluate , df , resume , task_ids ):
112+ """ "Prepare db."""
113+ db = DataBase (db_url )
114+
115+ if resume and db .exists ("df" ):
116+ logger .info ("Load data from SQL database" )
117+ db .reflect ("df" )
118+ previous_results = db .load ()
119+ previous_idx = previous_results .index
120+ bad_cols = [
121+ col
122+ for col in df .columns
123+ if not to_evaluate .loc [previous_idx , col ].equals (previous_results [col ])
124+ ]
125+ if bad_cols :
126+ raise ValueError (
127+ f"The following columns have different values from the DataBase: { bad_cols } "
128+ )
129+ to_evaluate .loc [previous_results .index ] = previous_results .loc [previous_results .index ]
130+ task_ids = task_ids .difference (previous_results .index )
131+ else :
132+ logger .info ("Create SQL database" )
133+ db .create (to_evaluate )
134+
135+ return db , db .get_url ()
136+
137+
29138def evaluate (
30139 df ,
31140 evaluation_function ,
@@ -84,10 +193,14 @@ def evaluate(
84193
85194 # Set default new columns
86195 if new_columns is None :
196+ if isinstance (parallel_factory , DaskDataFrameFactory ):
197+ raise ValueError ("The new columns must be provided when using 'DaskDataFrameFactory'" )
87198 new_columns = []
88199
89200 # Setup internal and new columns
90- to_evaluate ["exception" ] = None
201+ if any (col [0 ] == "exception" for col in new_columns ):
202+ raise ValueError ("The 'exception' column can not be one of the new columns" )
203+ new_columns = [["exception" , None ]] + new_columns # Don't use append to keep the input as is.
91204 for new_column in new_columns :
92205 to_evaluate [new_column [0 ]] = new_column [1 ]
93206
@@ -96,29 +209,7 @@ def evaluate(
96209 logger .info ("Not using SQL backend to save iterations" )
97210 db = None
98211 else :
99- db = DataBase (db_url )
100-
101- if resume and db .exists ("df" ):
102- logger .info ("Load data from SQL database" )
103- db .reflect ("df" )
104- previous_results = db .load ()
105- previous_idx = previous_results .index
106- bad_cols = [
107- col
108- for col in df .columns
109- if not to_evaluate .loc [previous_idx , col ].equals (previous_results [col ])
110- ]
111- if bad_cols :
112- raise ValueError (
113- f"The following columns have different values from the DataBase: { bad_cols } "
114- )
115- to_evaluate .loc [previous_results .index ] = previous_results .loc [previous_results .index ]
116- task_ids = task_ids .difference (previous_results .index )
117- else :
118- logger .info ("Create SQL database" )
119- db .create (to_evaluate )
120-
121- db_url = db .get_url ()
212+ db , db_url = _prepare_db (db_url , to_evaluate , df , resume , task_ids )
122213
123214 # Log the number of tasks to run
124215 if len (task_ids ) > 0 :
@@ -130,36 +221,29 @@ def evaluate(
130221 # Get the factory mapper
131222 mapper = parallel_factory .get_mapper (** mapper_kwargs )
132223
133- # Setup the function to apply to the data
134- eval_func = partial (
135- _try_evaluation ,
136- evaluation_function = evaluation_function ,
137- func_args = func_args ,
138- func_kwargs = func_kwargs ,
139- )
140-
141- # Split the data into rows
142- arg_list = list (to_evaluate .loc [task_ids , df .columns ].to_dict ("index" ).items ())
143-
144- res = []
145- try :
146- # Compute and collect the results
147- for task_id , result , exception in tqdm (mapper (eval_func , arg_list ), total = len (task_ids )):
148- res .append (dict ({"df_index" : task_id , "exception" : exception }, ** result ))
149-
150- # Save the results into the DB
151- if db is not None :
152- db .write (
153- task_id , result , exception , ** to_evaluate .loc [task_id , df .columns ].to_dict ()
154- )
155-
156- except (KeyboardInterrupt , SystemExit ) as ex :
157- # To save dataframe even if program is killed
158- logger .warning ("Stopping mapper loop. Reason: %r" , ex )
159-
160- # Gather the results to the output DataFrame
161- res_df = pd .DataFrame (res )
162- res_df .set_index ("df_index" , inplace = True )
224+ if isinstance (parallel_factory , DaskDataFrameFactory ):
225+ res_df = _evaluate_dataframe (
226+ to_evaluate ,
227+ df ,
228+ evaluation_function ,
229+ func_args ,
230+ func_kwargs ,
231+ new_columns ,
232+ mapper ,
233+ task_ids ,
234+ db ,
235+ )
236+ else :
237+ res_df = _evaluate_basic (
238+ to_evaluate ,
239+ df ,
240+ evaluation_function ,
241+ func_args ,
242+ func_kwargs ,
243+ mapper ,
244+ task_ids ,
245+ db ,
246+ )
163247 to_evaluate .loc [res_df .index , res_df .columns ] = res_df
164248
165249 return to_evaluate
0 commit comments