Skip to content

Commit 12f332a

Browse files
committed
Generate features as part of dataloader for optimized multiprocessing and batching
1 parent 984f563 commit 12f332a

File tree

5 files changed

+99
-148
lines changed

5 files changed

+99
-148
lines changed

deeplc/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
__all__ = ["DeepLC"]
1+
# __all__ = ["DeepLC"]
22

3-
from importlib.metadata import version
3+
# from importlib.metadata import version
44

5-
__version__ = version("deeplc")
5+
# __version__ = version("deeplc")
66

77

8-
from deeplc.deeplc import DeepLC
8+
# from deeplc.deeplc import DeepLC

deeplc/_data.py

Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,44 @@
11
import torch
2+
from psm_utils.psm_list import PSMList
23
from torch.utils.data import Dataset
34

5+
from deeplc._features import encode_peptidoform
46

5-
class DeepLCDataset(Dataset):
6-
"""
7-
Custom Dataset class for DeepLC used for loading features from peptide sequences.
8-
9-
Parameters
10-
----------
11-
X : ndarray
12-
Feature matrix for input data.
13-
X_sum : ndarray
14-
Feature matrix for sum of input data.
15-
X_global : ndarray
16-
Feature matrix for global input data.
17-
X_hc : ndarray
18-
Feature matrix for high-order context features.
19-
target : ndarray, optional
20-
The target retention times. Default is None.
21-
"""
227

23-
def __init__(self, X, X_sum, X_global, X_hc, target=None):
24-
self.X = torch.from_numpy(X).float()
25-
self.X_sum = torch.from_numpy(X_sum).float()
26-
self.X_global = torch.from_numpy(X_global).float()
27-
self.X_hc = torch.from_numpy(X_hc).float()
8+
class DeepLCDataset(Dataset):
9+
"""Custom Dataset class for DeepLC used for loading features from peptide sequences."""
2810

29-
if target is not None:
30-
self.target = torch.from_numpy(target).float() # Add target values if provided
11+
def __init__(self, psm_list: PSMList, add_ccs_features: bool = False):
12+
self.psm_list = psm_list
13+
self.add_ccs_features = add_ccs_features
14+
15+
self._targets = self._get_targets(psm_list)
16+
17+
@staticmethod
18+
def _get_targets(psm_list: PSMList) -> torch.Tensor | None:
19+
retention_times = [psm.retention_time for psm in psm_list]
20+
if None not in retention_times:
21+
return torch.tensor(retention_times, dtype=torch.float32)
3122
else:
32-
self.target = None # If no target is provided, set it to None
23+
return None
3324

3425
def __len__(self):
35-
return self.X.shape[0]
26+
return len(self.psm_list)
3627

37-
def __getitem__(self, idx):
38-
if self.target is not None:
39-
# Return both features and target during training
40-
return (
41-
self.X[idx],
42-
self.X_sum[idx],
43-
self.X_global[idx],
44-
self.X_hc[idx],
45-
self.target[idx],
46-
)
47-
else:
48-
# Return only features during prediction
49-
return (self.X[idx], self.X_sum[idx], self.X_global[idx], self.X_hc[idx])
28+
def __getitem__(self, idx) -> tuple:
29+
if not isinstance(idx, int):
30+
raise TypeError(f"Index must be an integer, got {type(idx)} instead.")
31+
features = encode_peptidoform(
32+
self.psm_list[idx].peptidoform,
33+
add_ccs_features=self.add_ccs_features
34+
)
35+
feature_tuples = (
36+
torch.from_numpy(features["matrix"]).to(dtype=torch.float32),
37+
torch.from_numpy(features["matrix_sum"]).to(dtype=torch.float32),
38+
torch.from_numpy(features["matrix_global"]).to(dtype=torch.float32),
39+
torch.from_numpy(features["matrix_hc"]).to(dtype=torch.float32),
40+
)
41+
targets = self._targets[idx] if self._targets is not None else torch.full_like(
42+
feature_tuples[0], fill_value=float('nan'), dtype=torch.float32
43+
)
44+
return feature_tuples, targets
Lines changed: 5 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _compute_rolling_sum(matrix: np.ndarray, n: int = 2) -> np.ndarray:
148148

