Skip to content

Commit cc965d6

Browse files
Read function to prepare matrices
1 parent 57db4c8 commit cc965d6

File tree

1 file changed

+119
-30
lines changed

1 file changed

+119
-30
lines changed

deeplc/deeplc.py

Lines changed: 119 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,9 @@ def __init__(self, X, X_sum, X_global, X_hc, target=None):
112112
self.X_hc = torch.from_numpy(X_hc).float()
113113

114114
if target is not None:
115-
self.target = torch.from_numpy(target).float() # Add target values if provided
115+
self.target = torch.from_numpy(
116+
target
117+
).float() # Add target values if provided
116118
else:
117119
self.target = None # If no target is provided, set it to None
118120

@@ -212,7 +214,8 @@ def fine_tune(self):
212214
val_loader = self.prepare_data(val_dataset, shuffle=False)
213215

214216
optimizer = torch.optim.Adam(
215-
filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.learning_rate
217+
filter(lambda p: p.requires_grad, self.model.parameters()),
218+
lr=self.learning_rate,
216219
)
217220
loss_fn = torch.nn.L1Loss()
218221
best_model_wts = copy.deepcopy(self.model.state_dict())
@@ -241,7 +244,9 @@ def fine_tune(self):
241244
for batch in val_loader:
242245
batch_X, batch_X_sum, batch_X_global, batch_X_hc, target = batch
243246
target = target.view(-1, 1)
244-
outputs = self.model(batch_X, batch_X_sum, batch_X_global, batch_X_hc)
247+
outputs = self.model(
248+
batch_X, batch_X_sum, batch_X_global, batch_X_hc
249+
)
245250
val_loss += loss_fn(outputs, target).item()
246251
avg_val_loss = val_loss / len(val_loader)
247252

