Skip to content

Commit 8ebbd6c

Browse files
committed
Simplify predict function and only predict unique peptidoforms:
- Remove ensemble-model mode (different kernel sizes). - Split of model loading to separate function. - Make dataset take peptidoforms instead of PSMs. - Get unique peptidoforms before predicting and keep inverse index for mapping back predictions to input PSM list.
1 parent 12f332a commit 8ebbd6c

File tree

2 files changed

+107
-87
lines changed

2 files changed

+107
-87
lines changed

deeplc/_data.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import numpy as np
12
import torch
2-
from psm_utils.psm_list import PSMList
3+
from psm_utils import Peptidoform, PSMList
34
from torch.utils.data import Dataset
45

56
from deeplc._features import encode_peptidoform
@@ -8,28 +9,24 @@
89
class DeepLCDataset(Dataset):
910
"""Custom Dataset class for DeepLC used for loading features from peptide sequences."""
1011

11-
def __init__(self, psm_list: PSMList, add_ccs_features: bool = False):
12-
self.psm_list = psm_list
12+
def __init__(
13+
self,
14+
peptidoforms: list[Peptidoform | str],
15+
target_retention_times: np.ndarray | None = None,
16+
add_ccs_features: bool = False
17+
):
18+
self.peptidoforms = peptidoforms
19+
self.target_retention_times = target_retention_times
1320
self.add_ccs_features = add_ccs_features
14-
15-
self._targets = self._get_targets(psm_list)
1621

17-
@staticmethod
18-
def _get_targets(psm_list: PSMList) -> torch.Tensor | None:
19-
retention_times = [psm.retention_time for psm in psm_list]
20-
if None not in retention_times:
21-
return torch.tensor(retention_times, dtype=torch.float32)
22-
else:
23-
return None
24-
2522
def __len__(self):
26-
return len(self.psm_list)
23+
return len(self.peptidoforms)
2724

2825
def __getitem__(self, idx) -> tuple:
2926
if not isinstance(idx, int):
3027
raise TypeError(f"Index must be an integer, got {type(idx)} instead.")
3128
features = encode_peptidoform(
32-
self.psm_list[idx].peptidoform,
29+
self.peptidoforms[idx],
3330
add_ccs_features=self.add_ccs_features
3431
)
3532
feature_tuples = (
@@ -38,7 +35,19 @@ def __getitem__(self, idx) -> tuple:
3835
torch.from_numpy(features["matrix_global"]).to(dtype=torch.float32),
3936
torch.from_numpy(features["matrix_hc"]).to(dtype=torch.float32),
4037
)
41-
targets = self._targets[idx] if self._targets is not None else torch.full_like(
42-
feature_tuples[0], fill_value=float('nan'), dtype=torch.float32
38+
targets = (
39+
self.target_retention_times[idx]
40+
if self.target_retention_times is not None
41+
else torch.full_like(
42+
feature_tuples[0], fill_value=float('nan'), dtype=torch.float32
43+
)
4344
)
4445
return feature_tuples, targets
46+
47+
48+
def get_targets(psm_list: PSMList) -> np.ndarray | None:
49+
retention_times = psm_list["retention_time"]
50+
if None not in retention_times:
51+
return torch.tensor(retention_times, dtype=torch.float32)
52+
else:
53+
return None

deeplc/deeplc.py

Lines changed: 81 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,13 @@
3434
import torch
3535
from psm_utils import PSM, Peptidoform, PSMList
3636
from psm_utils.io import read_file
37+
from rich.progress import track
38+
from torch.nn import Module
3739
from torch.utils.data import DataLoader
3840

39-
from deeplc.calibration import Calibration, SplineTransformerCalibration
40-
from deeplc._data import DeepLCDataset
41+
from deeplc._data import DeepLCDataset, get_targets
4142
from deeplc._finetune import DeepLCFineTuner
43+
from deeplc.calibration import Calibration, SplineTransformerCalibration
4244

4345
# If CLI/GUI/frozen: disable warnings before importing
4446
IS_CLI_GUI = os.path.basename(sys.argv[0]) in ["deeplc", "deeplc-gui"]
@@ -49,23 +51,18 @@
4951
# Default models, will be used if no other is specified. If no best model is
5052
# selected during calibration, the first model in the list will be used.
5153
DEEPLC_DIR = os.path.dirname(os.path.realpath(__file__))
52-
DEFAULT_MODELS = [
53-
"mods/full_hc_PXD005573_pub_1fd8363d9af9dcad3be7553c39396960.pt",
54-
"mods/full_hc_PXD005573_pub_8c22d89667368f2f02ad996469ba157e.pt",
55-
"mods/full_hc_PXD005573_pub_cb975cfdd4105f97efa0b3afffe075cc.pt",
56-
]
57-
DEFAULT_MODELS = [os.path.join(DEEPLC_DIR, m) for m in DEFAULT_MODELS]
54+
DEFAULT_MODEL = "mods/full_hc_PXD005573_pub_1fd8363d9af9dcad3be7553c39396960.pt"
55+
DEFAULT_MODEL = os.path.join(DEEPLC_DIR, DEFAULT_MODEL)
5856

5957

6058
logger = logging.getLogger(__name__)
6159

6260

6361
def predict(
6462
psm_list: PSMList | None = None,
65-
model_files: str | list[str] | None = None,
66-
calibrator: Calibration | None = None,
63+
model: str | list[str] | None = None,
64+
num_workers: int = 4,
6765
batch_size: int = 1024,
68-
single_model_mode: bool = False,
6966
):
7067
"""
7168
Make predictions for sequences, in batches if required.
@@ -74,68 +71,49 @@ def predict(
7471
----------
7572
psm_list
7673
PSMList object containing the peptidoforms to predict for.
77-
model_files
78-
Model file (or files) to use for prediction. If None, the default model is used.
79-
calibrator
80-
Calibrator object to use for calibration. If None, no calibration is performed.
74+
model_file
75+
Model file to use for prediction. If None, the default model is used.
8176
batch_size
82-
How many samples per batch to load (default: 1).
83-
single_model_mode
84-
Whether to use a single model instead of multiple default models. Only applies if
85-
model_file is None.
77+
How many samples per batch to load (default: 1024).
8678
8779
Returns
8880
-------
8981
np.array
9082
predictions
9183
9284
"""
93-
if len(psm_list) == 0:
94-
return []
85+
# Shortcut if empty PSMList is provided
86+
if not psm_list:
87+
return np.array([])
9588

96-
# Setup dataset and dataloader
97-
dataset = DeepLCDataset(psm_list)
98-
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
99-
100-
if model_files is not None:
101-
if isinstance(model_files, str):
102-
model_files = [model_files]
103-
elif isinstance(model_files, list):
104-
model_files = model_files
105-
else:
106-
raise ValueError("Invalid model name provided.")
107-
else:
108-
model_files = [DEFAULT_MODELS[0]] if single_model_mode else DEFAULT_MODELS
89+
# Avoid predicting repeated PSMs
90+
unique_peptidoforms, inverse_indices = _get_unique_peptidoforms(psm_list)
10991

110-
# Get predictions; iterate over models if multiple were selected
111-
model_predictions = []
112-
for model_f in model_files:
113-
# Load model
114-
model = torch.load(model_f, weights_only=False, map_location=torch.device("cpu"))
115-
model.eval()
92+
# Setup dataset and dataloader
93+
dataset = DeepLCDataset(unique_peptidoforms, target_retention_times=None)
94+
loader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, shuffle=False)
11695