149149
def encode_peptidoform(
150150
peptidoform: Peptidoform | str,
151-
predict_ccs: bool = False,
151+
add_ccs_features: bool = False,
152152
padding_length: int = 60,
153153
positions: set[int] | None = None,
154154
positions_pos: set[int] | None = None,
@@ -188,7 +188,7 @@ def encode_peptidoform(
188188

189189
matrix_all = np.sum(std_matrix, axis=0)
190190
matrix_all = np.append(matrix_all, seq_len)
191-
if predict_ccs:
191+
if add_ccs_features:
192192
matrix_all = np.append(matrix_all, (seq.count("H")) / seq_len)
193193
matrix_all = np.append(
194194
matrix_all, (seq.count("F") + seq.count("W") + seq.count("Y")) / seq_len
@@ -198,50 +198,12 @@ def encode_peptidoform(
198198
matrix_all = np.append(matrix_all, charge)
199199

200200
matrix_sum = _compute_rolling_sum(std_matrix.T, n=2)[:, ::2].T
201+
202+
matrix_global = np.concatenate([matrix_all, pos_matrix.flatten()])
201203

202204
return {
203205
"matrix": std_matrix,
204206
"matrix_sum": matrix_sum,
205-
"matrix_all": matrix_all,
206-
"pos_matrix": pos_matrix.flatten(),
207+
"matrix_global": matrix_global,
207208
"matrix_hc": onehot_matrix,
208209
}
209-
210-
211-
def extract_features(
212-
peptidoforms: list[str | Peptidoform] | PSMList,
213-
predict_ccs: bool = False,
214-
) -> dict[str, dict[int, np.ndarray]]:
215-
"""Extract features for all peptidoforms."""
216-
if isinstance(peptidoforms, PSMList):
217-
peptidoforms = [psm.peptidoform for psm in peptidoforms]
218-
219-
encodings = [encode_peptidoform(pf, predict_ccs=predict_ccs) for pf in peptidoforms]
220-
aggregated_encodings = aggregate_encodings(encodings)
221-
222-
return aggregated_encodings
223-
224-
225-
def aggregate_encodings(
226-
encodings: list[dict[str, np.ndarray]],
227-
) -> dict[str, dict[int, np.ndarray]]:
228-
"""Aggregate list of encodings into single dictionary."""
229-
return {key: {i: enc[key] for i, enc in enumerate(encodings)} for key in encodings[0]}
230-
231-
232-
def unpack_features(
233-
features: dict[str, np.ndarray],
234-
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
235-
"""Unpack dictionary with features to numpy arrays."""
236-
X_sum = np.stack(list(features["matrix_sum"].values()))
237-
X_global = np.concatenate(
238-
(
239-
np.stack(list(features["matrix_all"].values())),
240-
np.stack(list(features["pos_matrix"].values())),
241-
),
242-
axis=1,
243-
)
244-
X_hc = np.stack(list(features["matrix_hc"].values()))
245-
X_main = np.stack(list(features["matrix"].values()))
246-
247-
return X_sum, X_global, X_hc, X_main
Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
LOGGER = logging.getLogger(__name__)
1616

1717

18-
class Calibrator(ABC):
19-
"""Abstract base class for retention time calibrators."""
18+
class Calibration(ABC):
19+
"""Abstract base class for retention time calibration."""
2020

2121
@abstractmethod
2222
def __init__(self, *args, **kwargs):
@@ -27,9 +27,33 @@ def fit(measured_tr: np.ndarray, predicted_tr: np.ndarray) -> None: ...
2727

2828
@abstractmethod
2929
def transform(tr: np.ndarray) -> np.ndarray: ...
30+
3031

32+
class IdentityCalibration(Calibration):
33+
"""No calibration, just returns the predicted retention times."""
3134

32-
class PiecewiseLinearCalibrator(Calibrator):
35+
def fit(self, measured_tr: np.ndarray, predicted_tr: np.ndarray) -> None:
36+
"""No fitting required for NoCalibration."""
37+
pass
38+
39+
def transform(self, tr: np.ndarray) -> np.ndarray:
40+
"""
41+
Transform the predicted retention times without any calibration.
42+
43+
Parameters
44+
----------
45+
tr
46+
Retention times to be transformed.
47+
48+
Returns
49+
-------
50+
np.ndarray
51+
Transformed retention times (same as input).
52+
"""
53+
return tr
54+
55+
56+
class PiecewiseLinearCalibration(Calibration):
3357
def __init__(
3458
self,
3559
split_cal: int = 50,
@@ -202,7 +226,7 @@ def transform(self, tr: np.ndarray) -> np.ndarray:
202226
return np.array(cal_preds)
203227

204228

205-
class SplineTransformerCalibrator(Calibrator):
229+
class SplineTransformerCalibration(Calibration):
206230
def __init__(self):
207231
"""SplineTransformer calibration for retention time."""
208232
super().__init__()

0 commit comments

Comments
 (0)