@@ -112,7 +112,9 @@ def __init__(self, X, X_sum, X_global, X_hc, target=None):
112112 self .X_hc = torch .from_numpy (X_hc ).float ()
113113
114114 if target is not None :
115- self .target = torch .from_numpy (target ).float () # Add target values if provided
115+ self .target = torch .from_numpy (
116+ target
117+ ).float () # Add target values if provided
116118 else :
117119 self .target = None # If no target is provided, set it to None
118120
@@ -212,7 +214,8 @@ def fine_tune(self):
212214 val_loader = self .prepare_data (val_dataset , shuffle = False )
213215
214216 optimizer = torch .optim .Adam (
215- filter (lambda p : p .requires_grad , self .model .parameters ()), lr = self .learning_rate
217+ filter (lambda p : p .requires_grad , self .model .parameters ()),
218+ lr = self .learning_rate ,
216219 )
217220 loss_fn = torch .nn .L1Loss ()
218221 best_model_wts = copy .deepcopy (self .model .state_dict ())
@@ -241,7 +244,9 @@ def fine_tune(self):
241244 for batch in val_loader :
242245 batch_X , batch_X_sum , batch_X_global , batch_X_hc , target = batch
243246 target = target .view (- 1 , 1 )
244- outputs = self .model (batch_X , batch_X_sum , batch_X_global , batch_X_hc )
247+ outputs = self .model (
248+ batch_X , batch_X_sum , batch_X_global , batch_X_hc
249+ )
245250 val_loss += loss_fn (outputs , target ).item ()
246251 avg_val_loss = val_loss / len (val_loader )
247252
@@ -415,7 +420,9 @@ def __str__(self):
415420 """
416421
417422 @staticmethod
418- def _get_model_paths (passed_model_path : str | None , single_model_mode : bool ) -> list [str ]:
423+ def _get_model_paths (
424+ passed_model_path : str | None , single_model_mode : bool
425+ ) -> list [str ]:
419426 """Get the model paths based on the passed model path and the single model mode."""
420427 if passed_model_path :
421428 return [passed_model_path ]
@@ -425,6 +432,35 @@ def _get_model_paths(passed_model_path: str | None, single_model_mode: bool) ->
425432
426433 return DEFAULT_MODELS
427434
435+ def _prepare_feature_matrices (self , psm_list ):
436+ """
437+ Extract features in parallel and assemble the four input matrices.
438+
439+ Parameters
440+ ----------
441+ psm_list : list of PSM
442+ List of peptide‐spectrum matches for which to extract features.
443+
444+ Returns
445+ -------
446+ X : ndarray, shape (n_peptides, n_features)
447+ X_sum : ndarray, shape (n_peptides, n_sum_features)
448+ X_global : ndarray, shape (n_peptides, n_global_features * 2)
449+ X_hc : ndarray, shape (n_peptides, n_hc_features)
450+ """
451+ feats = self .do_f_extraction_psm_list_parallel (psm_list )
452+ X = np .stack (list (feats ["matrix" ].values ()))
453+ X_sum = np .stack (list (feats ["matrix_sum" ].values ()))
454+ X_global = np .concatenate (
455+ (
456+ np .stack (list (feats ["matrix_all" ].values ())),
457+ np .stack (list (feats ["pos_matrix" ].values ())),
458+ ),
459+ axis = 1 ,
460+ )
461+ X_hc = np .stack (list (feats ["matrix_hc" ].values ()))
462+ return X , X_sum , X_global , X_hc
463+
428464 def _extract_features (
429465 self ,
430466 peptidoforms : list [str | Peptidoform ] | PSMList ,
@@ -437,24 +473,31 @@ def _extract_features(
437473 logger .debug ("Running feature extraction in single-threaded mode..." )
438474 if self .n_jobs <= 1 :
439475 encodings = [
440- encode_peptidoform (pf , predict_ccs = self .predict_ccs ) for pf in peptidoforms
476+ encode_peptidoform (pf , predict_ccs = self .predict_ccs )
477+ for pf in peptidoforms
441478 ]
442479
443480 else :
444481 logger .debug ("Preparing feature extraction with Dask" )
445482 # Process peptidoforms in larger chunks to reduce task overhead.
446- peptidoform_strings = [str (pep ) for pep in peptidoforms ] # Faster pickling of strings
483+ peptidoform_strings = [
484+ str (pep ) for pep in peptidoforms
485+ ] # Faster pickling of strings
447486
448487 def chunked_encode (chunk ):
449- return [encode_peptidoform (pf , predict_ccs = self .predict_ccs ) for pf in chunk ]
488+ return [
489+ encode_peptidoform (pf , predict_ccs = self .predict_ccs ) for pf in chunk
490+ ]
450491
451492 tasks = [
452493 delayed (chunked_encode )(peptidoform_strings [i : i + chunk_size ])
453494 for i in range (0 , len (peptidoform_strings ), chunk_size )
454495 ]
455496
456497 logger .debug ("Starting feature extraction with Dask" )
457- chunks_encodings = compute (* tasks , scheduler = "processes" , workers = self .n_jobs )
498+ chunks_encodings = compute (
499+ * tasks , scheduler = "processes" , workers = self .n_jobs
500+ )
458501
459502 # Flatten the list of lists.
460503 encodings = [enc for chunk in chunks_encodings for enc in chunk ]
@@ -486,7 +529,9 @@ def _apply_calibration_core(
486529
487530 # Use spline model within the range of X
488531 within_range = (uncal_preds >= cal_min ) & (uncal_preds <= cal_max )
489- within_range = within_range .ravel () # Ensure this is a 1D array for proper indexing
532+ within_range = (
533+ within_range .ravel ()
534+ ) # Ensure this is a 1D array for proper indexing
490535
491536 # Create a prediction array initialized with spline predictions
492537 cal_preds = np .copy (y_pred_spline )
@@ -501,19 +546,27 @@ def _apply_calibration_core(
501546 else :
502547 for uncal_pred in uncal_preds :
503548 try :
504- slope , intercept = cal_dict [str (round (uncal_pred , self .bin_distance ))]
549+ slope , intercept = cal_dict [
550+ str (round (uncal_pred , self .bin_distance ))
551+ ]
505552 cal_preds .append (slope * (uncal_pred ) + intercept )
506553 except KeyError :
507554 # outside of the prediction range ... use the last
508555 # calibration curve
509556 if uncal_pred <= cal_min :
510- slope , intercept = cal_dict [str (round (cal_min , self .bin_distance ))]
557+ slope , intercept = cal_dict [
558+ str (round (cal_min , self .bin_distance ))
559+ ]
511560 cal_preds .append (slope * (uncal_pred ) + intercept )
512561 elif uncal_pred >= cal_max :
513- slope , intercept = cal_dict [str (round (cal_max , self .bin_distance ))]
562+ slope , intercept = cal_dict [
563+ str (round (cal_max , self .bin_distance ))
564+ ]
514565 cal_preds .append (slope * (uncal_pred ) + intercept )
515566 else :
516- slope , intercept = cal_dict [str (round (cal_max , self .bin_distance ))]
567+ slope , intercept = cal_dict [
568+ str (round (cal_max , self .bin_distance ))
569+ ]
517570 cal_preds .append (slope * (uncal_pred ) + intercept )
518571
519572 return np .array (cal_preds )
@@ -648,7 +701,9 @@ def make_preds(
648701 elif infile is not None :
649702 psm_list = _file_to_psm_list (infile )
650703 else :
651- raise ValueError ("Either `psm_list` or `seq_df` or `infile` must be provided." )
704+ raise ValueError (
705+ "Either `psm_list` or `seq_df` or `infile` must be provided."
706+ )
652707
653708 if len (psm_list ) == 0 :
654709 logger .warning ("No PSMs to predict for." )
@@ -692,7 +747,9 @@ def make_preds(
692747 )
693748 )
694749 # Average the predictions from all models
695- ret_preds = np .array ([sum (a ) / len (a ) for a in zip (* ret_preds , strict = True )])
750+ ret_preds = np .array (
751+ [sum (a ) / len (a ) for a in zip (* ret_preds , strict = True )]
752+ )
696753 # ret_preds = np.mean(model_predictions, axis=0)
697754
698755 else :
@@ -740,7 +797,9 @@ def _calibrate_preds_pygam(
740797 spline_model = linear_model
741798 linear_model_right = linear_model
742799 else :
743- spline = SplineTransformer (degree = 4 , n_knots = int (len (measured_tr ) / 500 ) + 5 )
800+ spline = SplineTransformer (
801+ degree = 4 , n_knots = int (len (measured_tr ) / 500 ) + 5
802+ )
744803 spline_model = make_pipeline (spline , LinearRegression ())
745804 spline_model .fit (predicted_tr .reshape (- 1 , 1 ), measured_tr )
746805
@@ -832,9 +891,13 @@ def _calibrate_preds_piecewise_linear(
832891 "peptides for fitting the calibration curve."
833892 )
834893 if len (mtr_mean ) == 0 :
835- raise CalibrationError ("The measured tr list is empty, not able to calibrate" )
894+ raise CalibrationError (
895+ "The measured tr list is empty, not able to calibrate"
896+ )
836897 if len (ptr_mean ) == 0 :
837- raise CalibrationError ("The predicted tr list is empty, not able to calibrate" )
898+ raise CalibrationError (
899+ "The predicted tr list is empty, not able to calibrate"
900+ )
838901
839902 # calculate calibration curves
840903 for i in range (0 , len (ptr_mean )):
@@ -913,7 +976,9 @@ def calibrate_preds(
913976 elif infile is not None :
914977 psm_list = _file_to_psm_list (infile )
915978 else :
916- raise ValueError ("Either `psm_list` or `seq_df` or `infile` must be provided." )
979+ raise ValueError (
980+ "Either `psm_list` or `seq_df` or `infile` must be provided."
981+ )
917982
918983 # Getting measured retention time either from measured_tr or provided PSMs
919984 if not measured_tr :
@@ -946,7 +1011,9 @@ def calibrate_preds(
9461011 X , X_sum , X_global , X_hc = self ._prepare_feature_matrices (psm_list )
9471012 dataset = DeepLCDataset (X , X_sum , X_global , X_hc , np .array (measured_tr ))
9481013
949- base_model_path = self .model [0 ] if isinstance (self .model , list ) else self .model
1014+ base_model_path = (
1015+ self .model [0 ] if isinstance (self .model , list ) else self .model
1016+ )
9501017 base_model = torch .load (
9511018 base_model_path , weights_only = False , map_location = torch .device ("cpu" )
9521019 )
@@ -982,29 +1049,42 @@ def calibrate_preds(
9821049
9831050 for model_name in self .model :
9841051 logger .debug (f"Trying out the following model: { model_name } " )
985- predicted_tr = self .make_preds (psm_list , calibrate = False , mod_name = model_name )
1052+ predicted_tr = self .make_preds (
1053+ psm_list , calibrate = False , mod_name = model_name
1054+ )
9861055
9871056 if self .pygam_calibration :
988- calibrate_output = self ._calibrate_preds_pygam (measured_tr , predicted_tr )
1057+ calibrate_output = self ._calibrate_preds_pygam (
1058+ measured_tr , predicted_tr
1059+ )
9891060 else :
9901061 calibrate_output = self ._calibrate_preds_piecewise_linear (
9911062 measured_tr , predicted_tr , use_median = use_median
9921063 )
993- self .calibrate_min , self .calibrate_max , self .calibrate_dict = calibrate_output
1064+ self .calibrate_min , self .calibrate_max , self .calibrate_dict = (
1065+ calibrate_output
1066+ )
9941067 # TODO: Currently, calibration dict can be both a dict (linear) or a list of models
9951068 # (PyGAM)... This should be handled better in the future.
9961069
9971070 # Skip this model if calibrate_dict is empty
9981071 # TODO: Should this do something when using PyGAM and calibrate_dict is a list?
999- if isinstance (self .calibrate_dict , dict ) and len (self .calibrate_dict .keys ()) == 0 :
1072+ if (
1073+ isinstance (self .calibrate_dict , dict )
1074+ and len (self .calibrate_dict .keys ()) == 0
1075+ ):
10001076 continue
10011077
10021078 m_name = model_name .split ("/" )[- 1 ]
10031079
10041080 # Get new predictions with calibration
1005- preds = self .make_preds (psm_list , calibrate = True , seq_df = seq_df , mod_name = model_name )
1081+ preds = self .make_preds (
1082+ psm_list , calibrate = True , seq_df = seq_df , mod_name = model_name
1083+ )
10061084
1007- m_group_name = "deepcallc" if self .deepcallc_mod else "_" .join (m_name .split ("_" )[:- 1 ])
1085+ m_group_name = (
1086+ "deepcallc" if self .deepcallc_mod else "_" .join (m_name .split ("_" )[:- 1 ])
1087+ )
10081088 m = model_name
10091089 try :
10101090 pred_dict [m_group_name ][m ] = preds
@@ -1026,13 +1106,18 @@ def calibrate_preds(
10261106 mod_calibrate_max_dict [m_group_name ][m ] = self .calibrate_max
10271107
10281108 for m_name in pred_dict :
1029- preds = [sum (a ) / len (a ) for a in zip (* list (pred_dict [m_name ].values ()), strict = True )]
1109+ preds = [
1110+ sum (a ) / len (a )
1111+ for a in zip (* list (pred_dict [m_name ].values ()), strict = True )
1112+ ]
10301113 if len (measured_tr ) == 0 :
10311114 perf = sum (abs (seq_df ["tr" ] - preds ))
10321115 else :
10331116 perf = sum (abs (np .array (measured_tr ) - np .array (preds )))
10341117
1035- logger .debug (f"For { m_name } model got a performance of: { perf / len (preds )} " )
1118+ logger .debug (
1119+ f"For { m_name } model got a performance of: { perf / len (preds )} "
1120+ )
10361121
10371122 if perf < best_perf :
10381123 m_group_name = "deepcallc" if self .deepcallc_mod else m_name
@@ -1072,7 +1157,9 @@ def calibrate_preds(
10721157 ],
10731158 )
10741159 plotly_return_dict ["scatter" ] = deeplc .plot .scatter (plotly_df )
1075- plotly_return_dict ["baseline_dist" ] = deeplc .plot .distribution_baseline (plotly_df )
1160+ plotly_return_dict ["baseline_dist" ] = deeplc .plot .distribution_baseline (
1161+ plotly_df
1162+ )
10761163 return plotly_return_dict
10771164
10781165 return None
@@ -1124,7 +1211,9 @@ def create_psm(args):
11241211 )
11251212
11261213 args_list = list (
1127- zip (sequences , modifications , identifiers , charges , retention_times , strict = True )
1214+ zip (
1215+ sequences , modifications , identifiers , charges , retention_times , strict = True
1216+ )
11281217 )
11291218 tasks = [delayed (create_psm )(args ) for args in args_list ]
11301219 list_of_psms = list (compute (* tasks , scheduler = "processes" ))
0 commit comments