1313
1414import torch
1515import torch .optim as optim
16+ from torch .optim .lr_scheduler import ReduceLROnPlateau
1617
1718from qlib .data .dataset .weight import Reweighter
1819
@@ -136,6 +137,10 @@ def __init__(
136137 else :
137138 raise NotImplementedError ("optimizer {} is not supported!" .format (optimizer ))
138139
140+ # === ReduceLROnPlateau learning rate scheduler ===
141+ self .lr_scheduler = ReduceLROnPlateau (
142+ self .train_optimizer , mode = "min" , factor = 0.5 , patience = 5 , min_lr = 1e-6 , threshold = 1e-5
143+ )
139144 self .fitted = False
140145 self .dnn_model .to (self .device )
141146
@@ -154,15 +159,15 @@ def loss_fn(self, pred, label, weight=None):
154159 weight = torch .ones_like (label )
155160
156161 if self .loss == "mse" :
157- return self .mse (pred [mask ], label [mask ], weight [mask ])
162+ return self .mse (pred [mask ], label [mask ]. view ( - 1 , 1 ) , weight [mask ])
158163
159164 raise ValueError ("unknown loss `%s`" % self .loss )
160165
161166 def metric_fn (self , pred , label ):
162167 mask = torch .isfinite (label )
163168
164169 if self .metric in ("" , "loss" ):
165- return - self .loss_fn (pred [mask ], label [mask ])
170+ return self .loss_fn (pred [mask ], label [mask ])
166171
167172 raise ValueError ("unknown metric `%s`" % self .metric )
168173
@@ -238,6 +243,8 @@ def fit(
238243
239244 dl_train = dataset .prepare ("train" , col_set = ["feature" , "label" ], data_key = DataHandlerLP .DK_L )
240245 dl_valid = dataset .prepare ("valid" , col_set = ["feature" , "label" ], data_key = DataHandlerLP .DK_L )
246+ self .logger .info (f"Train samples: { len (dl_train )} " )
247+ self .logger .info (f"Valid samples: { len (dl_valid )} " )
241248 if dl_train .empty or dl_valid .empty :
242249 raise ValueError ("Empty data from dataset, please check your dataset config." )
243250
@@ -279,7 +286,7 @@ def fit(
279286
280287 stop_steps = 0
281288 train_loss = 0
282- best_score = - np .inf
289+ best_score = np .inf
283290 best_epoch = 0
284291 evals_result ["train" ] = []
285292 evals_result ["valid" ] = []
@@ -295,13 +302,18 @@ def fit(
295302 self .logger .info ("evaluating..." )
296303 train_loss , train_score = self .test_epoch (train_loader )
297304 val_loss , val_score = self .test_epoch (valid_loader )
298- self .logger .info ("train %.6f, valid %.6f" % (train_score , val_score ))
305+ self .logger .info ("Epoch%d: train %.6f, valid %.6f" % (step , train_score , val_score ))
299306 evals_result ["train" ].append (train_score )
300307 evals_result ["valid" ].append (val_score )
301308
309+ # current_lr = self.train_optimizer.param_groups[0]["lr"]
310+ # self.logger.info("Current learning rate: %.6e" % current_lr)
311+
312+ self .lr_scheduler .step (val_score )
313+
302314 if step == 0 :
303315 best_param = copy .deepcopy (self .dnn_model .state_dict ())
304- if val_score > best_score :
316+ if val_score < best_score :
305317 best_score = val_score
306318 stop_steps = 0
307319 best_epoch = step
@@ -312,7 +324,7 @@ def fit(
312324 self .logger .info ("early stop" )
313325 break
314326
315- self .logger .info ("best score: %.6lf @ %d" % (best_score , best_epoch ))
327+ self .logger .info ("best score: %.6lf @ %d epoch " % (best_score , best_epoch ))
316328 self .dnn_model .load_state_dict (best_param )
317329 torch .save (best_param , save_path )
318330
@@ -329,6 +341,7 @@ def predict(
329341 raise ValueError ("model is not fitted yet!" )
330342
331343 dl_test = dataset .prepare ("test" , col_set = ["feature" , "label" ], data_key = DataHandlerLP .DK_I )
344+ self .logger .info (f"Test samples: { len (dl_test )} " )
332345
333346 if isinstance (dataset , TSDatasetH ):
334347 dl_test .config (fillna_type = "ffill+bfill" ) # process nan brought by dataloader
0 commit comments