11"""Module to evaluate generic functions on rows of dataframe."""
22import logging
3- import sqlite3
43import sys
54import traceback
65from functools import partial
7- from pathlib import Path
86
9- import pandas as pd
107from tqdm import tqdm
118
9+ from bluepyparallel .database import DataBase
1210from bluepyparallel .parallel import init_parallel_factory
1311
1412logger = logging .getLogger (__name__ )
1513
1614
17- def _try_evaluation (task , evaluation_function , db_filename , func_args , func_kwargs ):
15+ def _try_evaluation (task , evaluation_function , func_args , func_kwargs ):
1816 """Encapsulate the evaluation function into a try/except and isolate to record exceptions."""
1917 task_id , task_args = task
2018 try :
@@ -24,47 +22,16 @@ def _try_evaluation(task, evaluation_function, db_filename, func_args, func_kwar
2422 result = None
2523 exception = "" .join (traceback .format_exception (* sys .exc_info ()))
2624 logger .exception ("Exception for ID=%s: %s" , task_id , exception )
27-
28- # Save the results into the DB
29- if db_filename is not None :
30- _write_to_sql (db_filename , task_id , result , exception )
3125 return task_id , result , exception
3226
3327
34- def _create_database (df , db_filename = "db.sql" ):
35- """Create a sqlite database from dataframe."""
36- with sqlite3 .connect (str (db_filename )) as db :
37- df .to_sql ("df" , db , if_exists = "replace" , index_label = "df_index" )
38-
39-
40- def _load_database_to_dataframe (db_filename = "db.sql" ):
41- """Load an SQL database and construct the dataframe."""
42- with sqlite3 .connect (str (db_filename )) as db :
43- return pd .read_sql ("SELECT * FROM df" , db , index_col = "df_index" )
44-
45-
46- def _write_to_sql (db_filename , task_id , results , exception ):
47- """Write row data to SQL."""
48- with sqlite3 .connect (str (db_filename )) as db :
49- if results is not None :
50- keys , vals = zip (* results .items ())
51- query_keys = ", " .join ([f"{ k } =?" for k in keys ])
52- else :
53- query_keys = "exception=?"
54- vals = [exception ]
55- db .execute (
56- "UPDATE df SET " + query_keys + " WHERE df_index=?" ,
57- list (vals ) + [task_id ],
58- )
59-
60-
6128def evaluate (
6229 df ,
6330 evaluation_function ,
6431 new_columns = None ,
6532 resume = False ,
6633 parallel_factory = None ,
67- db_filename = None ,
34+ db_url = None ,
6835 func_args = None ,
6936 func_kwargs = None ,
7037):
@@ -80,9 +47,11 @@ def evaluate(
8047 resume (bool): if True, it will use only compute the empty rows of the database,
8148 if False, it will ecrase or generate the database.
8249 parallel_factory (ParallelFactory): parallel factory instance.
83- db_filename (str): if a file path is given, SQL backend will be enabled and will use this
84- path for the SQLite database. Should not be used when evaluations are numerous and
85- fast, in order to avoid the overhead of communication with SQL database.
50+ db_url (str): should be DB URL that can be interpreted by SQLAlchemy or can be a file path
51+ that is interpreted as a SQLite database. If an URL is given, the SQL backend will be
52+ enabled to store results and allowing future resume. Should not be used when
53+ evaluations are numerous and fast, in order to avoid the overhead of communication with
54+ SQL database.
8655 func_args (list): the arguments to pass to the evaluation_function.
8756 func_kwargs (dict): the keyword arguments to pass to the evaluation_function.
8857
@@ -115,12 +84,16 @@ def evaluate(
11584 to_evaluate [new_column [0 ]] = new_column [1 ]
11685
11786 # Create the database if required and get the task ids to run
118- if db_filename is None :
87+ if db_url is None :
11988 logger .info ("Not using SQL backend to save iterations" )
120- elif resume :
121- logger .info ("Load data from SQL database" )
122- if Path (db_filename ).exists ():
123- previous_results = _load_database_to_dataframe (db_filename = db_filename )
89+ db = None
90+ else :
91+ db = DataBase (db_url )
92+
93+ if resume and db .exists ("df" ):
94+ logger .info ("Load data from SQL database" )
95+ db .reflect ("df" )
96+ previous_results = db .load ()
12497 previous_idx = previous_results .index
12598 bad_cols = [
12699 col
@@ -134,10 +107,10 @@ def evaluate(
134107 to_evaluate .loc [previous_results .index ] = previous_results .loc [previous_results .index ]
135108 task_ids = task_ids .difference (previous_results .index )
136109 else :
137- _create_database ( to_evaluate , db_filename = db_filename )
138- else :
139- logger . info ( "Create SQL database" )
140- _create_database ( to_evaluate , db_filename = db_filename )
110+ logger . info ( "Create SQL database" )
111+ db . create ( to_evaluate )
112+
113+ db_url = db . get_url ( )
141114
142115 # Log the number of tasks to run
143116 if len (task_ids ) > 0 :
@@ -153,16 +126,21 @@ def evaluate(
153126 eval_func = partial (
154127 _try_evaluation ,
155128 evaluation_function = evaluation_function ,
156- db_filename = db_filename ,
157129 func_args = func_args ,
158130 func_kwargs = func_kwargs ,
159131 )
160132
161133 # Split the data into rows
162- arg_list = list (to_evaluate .loc [task_ids ].to_dict ("index" ).items ())
134+ arg_list = list (to_evaluate .loc [task_ids , df . columns ].to_dict ("index" ).items ())
163135
164136 try :
165137 for task_id , results , exception in tqdm (mapper (eval_func , arg_list ), total = len (task_ids )):
138+ # Save the results into the DB
139+ if db is not None :
140+ db .write (
141+ task_id , results , exception , ** to_evaluate .loc [task_id , df .columns ].to_dict ()
142+ )
143+
166144 # Save the results into the DataFrame
167145 if results is not None :
168146 to_evaluate .loc [task_id , results .keys ()] = list (results .values ())
0 commit comments