33import sqlite3
44import sys
55import traceback
6- from collections import defaultdict
76from functools import partial
87from pathlib import Path
98
1514logger = logging .getLogger (__name__ )
1615
1716
18- def _try_evaluation (task , evaluation_function = None ):
17+ def _try_evaluation (task , evaluation_function , db_filename , func_args , func_kwargs ):
1918 """Encapsulate the evaluation function into a try/except and isolate to record exceptions."""
2019 task_id , task_args = task
2120 try :
22- result = evaluation_function (task_args )
23- exception = ""
21+ result = evaluation_function (task_args , * func_args , ** func_kwargs )
22+ exception = None
2423 except Exception : # pylint: disable=broad-except
2524 result = None
2625 exception = "" .join (traceback .format_exception (* sys .exc_info ()))
2726 logger .exception ("Exception for ID=%s: %s" , task_id , exception )
27+
28+ # Save the results into the DB
29+ if db_filename is not None :
30+ _write_to_sql (db_filename , task_id , result , exception )
2831 return task_id , result , exception
2932
3033
31- def _create_database (df , new_columns , db_filename = "db.sql" ):
34+ def _create_database (df , db_filename = "db.sql" ):
3235 """Create a sqlite database from dataframe."""
33- df ["exception" ] = None
34- for new_column in new_columns :
35- df [new_column [0 ]] = new_column [1 ]
36- df ["to_run_" + new_column [0 ]] = 1
3736 with sqlite3 .connect (str (db_filename )) as db :
38- df .to_sql ("df" , db , if_exists = "replace" , index_label = "index" )
39- return df
37+ df .to_sql ("df" , db , if_exists = "replace" , index_label = "df_index" )
4038
4139
4240def _load_database_to_dataframe (db_filename = "db.sql" ):
4341 """Load an SQL database and construct the dataframe."""
4442 with sqlite3 .connect (str (db_filename )) as db :
45- out = pd .read_sql ("SELECT * FROM df" , db , index_col = "index" )
46- return out
43+ return pd .read_sql ("SELECT * FROM df" , db , index_col = "df_index" )
4744
4845
49- def _write_to_sql (db_filename , task_id , results , new_columns , exception ):
46+ def _write_to_sql (db_filename , task_id , results , exception ):
5047 """Write row data to SQL."""
5148 with sqlite3 .connect (str (db_filename )) as db :
52- for new_column in new_columns :
53- res = results [new_column [0 ]] if results is not None else None
54- db .execute (
55- "UPDATE df SET " + new_column [0 ] + "=?, "
56- "exception=?, to_run_" + new_column [0 ] + "=? WHERE `index`=?" ,
57- (res , exception , 0 , task_id ),
58- )
49+ if results is not None :
50+ keys , vals = zip (* results .items ())
51+ query_keys = ", " .join ([f"{ k } =?" for k in keys ])
52+ else :
53+ query_keys = "exception=?"
54+ vals = [exception ]
55+ db .execute (
56+ "UPDATE df SET " + query_keys + " WHERE df_index=?" ,
57+ list (vals ) + [task_id ],
58+ )
5959
6060
6161def evaluate (
@@ -65,88 +65,111 @@ def evaluate(
6565 resume = False ,
6666 parallel_factory = None ,
6767 db_filename = None ,
68+ func_args = None ,
69+ func_kwargs = None ,
6870):
6971 """Evaluate and save results in a sqlite database on the fly and return dataframe.
7072
7173 Args:
72- df (DataFrame): each row contains information for the computation
74+ df (DataFrame): each row contains information for the computation.
7375 evaluation_function (function): function used to evaluate each row,
7476 should have a single argument as list-like containing values of the rows of df,
75- and return a dict with keys corresponding to the names in new_columns
77+ and return a dict with keys corresponding to the names in new_columns.
7678 new_columns (list): list of names of new column and empty value to save evaluation results,
77- i.e.: [['result', 0.0], ['valid', False]]
79+ i.e.: [['result', 0.0], ['valid', False]].
7880 resume (bool): if True, it will use only compute the empty rows of the database,
79- if False, it will ecrase or generate the database
80- parallel_factory (ParallelFactory): parallel factory instance
81+ if False, it will ecrase or generate the database.
82+ parallel_factory (ParallelFactory): parallel factory instance.
8183 db_filename (str): if a file path is given, SQL backend will be enabled and will use this
8284 path for the SQLite database. Should not be used when evaluations are numerous and
8385 fast, in order to avoid the overhead of communication with SQL database.
86+ func_args (list): the arguments to pass to the evaluation_function.
87+ func_kwargs (dict): the keyword arguments to pass to the evaluation_function.
8488
8589 Return:
86- pandas.DataFrame: dataframe with new columns containing computed results
90+ pandas.DataFrame: dataframe with new columns containing the computed results.
8791 """
92+ # Initialize the parallel factory
8893 if isinstance (parallel_factory , str ) or parallel_factory is None :
8994 parallel_factory = init_parallel_factory (parallel_factory )
9095
91- task_ids = df .index
96+ # Set default args
97+ if func_args is None :
98+ func_args = []
99+
100+ # Set default kwargs
101+ if func_kwargs is None :
102+ func_kwargs = {}
92103
104+ # Shallow copy the given DataFrame to add internal rows
105+ to_evaluate = df .copy ()
106+ task_ids = to_evaluate .index
107+
108+ # Set default new columns
93109 if new_columns is None :
94110 new_columns = [["data" , "" ]]
95111
112+ # Setup internal and new columns
113+ to_evaluate ["exception" ] = None
114+ for new_column in new_columns :
115+ to_evaluate [new_column [0 ]] = new_column [1 ]
116+
117+ # Create the database if required and get the task ids to run
96118 if db_filename is None :
97119 logger .info ("Not using SQL backend to save iterations" )
98- to_evaluate = df
99120 elif resume :
100121 logger .info ("Load data from SQL database" )
101122 if Path (db_filename ).exists ():
102- to_evaluate = _load_database_to_dataframe (db_filename = db_filename )
103- task_ids = task_ids .intersection (to_evaluate .index )
123+ previous_results = _load_database_to_dataframe (db_filename = db_filename )
124+ previous_idx = previous_results .index
125+ bad_cols = [
126+ col
127+ for col in df .columns
128+ if not to_evaluate .loc [previous_idx , col ].equals (previous_results [col ])
129+ ]
130+ if bad_cols :
131+ raise ValueError (
132+ f"The following columns have different values from the DataBase: { bad_cols } "
133+ )
134+ to_evaluate .loc [previous_results .index ] = previous_results .loc [previous_results .index ]
135+ task_ids = task_ids .difference (previous_results .index )
104136 else :
105- to_evaluate = _create_database (df , new_columns , db_filename = db_filename )
106-
107- # Find tasks to run
108- should_run = (
109- to_evaluate .loc [task_ids , ["to_run_" + col [0 ] for col in new_columns ]] == 1
110- ).any (axis = 1 )
111- task_ids = should_run .loc [should_run ].index
137+ _create_database (to_evaluate , db_filename = db_filename )
112138 else :
113139 logger .info ("Create SQL database" )
114- to_evaluate = _create_database (df , new_columns , db_filename = db_filename )
140+ _create_database (to_evaluate , db_filename = db_filename )
115141
142+ # Log the number of tasks to run
116143 if len (task_ids ) > 0 :
117144 logger .info ("%s rows to compute." , str (len (task_ids )))
118145 else :
119146 logger .warning ("WARNING: No row to compute, something may be wrong" )
120- return _load_database_to_dataframe ( db_filename )
147+ return to_evaluate
121148
149+ # Get the factory mapper
122150 mapper = parallel_factory .get_mapper ()
123151
124- eval_func = partial (_try_evaluation , evaluation_function = evaluation_function )
125- arg_list = to_evaluate .to_dict ("index" ).items ()
152+ # Setup the function to apply to the data
153+ eval_func = partial (
154+ _try_evaluation ,
155+ evaluation_function = evaluation_function ,
156+ db_filename = db_filename ,
157+ func_args = func_args ,
158+ func_kwargs = func_kwargs ,
159+ )
126160
127- if db_filename is None :
128- _results = defaultdict ( dict )
161+ # Split the data into rows
162+ arg_list = list ( to_evaluate . loc [ task_ids ]. to_dict ( "index" ). items () )
129163
130164 try :
131165 for task_id , results , exception in tqdm (mapper (eval_func , arg_list ), total = len (task_ids )):
132- if db_filename is None :
133- for new_column , _ in new_columns :
134- _results [new_column ][task_id ] = (
135- results [new_column ] if results is not None else None
136- )
137- else :
138- _write_to_sql (
139- db_filename ,
140- task_id ,
141- results ,
142- new_columns ,
143- exception ,
144- )
166+ # Save the results into the DataFrame
167+ if results is not None :
168+ to_evaluate .loc [task_id , results .keys ()] = list (results .values ())
169+ elif exception is not None :
170+ to_evaluate .loc [task_id , "exception" ] = exception
145171 except (KeyboardInterrupt , SystemExit ) as ex :
146172 # To save dataframe even if program is killed
147173 logger .warning ("Stopping mapper loop. Reason: %r" , ex )
148174
149- if db_filename is None :
150- to_evaluate = pd .concat ([to_evaluate , pd .DataFrame (_results )], axis = 1 )
151- return to_evaluate
152- return _load_database_to_dataframe (db_filename )
175+ return to_evaluate
0 commit comments