@@ -35,7 +35,7 @@ def prepped_array_to_df(data_array, dates, ids, col_names):
3535 return df
3636
3737
38- def take_first_half (df ):
38+ def take_half (df , first_half = True ):
3939 """
4040 filter out the second half of the dates in the predictions. this is to
4141 retain a "test" set of the i/o data for evaluation
@@ -47,9 +47,12 @@ def take_first_half(df):
4747 df .sort_index (inplace = True )
4848 unique_dates = df .index .unique ()
4949 halfway_date = unique_dates [int (len (unique_dates ) / 2 )]
50- df_first_half = df .loc [:halfway_date ]
51- df_first_half .reset_index (inplace = True )
52- return df_first_half
50+ if first_half :
51+ df_half = df .loc [:halfway_date ]
52+ else :
53+ df_half = df .loc [halfway_date :]
54+ df_half .reset_index (inplace = True )
55+ return df_half
5356
5457
5558def unscale_output (y_scl , y_std , y_mean , data_cols , logged_q = False ):
@@ -197,11 +200,16 @@ def predict(model, io_data, partition, outfile, logged_q=False, half_tst=False):
197200 """
198201 io_data = get_data_if_file (io_data )
199202
200- # evaluate training
201- if partition == "trn" or partition == "tst" :
203+ if partition in ["trn" , "tst" , "ver" ]:
202204 pass
203205 else :
204- raise ValueError ('partition arg needs to be "trn" or "tst"' )
206+ raise ValueError ('partition arg needs to be "trn" or "tst" or "ver"' )
207+
208+ if partition == "ver" :
209+ partition = "tst"
210+ tst_partition = "ver"
211+ elif partition == "tst" :
212+ tst_partition = "tst"
205213
206214 num_segs = len (np .unique (io_data ["ids_trn" ]))
207215 y_pred = model .predict (io_data [f"x_{ partition } " ], batch_size = num_segs )
@@ -220,8 +228,12 @@ def predict(model, io_data, partition, outfile, logged_q=False, half_tst=False):
220228 logged_q ,
221229 )
222230
223- if half_tst and partition == "tst" :
224- y_pred_pp = take_first_half (y_pred_pp )
231+ if partition == "tst" :
232+ if half_tst and tst_partition == "tst" :
233+ y_pred_pp = take_half (y_pred_pp , first_half = True )
234+
235+ if half_tst and tst_partition == "ver" :
236+ y_pred_pp = take_half (y_pred_pp , first_half = False )
225237
226238 y_pred_pp .to_feather (outfile )
227239 return y_pred_pp
@@ -372,13 +384,14 @@ def overall_metrics(
372384
373385
374386def combined_metrics (
375- pred_trn , pred_tst , obs_temp , obs_flow , grp = None , outfile = None
387+ pred_trn , pred_tst , obs_temp , obs_flow , pred_ver = None , grp = None , outfile = None
376388):
377389 """
378390 calculate the metrics for flow and temp and training and test sets for a
379391 given grouping
380392 :param pred_trn: [str] path to training prediction feather file
381393 :param pred_tst: [str] path to testing prediction feather file
394+ :param pred_tst: [str] path to verification prediction feather file
382395 :param obs_temp: [str] path to observations temperature zarr file
383396 :param obs_flow: [str] path to observations flow zarr file
384397 :param group: [str or list] which group the metrics should be computed for.
@@ -393,6 +406,10 @@ def combined_metrics(
393406 tst_temp = overall_metrics (pred_tst , obs_temp , "temp" , "tst" , grp )
394407 tst_flow = overall_metrics (pred_tst , obs_flow , "flow" , "tst" , grp )
395408 df_all = [trn_temp , tst_temp , trn_flow , tst_flow ]
409+ if pred_ver :
410+ ver_temp = overall_metrics (pred_ver , obs_temp , "temp" , "ver" , grp )
411+ ver_flow = overall_metrics (pred_ver , obs_flow , "flow" , "ver" , grp )
412+ df_all .extend ([ver_temp , ver_flow ])
396413 df_all = pd .concat (df_all , axis = 0 )
397414 if outfile :
398415 df_all .to_csv (outfile , index = False )
0 commit comments