3434import torch
3535from psm_utils import PSM , Peptidoform , PSMList
3636from psm_utils .io import read_file
37+ from rich .progress import track
38+ from torch .nn import Module
3739from torch .utils .data import DataLoader
3840
39- from deeplc .calibration import Calibration , SplineTransformerCalibration
40- from deeplc ._data import DeepLCDataset
41+ from deeplc ._data import DeepLCDataset , get_targets
4142from deeplc ._finetune import DeepLCFineTuner
43+ from deeplc .calibration import Calibration , SplineTransformerCalibration
4244
4345# If CLI/GUI/frozen: disable warnings before importing
4446IS_CLI_GUI = os .path .basename (sys .argv [0 ]) in ["deeplc" , "deeplc-gui" ]
4951# Default models, will be used if no other is specified. If no best model is
5052# selected during calibration, the first model in the list will be used.
5153DEEPLC_DIR = os .path .dirname (os .path .realpath (__file__ ))
52- DEFAULT_MODELS = [
53- "mods/full_hc_PXD005573_pub_1fd8363d9af9dcad3be7553c39396960.pt" ,
54- "mods/full_hc_PXD005573_pub_8c22d89667368f2f02ad996469ba157e.pt" ,
55- "mods/full_hc_PXD005573_pub_cb975cfdd4105f97efa0b3afffe075cc.pt" ,
56- ]
57- DEFAULT_MODELS = [os .path .join (DEEPLC_DIR , m ) for m in DEFAULT_MODELS ]
54+ DEFAULT_MODEL = "mods/full_hc_PXD005573_pub_1fd8363d9af9dcad3be7553c39396960.pt"
55+ DEFAULT_MODEL = os .path .join (DEEPLC_DIR , DEFAULT_MODEL )
5856
5957
6058logger = logging .getLogger (__name__ )
6159
6260
6361def predict (
6462 psm_list : PSMList | None = None ,
65- model_files : str | list [str ] | None = None ,
66- calibrator : Calibration | None = None ,
63+ model : str | list [str ] | None = None ,
64+ num_workers : int = 4 ,
6765 batch_size : int = 1024 ,
68- single_model_mode : bool = False ,
6966):
7067 """
7168 Make predictions for sequences, in batches if required.
@@ -74,68 +71,49 @@ def predict(
7471 ----------
7572 psm_list
7673 PSMList object containing the peptidoforms to predict for.
77- model_files
78- Model file (or files) to use for prediction. If None, the default model is used.
79- calibrator
80- Calibrator object to use for calibration. If None, no calibration is performed.
74+ model_file
75+ Model file to use for prediction. If None, the default model is used.
8176 batch_size
82- How many samples per batch to load (default: 1).
83- single_model_mode
84- Whether to use a single model instead of multiple default models. Only applies if
85- model_file is None.
77+ How many samples per batch to load (default: 1024).
8678
8779 Returns
8880 -------
8981 np.array
9082 predictions
9183
9284 """
93- if len (psm_list ) == 0 :
94- return []
85+ # Shortcut if empty PSMList is provided
86+ if not psm_list :
87+ return np .array ([])
9588
96- # Setup dataset and dataloader
97- dataset = DeepLCDataset (psm_list )
98- loader = DataLoader (dataset , batch_size = batch_size , shuffle = False )
99-
100- if model_files is not None :
101- if isinstance (model_files , str ):
102- model_files = [model_files ]
103- elif isinstance (model_files , list ):
104- model_files = model_files
105- else :
106- raise ValueError ("Invalid model name provided." )
107- else :
108- model_files = [DEFAULT_MODELS [0 ]] if single_model_mode else DEFAULT_MODELS
89+ # Avoid predicting repeated PSMs
90+ unique_peptidoforms , inverse_indices = _get_unique_peptidoforms (psm_list )
10991
110- # Get predictions; iterate over models if multiple were selected
111- model_predictions = []
112- for model_f in model_files :
113- # Load model
114- model = torch .load (model_f , weights_only = False , map_location = torch .device ("cpu" ))
115- model .eval ()
92+ # Setup dataset and dataloader
93+ dataset = DeepLCDataset (unique_peptidoforms , target_retention_times = None )
94+ loader = DataLoader (dataset , num_workers = num_workers , batch_size = batch_size , shuffle = False )
11695
117- # Predict
118- ret_preds = []
119- with torch .no_grad ():
120- for features , _ in loader :
121- batch_preds = model (* features )
122- ret_preds .append (batch_preds .detach ().cpu ().numpy ())
123- raise Exception ()
96+ # Get model files
97+ model = model or DEFAULT_MODEL
12498
125- # Concatenate predictions
126- ret_preds = np . concatenate ( ret_preds , axis = 0 )
99+ # Check device availability
100+ device = torch . device ( "cuda" if torch . cuda . is_available () else "cpu" )
127101
128- # TODO: Bring outside of model loop?
129- # Calibrate
130- if calibrator is not None :
131- ret_preds = calibrator .transform (ret_preds )
102+ # Load model on specified device
103+ model = _load_model (model = model , device = device , eval = True )
132104
133- model_predictions .append (ret_preds )
105+ # Predict
106+ predictions = []
107+ with torch .no_grad ():
108+ for features , _ in track (loader ):
109+ features = [feature_tensor .to (device ) for feature_tensor in features ]
110+ batch_preds = model (* features )
111+ predictions .append (batch_preds .detach ().cpu ().numpy ())
134112
135- # Average the predictions from all models
136- averaged_predictions = np .mean ( model_predictions , axis = 0 )
113+ # Concatenate predictions and reorder to match original PSMList order
114+ predictions = np .concatenate ( predictions , axis = 0 )[ inverse_indices ]
137115
138- return averaged_predictions
116+ return predictions
139117
140118
141119# TODO: Split-of transfer learning?
@@ -144,13 +122,12 @@ def calibrate(
144122 model_files : str | list [str ] | None = None ,
145123 location_retraining_models : str = "" ,
146124 sample_for_calibration_curve : int | None = None ,
147- return_plotly_report = False ,
148125 n_jobs : int | None = None ,
149126 batch_size : int = int (1e6 ),
150127 fine_tune : bool = False ,
151128 n_epochs : int = 20 ,
152129 calibrator : Calibration | None = None ,
153- ) -> dict | None :
130+ ) -> tuple [ str , dict [ str , Calibration ]] :
154131 """
155132 Find best model and calibrate.
156133
@@ -159,14 +136,12 @@ def calibrate(
159136 psm_list
160137 PSMList object containing the peptidoforms to predict for.
161138 model_files
162- Path to one or mode models to test and calibrat for. If a list of models is passed,
139+ Path to one or mode models to test and calibrate for. If a list of models is passed,
163140 the best performing one on the calibration data will be selected.
164141 location_retraining_models
165142 Location to save the retraining models; if None, a temporary directory is used.
166143 sample_for_calibration_curve
167144 Number of PSMs to sample for calibration curve; if None, all provided PSMs are used.
168- return_plotly_report
169- Whether to return a plotly report with the calibration results.
170145 n_jobs
171146 Number of jobs to use for parallel processing; if None, the number of CPU cores is used.
172147 batch_size
@@ -180,8 +155,6 @@ def calibrate(
180155
181156 Returns
182157 -------
183- dict | None
184- Dictionary with plotly report information or None.
185158
186159 """
187160 if None in psm_list ["retention_time" ]:
@@ -193,7 +166,7 @@ def calibrate(
193166 calibrator = SplineTransformerCalibration ()
194167
195168 # Ensuring self.model is list of strings
196- model_files = model_files or DEFAULT_MODELS
169+ model_files = model_files or DEFAULT_MODEL
197170 if isinstance (model_files , str ):
198171 model_files = [model_files ]
199172
@@ -239,13 +212,13 @@ def calibrate(
239212
240213 best_perf = float ("inf" )
241214 best_calibrator = {}
242- mod_calibrator = {}
215+ model_calibrators = {}
243216 pred_dict = {}
244217 mod_dict = {}
245218
246219 for model_name in model_files :
247220 logger .debug (f"Trying out the following model: { model_name } " )
248- predicted_tr = predict (psm_list , calibrate = False , model_name = model_name )
221+ predicted_tr = predict (psm_list , calibrator = calibrator , model_name = model_name )
249222
250223 model_calibrator = copy .deepcopy (calibrator )
251224
@@ -265,7 +238,7 @@ def calibrate(
265238 m_group_name = "_" .join (m_name .split ("_" )[:- 1 ])
266239 pred_dict .setdefault (m_group_name , {})[model_name ] = preds
267240 mod_dict .setdefault (m_group_name , {})[model_name ] = model_name
268- mod_calibrator .setdefault (m_group_name , {})[model_name ] = model_calibrator
241+ model_calibrators .setdefault (m_group_name , {})[model_name ] = model_calibrator
269242
270243 # Find best-performing model, including each model's calibration
271244 for m_name in pred_dict :
@@ -279,13 +252,12 @@ def calibrate(
279252 m_group_name = m_name
280253
281254 # TODO is deepcopy really required?
282- best_calibrator = copy .deepcopy (mod_calibrator [m_group_name ])
255+ best_calibrator = copy .deepcopy (model_calibrators [m_group_name ])
283256 best_model = copy .deepcopy (mod_dict [m_group_name ])
284257 best_perf = perf
285258
286259 logger .debug (f"Model with the best performance got selected: { best_model } " )
287260
288-
289261 return best_model , best_calibrator
290262
291263
@@ -304,3 +276,42 @@ def _file_to_psm_list(input_file: str | Path) -> PSMList:
304276 psm_list .rename_modifications (mapper )
305277
306278 return psm_list
279+
280+
281+ def _get_unique_peptidoforms (psm_list : PSMList ) -> tuple [PSMList , np .ndarray ]:
282+ """Get PSMs with unique peptidoforms and their inverse indices."""
283+ peptidoform_strings = np .array ([str (psm .peptidoform ) for psm in psm_list ])
284+ unique_peptidoforms , inverse_indices = np .unique (peptidoform_strings , return_inverse = True )
285+ return unique_peptidoforms , inverse_indices
286+
287+
288+ def _load_model (
289+ model : Module | Path | str | None = None ,
290+ device : str | None = None ,
291+ eval : bool = False ,
292+ ) -> Module :
293+ """Load a model from a file or return the default model if none is provided."""
294+ # If no model is provided, use the default model
295+ model = model or DEFAULT_MODEL
296+
297+ # If device is not specified, use the default device (GPU if available, else CPU)
298+ device = device or torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
299+
300+ # Load model from file if a path is provided
301+ if isinstance (model , str | Path ):
302+ model = torch .load (model , weights_only = False , map_location = device )
303+ elif not isinstance (model , Module ):
304+ raise TypeError (f"Expected a PyTorch Module or a file path, got { type (model )} instead." )
305+
306+ # Ensure the model is on the specified device
307+ model .to (device )
308+
309+ # Set the model to evaluation or training mode based on the eval flag
310+ if eval :
311+ model .eval ()
312+ else :
313+ model .train ()
314+
315+ return model
316+
317+
0 commit comments