@@ -415,7 +420,9 @@ def __str__(self):
415420
"""
416421

417422
@staticmethod
418-
def _get_model_paths(passed_model_path: str | None, single_model_mode: bool) -> list[str]:
423+
def _get_model_paths(
424+
passed_model_path: str | None, single_model_mode: bool
425+
) -> list[str]:
419426
"""Get the model paths based on the passed model path and the single model mode."""
420427
if passed_model_path:
421428
return [passed_model_path]
@@ -425,6 +432,35 @@ def _get_model_paths(passed_model_path: str | None, single_model_mode: bool) ->
425432

426433
return DEFAULT_MODELS
427434

435+
def _prepare_feature_matrices(self, psm_list):
436+
"""
437+
Extract features in parallel and assemble the four input matrices.
438+
439+
Parameters
440+
----------
441+
psm_list : list of PSM
442+
List of peptide‐spectrum matches for which to extract features.
443+
444+
Returns
445+
-------
446+
X : ndarray, shape (n_peptides, n_features)
447+
X_sum : ndarray, shape (n_peptides, n_sum_features)
448+
X_global : ndarray, shape (n_peptides, n_global_features * 2)
449+
X_hc : ndarray, shape (n_peptides, n_hc_features)
450+
"""
451+
feats = self.do_f_extraction_psm_list_parallel(psm_list)
452+
X = np.stack(list(feats["matrix"].values()))
453+
X_sum = np.stack(list(feats["matrix_sum"].values()))
454+
X_global = np.concatenate(
455+
(
456+
np.stack(list(feats["matrix_all"].values())),
457+
np.stack(list(feats["pos_matrix"].values())),
458+
),
459+
axis=1,
460+
)
461+
X_hc = np.stack(list(feats["matrix_hc"].values()))
462+
return X, X_sum, X_global, X_hc
463+
428464
def _extract_features(
429465
self,
430466
peptidoforms: list[str | Peptidoform] | PSMList,
@@ -437,24 +473,31 @@ def _extract_features(
437473
logger.debug("Running feature extraction in single-threaded mode...")
438474
if self.n_jobs <= 1:
439475
encodings = [
440-
encode_peptidoform(pf, predict_ccs=self.predict_ccs) for pf in peptidoforms
476+
encode_peptidoform(pf, predict_ccs=self.predict_ccs)
477+
for pf in peptidoforms
441478
]
442479

443480
else:
444481
logger.debug("Preparing feature extraction with Dask")
445482
# Process peptidoforms in larger chunks to reduce task overhead.
446-
peptidoform_strings = [str(pep) for pep in peptidoforms] # Faster pickling of strings
483+
peptidoform_strings = [
484+
str(pep) for pep in peptidoforms
485+
] # Faster pickling of strings
447486

448487
def chunked_encode(chunk):
449-
return [encode_peptidoform(pf, predict_ccs=self.predict_ccs) for pf in chunk]
488+
return [
489+
encode_peptidoform(pf, predict_ccs=self.predict_ccs) for pf in chunk
490+
]
450491

451492
tasks = [
452493
delayed(chunked_encode)(peptidoform_strings[i : i + chunk_size])
453494
for i in range(0, len(peptidoform_strings), chunk_size)
454495
]
455496

456497
logger.debug("Starting feature extraction with Dask")
457-
chunks_encodings = compute(*tasks, scheduler="processes", workers=self.n_jobs)
498+
chunks_encodings = compute(
499+
*tasks, scheduler="processes", workers=self.n_jobs
500+
)
458501

459502
# Flatten the list of lists.
460503
encodings = [enc for chunk in chunks_encodings for enc in chunk]
@@ -486,7 +529,9 @@ def _apply_calibration_core(
486529

487530
# Use spline model within the range of X
488531
within_range = (uncal_preds >= cal_min) & (uncal_preds <= cal_max)
489-
within_range = within_range.ravel() # Ensure this is a 1D array for proper indexing
532+
within_range = (
533+
within_range.ravel()
534+
) # Ensure this is a 1D array for proper indexing
490535

491536
# Create a prediction array initialized with spline predictions
492537
cal_preds = np.copy(y_pred_spline)
@@ -501,19 +546,27 @@ def _apply_calibration_core(
501546
else:
502547
for uncal_pred in uncal_preds:
503548
try:
504-
slope, intercept = cal_dict[str(round(uncal_pred, self.bin_distance))]
549+
slope, intercept = cal_dict[
550+
str(round(uncal_pred, self.bin_distance))
551+
]
505552
cal_preds.append(slope * (uncal_pred) + intercept)
506553
except KeyError:
507554
# outside of the prediction range ... use the last
508555
# calibration curve
509556
if uncal_pred <= cal_min:
510-
slope, intercept = cal_dict[str(round(cal_min, self.bin_distance))]
557+
slope, intercept = cal_dict[
558+
str(round(cal_min, self.bin_distance))
559+
]
511560
cal_preds.append(slope * (uncal_pred) + intercept)
512561
elif uncal_pred >= cal_max:
513-
slope, intercept = cal_dict[str(round(cal_max, self.bin_distance))]
562+
slope, intercept = cal_dict[
563+
str(round(cal_max, self.bin_distance))
564+
]
514565
cal_preds.append(slope * (uncal_pred) + intercept)
515566
else:
516-
slope, intercept = cal_dict[str(round(cal_max, self.bin_distance))]
567+
slope, intercept = cal_dict[
568+
str(round(cal_max, self.bin_distance))
569+
]
517570
cal_preds.append(slope * (uncal_pred) + intercept)
518571

519572
return np.array(cal_preds)
@@ -648,7 +701,9 @@ def make_preds(
648701
elif infile is not None:
649702
psm_list = _file_to_psm_list(infile)
650703
else:
651-
raise ValueError("Either `psm_list` or `seq_df` or `infile` must be provided.")
704+
raise ValueError(
705+
"Either `psm_list` or `seq_df` or `infile` must be provided."
706+
)
652707

653708
if len(psm_list) == 0:
654709
logger.warning("No PSMs to predict for.")
@@ -692,7 +747,9 @@ def make_preds(
692747
)
693748
)
694749
# Average the predictions from all models
695-
ret_preds = np.array([sum(a) / len(a) for a in zip(*ret_preds, strict=True)])
750+
ret_preds = np.array(
751+
[sum(a) / len(a) for a in zip(*ret_preds, strict=True)]
752+
)
696753
# ret_preds = np.mean(model_predictions, axis=0)
697754

698755
else:
@@ -740,7 +797,9 @@ def _calibrate_preds_pygam(
740797
spline_model = linear_model
741798
linear_model_right = linear_model
742799
else:
743-
spline = SplineTransformer(degree=4, n_knots=int(len(measured_tr) / 500) + 5)
800+
spline = SplineTransformer(
801+
degree=4, n_knots=int(len(measured_tr) / 500) + 5
802+
)
744803
spline_model = make_pipeline(spline, LinearRegression())
745804
spline_model.fit(predicted_tr.reshape(-1, 1), measured_tr)
746805

@@ -832,9 +891,13 @@ def _calibrate_preds_piecewise_linear(
832891
"peptides for fitting the calibration curve."
833892
)
834893
if len(mtr_mean) == 0:
835-
raise CalibrationError("The measured tr list is empty, not able to calibrate")
894+
raise CalibrationError(
895+
"The measured tr list is empty, not able to calibrate"
896+
)
836897
if len(ptr_mean) == 0:
837-
raise CalibrationError("The predicted tr list is empty, not able to calibrate")
898+
raise CalibrationError(
899+
"The predicted tr list is empty, not able to calibrate"
900+
)
838901

839902
# calculate calibration curves
840903
for i in range(0, len(ptr_mean)):
@@ -913,7 +976,9 @@ def calibrate_preds(
913976
elif infile is not None:
914977
psm_list = _file_to_psm_list(infile)
915978
else:
916-
raise ValueError("Either `psm_list` or `seq_df` or `infile` must be provided.")
979+
raise ValueError(
980+
"Either `psm_list` or `seq_df` or `infile` must be provided."
981+
)
917982

918983
# Getting measured retention time either from measured_tr or provided PSMs
919984
if not measured_tr:
@@ -946,7 +1011,9 @@ def calibrate_preds(
9461011
X, X_sum, X_global, X_hc = self._prepare_feature_matrices(psm_list)
9471012
dataset = DeepLCDataset(X, X_sum, X_global, X_hc, np.array(measured_tr))
9481013

949-
base_model_path = self.model[0] if isinstance(self.model, list) else self.model
1014+
base_model_path = (
1015+
self.model[0] if isinstance(self.model, list) else self.model
1016+
)
9501017
base_model = torch.load(
9511018
base_model_path, weights_only=False, map_location=torch.device("cpu")
9521019
)
@@ -982,29 +1049,42 @@ def calibrate_preds(
9821049

9831050
for model_name in self.model:
9841051
logger.debug(f"Trying out the following model: {model_name}")
985-
predicted_tr = self.make_preds(psm_list, calibrate=False, mod_name=model_name)
1052+
predicted_tr = self.make_preds(
1053+
psm_list, calibrate=False, mod_name=model_name
1054+
)
9861055

9871056
if self.pygam_calibration:
988-
calibrate_output = self._calibrate_preds_pygam(measured_tr, predicted_tr)
1057+
calibrate_output = self._calibrate_preds_pygam(
1058+
measured_tr, predicted_tr
1059+
)
9891060
else:
9901061
calibrate_output = self._calibrate_preds_piecewise_linear(
9911062
measured_tr, predicted_tr, use_median=use_median
9921063
)
993-
self.calibrate_min, self.calibrate_max, self.calibrate_dict = calibrate_output
1064+
self.calibrate_min, self.calibrate_max, self.calibrate_dict = (
1065+
calibrate_output
1066+
)
9941067
# TODO: Currently, calibration dict can be both a dict (linear) or a list of models
9951068
# (PyGAM)... This should be handled better in the future.
9961069

9971070
# Skip this model if calibrate_dict is empty
9981071
# TODO: Should this do something when using PyGAM and calibrate_dict is a list?
999-
if isinstance(self.calibrate_dict, dict) and len(self.calibrate_dict.keys()) == 0:
1072+
if (
1073+
isinstance(self.calibrate_dict, dict)
1074+
and len(self.calibrate_dict.keys()) == 0
1075+
):
10001076
continue
10011077

10021078
m_name = model_name.split("/")[-1]
10031079

10041080
# Get new predictions with calibration
1005-
preds = self.make_preds(psm_list, calibrate=True, seq_df=seq_df, mod_name=model_name)
1081+
preds = self.make_preds(
1082+
psm_list, calibrate=True, seq_df=seq_df, mod_name=model_name
1083+
)
10061084

1007-
m_group_name = "deepcallc" if self.deepcallc_mod else "_".join(m_name.split("_")[:-1])
1085+
m_group_name = (
1086+
"deepcallc" if self.deepcallc_mod else "_".join(m_name.split("_")[:-1])
1087+
)
10081088
m = model_name
10091089
try:
10101090
pred_dict[m_group_name][m] = preds
@@ -1026,13 +1106,18 @@ def calibrate_preds(
10261106
mod_calibrate_max_dict[m_group_name][m] = self.calibrate_max
10271107

10281108
for m_name in pred_dict:
1029-
preds = [sum(a) / len(a) for a in zip(*list(pred_dict[m_name].values()), strict=True)]
1109+
preds = [
1110+
sum(a) / len(a)
1111+
for a in zip(*list(pred_dict[m_name].values()), strict=True)
1112+
]
10301113
if len(measured_tr) == 0:
10311114
perf = sum(abs(seq_df["tr"] - preds))
10321115
else:
10331116
perf = sum(abs(np.array(measured_tr) - np.array(preds)))
10341117

1035-
logger.debug(f"For {m_name} model got a performance of: {perf / len(preds)}")
1118+
logger.debug(
1119+
f"For {m_name} model got a performance of: {perf / len(preds)}"
1120+
)
10361121

10371122
if perf < best_perf:
10381123
m_group_name = "deepcallc" if self.deepcallc_mod else m_name
@@ -1072,7 +1157,9 @@ def calibrate_preds(
10721157
],
10731158
)
10741159
plotly_return_dict["scatter"] = deeplc.plot.scatter(plotly_df)
1075-
plotly_return_dict["baseline_dist"] = deeplc.plot.distribution_baseline(plotly_df)
1160+
plotly_return_dict["baseline_dist"] = deeplc.plot.distribution_baseline(
1161+
plotly_df
1162+
)
10761163
return plotly_return_dict
10771164

10781165
return None
@@ -1124,7 +1211,9 @@ def create_psm(args):
11241211
)
11251212

11261213
args_list = list(
1127-
zip(sequences, modifications, identifiers, charges, retention_times, strict=True)
1214+
zip(
1215+
sequences, modifications, identifiers, charges, retention_times, strict=True
1216+
)
11281217
)
11291218
tasks = [delayed(create_psm)(args) for args in args_list]
11301219
list_of_psms = list(compute(*tasks, scheduler="processes"))

0 commit comments

Comments
 (0)