@@ -232,16 +232,13 @@ def _freeze_layers(self, unfreeze_keywords="33_1"):
232232 for name , param in self .model .named_parameters ():
233233
234234 param .requires_grad = (unfreeze_keywords in name )
235- print (f"[INFO] Trainable parameters:" )
236- for name , param in self .model .named_parameters ():
237- if param .requires_grad :
238- print (f" - { name } " )
235+
239236
240237 def prepare_data (self , data , shuffle = True ):
241238 return DataLoader (data , batch_size = self .batch_size , shuffle = shuffle )
242239
243240 def fine_tune (self ):
244- logger .info ("Starting fine-tuning..." )
241+ logger .debug ("Starting fine-tuning..." )
245242 if self .validation_data is None :
246243 # Split the training data into training and validation sets
247244 val_size = int (len (self .train_data ) * self .validation_split )
@@ -290,17 +287,15 @@ def fine_tune(self):
290287 val_loss += loss_fn (outputs , target ).item ()
291288 avg_val_loss = val_loss / len (val_loader )
292289
293- logger .info (f"Epoch { epoch + 1 } /{ self .epochs } , Loss: { avg_loss :.4f} , Validation Loss: { avg_val_loss :.4f} " )
294- print (f"Epoch { epoch + 1 } /{ self .epochs } , Loss: { avg_loss :.4f} , Validation Loss: { avg_val_loss :.4f} " )
290+ logger .debug (f"Epoch { epoch + 1 } /{ self .epochs } , Loss: { avg_loss :.4f} , Validation Loss: { avg_val_loss :.4f} " )
295291 if avg_val_loss < best_val_loss :
296292 best_val_loss = avg_val_loss
297293 best_model_wts = copy .deepcopy (self .model .state_dict ())
298294 epochs_no_improve = 0
299295 else :
300296 epochs_no_improve += 1
301297 if epochs_no_improve >= self .patience :
302- logger .info (f"Early stopping triggered { epoch + 1 } " )
303- print (f"Early stopping triggered { epoch + 1 } " )
298+ logger .debug (f"Early stopping triggered { epoch + 1 } " )
304299 break
305300 self .model .load_state_dict (best_model_wts )
306301 return self .model
@@ -672,6 +667,37 @@ def do_f_extraction_psm_list_parallel(self, psm_list):
672667
673668 return all_feats
674669
670+ def _prepare_feature_matrices (self , psm_list ):
671+ """
672+ Extract features in parallel and assemble the four input matrices.
673+
674+ Parameters
675+ ----------
676+ psm_list : list of PSM
677+ List of peptide‐spectrum matches for which to extract features.
678+
679+ Returns
680+ -------
681+ X : ndarray, shape (n_peptides, n_features)
682+ X_sum : ndarray, shape (n_peptides, n_sum_features)
683+ X_global : ndarray, shape (n_peptides, n_global_features * 2)
684+ X_hc : ndarray, shape (n_peptides, n_hc_features)
685+ """
686+ feats = self .do_f_extraction_psm_list_parallel (psm_list )
687+ X = np .stack (list (feats ["matrix" ].values ()))
688+ X_sum = np .stack (list (feats ["matrix_sum" ].values ()))
689+ X_global = np .concatenate (
690+ (
691+ np .stack (list (feats ["matrix_all" ].values ())),
692+ np .stack (list (feats ["pos_matrix" ].values ())),
693+ ),
694+ axis = 1 ,
695+ )
696+ X_hc = np .stack (list (feats ["matrix_hc" ].values ()))
697+ return X , X_sum , X_global , X_hc
698+
699+
700+
675701 def calibration_core (self , uncal_preds , cal_dict , cal_min , cal_max ):
676702 """
677703 Perform calibration on uncalibrated predictions.
@@ -818,18 +844,7 @@ def make_preds_core(
818844 if len (X ) == 0 and len (psm_list ) > 0 :
819845 if self .verbose :
820846 logger .debug ("Extracting features for the CNN model ..." )
821- X = self .do_f_extraction_psm_list_parallel (psm_list )
822-
823- X_sum = np .stack (list (X ["matrix_sum" ].values ()))
824- X_global = np .concatenate (
825- (
826- np .stack (list (X ["matrix_all" ].values ())),
827- np .stack (list (X ["pos_matrix" ].values ())),
828- ),
829- axis = 1 ,
830- )
831- X_hc = np .stack (list (X ["matrix_hc" ].values ()))
832- X = np .stack (list (X ["matrix" ].values ()))
847+ X , X_sum , X_global , X_hc = self ._prepare_feature_matrices (psm_list )
833848 elif len (X ) == 0 and len (psm_list ) == 0 :
834849 return []
835850
@@ -940,17 +955,7 @@ def make_preds(
940955 if self .verbose :
941956 logger .debug ("Extracting features for the CNN model ..." )
942957
943- X = self .do_f_extraction_psm_list_parallel (psm_list_t )
944- X_sum = np .stack (list (X ["matrix_sum" ].values ()))
945- X_global = np .concatenate (
946- (
947- np .stack (list (X ["matrix_all" ].values ())),
948- np .stack (list (X ["pos_matrix" ].values ())),
949- ),
950- axis = 1 ,
951- )
952- X_hc = np .stack (list (X ["matrix_hc" ].values ()))
953- X = np .stack (list (X ["matrix" ].values ()))
958+ X , X_sum , X_global , X_hc = self ._prepare_feature_matrices (psm_list_t )
954959 else :
955960 return []
956961
@@ -1410,19 +1415,9 @@ def calibrate_preds(
14101415 temp_pred = []
14111416
14121417 if self .deeplc_retrain :
1413- logger .info ("Preparing for model fine-tuning..." )
1414-
1415- X = self .do_f_extraction_psm_list_parallel (psm_list )
1416- X_sum = np .stack (list (X ["matrix_sum" ].values ()))
1417- X_global = np .concatenate (
1418- (
1419- np .stack (list (X ["matrix_all" ].values ())),
1420- np .stack (list (X ["pos_matrix" ].values ())),
1421- ),
1422- axis = 1 ,
1423- )
1424- X_hc = np .stack (list (X ["matrix_hc" ].values ()))
1425- X = np .stack (list (X ["matrix" ].values ()))
1418+ logger .debug ("Preparing for model fine-tuning..." )
1419+
1420+ X , X_sum , X_global , X_hc = self ._prepare_feature_matrices (psm_list )
14261421 dataset = DeepLCDataset (X , X_sum , X_global , X_hc , np .array (measured_tr ))
14271422
14281423 base_model_path = self .model [0 ] if isinstance (self .model , list ) else self .model
@@ -1448,7 +1443,6 @@ def calibrate_preds(
14481443
14491444 # Define path to save fine-tuned model
14501445 fine_tuned_model_path = os .path .join (t_dir_models , "fine_tuned_model.pth" )
1451- print ("Saving fine-tuned model to:" , fine_tuned_model_path )
14521446 torch .save (fine_tuned_model , fine_tuned_model_path )
14531447 self .model = [fine_tuned_model_path ]
14541448
0 commit comments