117-
# Predict
118-
ret_preds = []
119-
with torch.no_grad():
120-
for features, _ in loader:
121-
batch_preds = model(*features)
122-
ret_preds.append(batch_preds.detach().cpu().numpy())
123-
raise Exception()
96+
# Get model files
97+
model = model or DEFAULT_MODEL
12498

125-
# Concatenate predictions
126-
ret_preds = np.concatenate(ret_preds, axis=0)
99+
# Check device availability
100+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
127101

128-
# TODO: Bring outside of model loop?
129-
# Calibrate
130-
if calibrator is not None:
131-
ret_preds = calibrator.transform(ret_preds)
102+
# Load model on specified device
103+
model = _load_model(model=model, device=device, eval=True)
132104

133-
model_predictions.append(ret_preds)
105+
# Predict
106+
predictions = []
107+
with torch.no_grad():
108+
for features, _ in track(loader):
109+
features = [feature_tensor.to(device) for feature_tensor in features]
110+
batch_preds = model(*features)
111+
predictions.append(batch_preds.detach().cpu().numpy())
134112

135-
# Average the predictions from all models
136-
averaged_predictions = np.mean(model_predictions, axis=0)
113+
# Concatenate predictions and reorder to match original PSMList order
114+
predictions = np.concatenate(predictions, axis=0)[inverse_indices]
137115

138-
return averaged_predictions
116+
return predictions
139117

140118

