33import sqlite3
44import sys
55import traceback
6+ from collections import defaultdict
67from functools import partial
78from pathlib import Path
89
910import pandas as pd
1011from tqdm import tqdm
1112
13+ from bluepyparallel .parallel import init_parallel_factory
14+
1215logger = logging .getLogger (__name__ )
1316
1417
@@ -18,35 +21,34 @@ def _try_evaluation(task, evaluation_function=None):
1821 try :
1922 result = evaluation_function (task_args )
2023 exception = ""
21-
2224 except Exception : # pylint: disable=broad-except
2325 result = None
2426 exception = "" .join (traceback .format_exception (* sys .exc_info ()))
25- logger .exception ("Exception for combo %s" , exception )
27+ logger .exception ("Exception for ID=%s: %s" , task_id , exception )
2628 return task_id , result , exception
2729
2830
2931def _create_database (df , new_columns , db_filename = "db.sql" ):
3032 """Create a sqlite database from dataframe."""
31- df . loc [:, "exception" ] = None
33+ df [ "exception" ] = None
3234 for new_column in new_columns :
33- df . loc [:, new_column [0 ]] = new_column [1 ]
34- df . loc [:, "to_run_" + new_column [0 ]] = 1
35- with sqlite3 .connect (db_filename ) as db :
35+ df [ new_column [0 ]] = new_column [1 ]
36+ df [ "to_run_" + new_column [0 ]] = 1
37+ with sqlite3 .connect (str ( db_filename ) ) as db :
3638 df .to_sql ("df" , db , if_exists = "replace" , index_label = "index" )
3739 return df
3840
3941
4042def _load_database_to_dataframe (db_filename = "db.sql" ):
41- """Load an sql database and construct the dataframe."""
42- with sqlite3 .connect (db_filename ) as db :
43+ """Load an SQL database and construct the dataframe."""
44+ with sqlite3 .connect (str ( db_filename ) ) as db :
4345 out = pd .read_sql ("SELECT * FROM df" , db , index_col = "index" )
4446 return out
4547
4648
4749def _write_to_sql (db_filename , task_id , results , new_columns , exception ):
48- """Write row data to sql ."""
49- with sqlite3 .connect (db_filename ) as db :
50+ """Write row data to SQL ."""
51+ with sqlite3 .connect (str ( db_filename ) ) as db :
5052 for new_column in new_columns :
5153 res = results [new_column [0 ]] if results is not None else None
5254 db .execute (
@@ -60,11 +62,9 @@ def evaluate(
6062 df ,
6163 evaluation_function ,
6264 new_columns = None ,
63- task_ids = None ,
64- continu = False ,
65+ resume = False ,
6566 parallel_factory = None ,
66- db_filename = "db.sql" ,
67- no_sql = False ,
67+ db_filename = None ,
6868):
6969 """Evaluate and save results in a sqlite database on the fly and return dataframe.
7070
@@ -75,68 +75,61 @@ def evaluate(
7575 and return a dict with keys corresponding to the names in new_columns
7676 new_columns (list): list of names of new column and empty value to save evaluation results,
7777 i.e.: [['result', 0.0], ['valid', False]]
78- task_ids (int): index of dataframe to compute, if None, all will be computed
79- continu (bool): if True, it will use only compute the empty rows of the database,
78+ resume (bool): if True, it will use only compute the empty rows of the database,
8079 if False, it will ecrase or generate the database
8180 parallel_factory (ParallelFactory): parallel factory instance
82- db_filename (str): filename for the sqlite database
83- no_sql (bool): is True, sql backend will be disabled. To use when evaluations are numerous
84- and fast, to avoid the overhead of communication with sql database.
81+ db_filename (str): if a file path is given, SQL backend will be enabled and will use this
82+ path for the SQLite database. Should not be used when evaluations are numerous and
83+ fast, in order to avoid the overhead of communication with SQL database.
84+
8585 Return:
8686 pandas.DataFrame: dataframe with new columns containing computed results
8787 """
88- if task_ids is None :
89- task_ids = df .index
90- else :
91- df = df .loc [task_ids ]
88+ if isinstance (parallel_factory , str ) or parallel_factory is None :
89+ parallel_factory = init_parallel_factory (parallel_factory )
90+
91+ task_ids = df .index
92+
9293 if new_columns is None :
9394 new_columns = [["data" , "" ]]
9495
95- if no_sql :
96- logger .info ("Not using sql backend to save iterations" )
96+ if db_filename is None :
97+ logger .info ("Not using SQL backend to save iterations" )
9798 to_evaluate = df
98- elif continu :
99- logger .info ("Load data from sql database" )
99+ elif resume :
100+ logger .info ("Load data from SQL database" )
100101 if Path (db_filename ).exists ():
101102 to_evaluate = _load_database_to_dataframe (db_filename = db_filename )
103+ task_ids = task_ids .intersection (to_evaluate .index )
102104 else :
103105 to_evaluate = _create_database (df , new_columns , db_filename = db_filename )
104- for new_column in new_columns :
105- task_ids = task_ids [
106- to_evaluate .loc [task_ids , "to_run_" + new_column [0 ]].to_numpy () == 1
107- ]
106+
107+ # Find tasks to run
108+ should_run = (
109+ to_evaluate .loc [task_ids , ["to_run_" + col [0 ] for col in new_columns ]] == 1
110+ ).any (axis = 1 )
111+ task_ids = should_run .loc [should_run ].index
108112 else :
109- logger .info ("Create sql database" )
113+ logger .info ("Create SQL database" )
110114 to_evaluate = _create_database (df , new_columns , db_filename = db_filename )
111115
112- # this is a hack to make it work, otherwise it does not update the entries correctly
113- to_evaluate = _load_database_to_dataframe (db_filename )
114- to_evaluate = _create_database (to_evaluate , new_columns , db_filename = db_filename )
115-
116116 if len (task_ids ) > 0 :
117117 logger .info ("%s rows to compute." , str (len (task_ids )))
118118 else :
119- logger .warning ("WARNING: No rows to compute, something may be wrong" )
119+ logger .warning ("WARNING: No row to compute, something may be wrong" )
120120 return _load_database_to_dataframe (db_filename )
121121
122- if parallel_factory is None :
123- mapper = map
124- else :
125- mapper = parallel_factory .get_mapper ()
122+ mapper = parallel_factory .get_mapper ()
126123
127124 eval_func = partial (_try_evaluation , evaluation_function = evaluation_function )
128- arg_list = enumerate (
129- dict (zip (to_evaluate .columns , row )) for row in to_evaluate .loc [task_ids ].values
130- )
125+ arg_list = to_evaluate .to_dict ("index" ).items ()
131126
132- if no_sql :
133- _results = {}
134- for new_column , new_column_empty in new_columns :
135- _results [new_column ] = len (task_ids ) * [new_column_empty ]
127+ if db_filename is None :
128+ _results = defaultdict (dict )
136129
137130 try :
138131 for task_id , results , exception in tqdm (mapper (eval_func , arg_list ), total = len (task_ids )):
139- if no_sql :
132+ if db_filename is None :
140133 for new_column , _ in new_columns :
141134 _results [new_column ][task_id ] = (
142135 results [new_column ] if results is not None else None
@@ -149,15 +142,11 @@ def evaluate(
149142 new_columns ,
150143 exception ,
151144 )
152-
153- if no_sql :
154- for new_column , data in _results .items ():
155- to_evaluate .loc [:, new_column ] = data
156-
157- # to save dataframe even if program is killed
158145 except (KeyboardInterrupt , SystemExit ) as ex :
146+ # To save dataframe even if program is killed
159147 logger .warning ("Stopping mapper loop. Reason: %r" , ex )
160148
161- if no_sql :
149+ if db_filename is None :
150+ to_evaluate = pd .concat ([to_evaluate , pd .DataFrame (_results )], axis = 1 )
162151 return to_evaluate
163152 return _load_database_to_dataframe (db_filename )
0 commit comments