141119
# TODO: Split-of transfer learning?
@@ -144,13 +122,12 @@ def calibrate(
144122
model_files: str | list[str] | None = None,
145123
location_retraining_models: str = "",
146124
sample_for_calibration_curve: int | None = None,
147-
return_plotly_report=False,
148125
n_jobs: int | None = None,
149126
batch_size: int = int(1e6),
150127
fine_tune: bool = False,
151128
n_epochs: int = 20,
152129
calibrator: Calibration | None = None,
153-
) -> dict | None:
130+
) -> tuple[str, dict[str, Calibration]]:
154131
"""
155132
Find best model and calibrate.
156133
@@ -159,14 +136,12 @@ def calibrate(
159136
psm_list
160137
PSMList object containing the peptidoforms to predict for.
161138
model_files
162-
Path to one or mode models to test and calibrat for. If a list of models is passed,
139+
Path to one or mode models to test and calibrate for. If a list of models is passed,
163140
the best performing one on the calibration data will be selected.
164141
location_retraining_models
165142
Location to save the retraining models; if None, a temporary directory is used.
166143
sample_for_calibration_curve
167144
Number of PSMs to sample for calibration curve; if None, all provided PSMs are used.
168-
return_plotly_report
169-
Whether to return a plotly report with the calibration results.
170145
n_jobs
171146
Number of jobs to use for parallel processing; if None, the number of CPU cores is used.
172147
batch_size
@@ -180,8 +155,6 @@ def calibrate(
180155
181156
Returns
182157
-------
183-
dict | None
184-
Dictionary with plotly report information or None.
185158
186159
"""
187160
if None in psm_list["retention_time"]:
@@ -193,7 +166,7 @@ def calibrate(
193166
calibrator = SplineTransformerCalibration()
194167

195168
# Ensuring self.model is list of strings
196-
model_files = model_files or DEFAULT_MODELS
169+
model_files = model_files or DEFAULT_MODEL
197170
if isinstance(model_files, str):
198171
model_files = [model_files]
199172

@@ -239,13 +212,13 @@ def calibrate(
239212

240213
best_perf = float("inf")
241214
best_calibrator = {}
242-
mod_calibrator = {}
215+
model_calibrators = {}
243216
pred_dict = {}
244217
mod_dict = {}
245218

246219
for model_name in model_files:
247220
logger.debug(f"Trying out the following model: {model_name}")
248-
predicted_tr = predict(psm_list, calibrate=False, model_name=model_name)
221+
predicted_tr = predict(psm_list, calibrator=calibrator, model_name=model_name)
249222

250223
model_calibrator = copy.deepcopy(calibrator)
251224

@@ -265,7 +238,7 @@ def calibrate(
265238
m_group_name = "_".join(m_name.split("_")[:-1])
266239
pred_dict.setdefault(m_group_name, {})[model_name] = preds
267240
mod_dict.setdefault(m_group_name, {})[model_name] = model_name
268-
mod_calibrator.setdefault(m_group_name, {})[model_name] = model_calibrator
241+
model_calibrators.setdefault(m_group_name, {})[model_name] = model_calibrator
269242

270243
# Find best-performing model, including each model's calibration
271244
for m_name in pred_dict:
@@ -279,13 +252,12 @@ def calibrate(
279252
m_group_name = m_name
280253

281254
# TODO is deepcopy really required?
282-
best_calibrator = copy.deepcopy(mod_calibrator[m_group_name])
255+
best_calibrator = copy.deepcopy(model_calibrators[m_group_name])
283256
best_model = copy.deepcopy(mod_dict[m_group_name])
284257
best_perf = perf
285258

286259
logger.debug(f"Model with the best performance got selected: {best_model}")
287260

288-
289261
return best_model, best_calibrator
290262

291263

@@ -304,3 +276,42 @@ def _file_to_psm_list(input_file: str | Path) -> PSMList:
304276
psm_list.rename_modifications(mapper)
305277

306278
return psm_list
279+
280+
281+
def _get_unique_peptidoforms(psm_list: PSMList) -> tuple[PSMList, np.ndarray]:
282+
"""Get PSMs with unique peptidoforms and their inverse indices."""
283+
peptidoform_strings = np.array([str(psm.peptidoform) for psm in psm_list])
284+
unique_peptidoforms, inverse_indices = np.unique(peptidoform_strings, return_inverse=True)
285+
return unique_peptidoforms, inverse_indices
286+
287+
288+
def _load_model(
289+
model: Module | Path | str | None = None,
290+
device: str | None = None,
291+
eval: bool = False,
292+
) -> Module:
293+
"""Load a model from a file or return the default model if none is provided."""
294+
# If no model is provided, use the default model
295+
model = model or DEFAULT_MODEL
296+
297+
# If device is not specified, use the default device (GPU if available, else CPU)
298+
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
299+
300+
# Load model from file if a path is provided
301+
if isinstance(model, str | Path):
302+
model = torch.load(model, weights_only=False, map_location=device)
303+
elif not isinstance(model, Module):
304+
raise TypeError(f"Expected a PyTorch Module or a file path, got {type(model)} instead.")
305+
306+
# Ensure the model is on the specified device
307+
model.to(device)
308+
309+
# Set the model to evaluation or training mode based on the eval flag
310+
if eval:
311+
model.eval()
312+
else:
313+
model.train()
314+
315+
return model
316+
317+

0 commit comments

Comments
 (0)