Skip to content

Commit 922e19a

Browse files
committed
Add feature_matrices function
Remove redundancy.
1 parent 8e897b0 commit 922e19a

File tree

1 file changed

+40
-46
lines changed

1 file changed

+40
-46
lines changed

deeplc/deeplc.py

Lines changed: 40 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -232,16 +232,13 @@ def _freeze_layers(self, unfreeze_keywords="33_1"):
232232
for name, param in self.model.named_parameters():
233233

234234
param.requires_grad = (unfreeze_keywords in name)
235-
print(f"[INFO] Trainable parameters:")
236-
for name, param in self.model.named_parameters():
237-
if param.requires_grad:
238-
print(f" - {name}")
235+
239236

240237
def prepare_data(self, data, shuffle=True):
241238
return DataLoader(data, batch_size=self.batch_size, shuffle=shuffle)
242239

243240
def fine_tune(self):
244-
logger.info("Starting fine-tuning...")
241+
logger.debug("Starting fine-tuning...")
245242
if self.validation_data is None:
246243
# Split the training data into training and validation sets
247244
val_size = int(len(self.train_data) * self.validation_split)
@@ -290,17 +287,15 @@ def fine_tune(self):
290287
val_loss += loss_fn(outputs, target).item()
291288
avg_val_loss = val_loss / len(val_loader)
292289

293-
logger.info(f"Epoch {epoch + 1}/{self.epochs}, Loss: {avg_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")
294-
print(f"Epoch {epoch + 1}/{self.epochs}, Loss: {avg_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")
290+
logger.debug(f"Epoch {epoch + 1}/{self.epochs}, Loss: {avg_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")
295291
if avg_val_loss < best_val_loss:
296292
best_val_loss = avg_val_loss
297293
best_model_wts = copy.deepcopy(self.model.state_dict())
298294
epochs_no_improve = 0
299295
else:
300296
epochs_no_improve += 1
301297
if epochs_no_improve >= self.patience:
302-
logger.info(f"Early stopping triggered {epoch + 1}")
303-
print(f"Early stopping triggered {epoch + 1}")
298+
logger.debug(f"Early stopping triggered {epoch + 1}")
304299
break
305300
self.model.load_state_dict(best_model_wts)
306301
return self.model
@@ -672,6 +667,37 @@ def do_f_extraction_psm_list_parallel(self, psm_list):
672667

673668
return all_feats
674669

670+
def _prepare_feature_matrices(self, psm_list):
671+
"""
672+
Extract features in parallel and assemble the four input matrices.
673+
674+
Parameters
675+
----------
676+
psm_list : list of PSM
677+
List of peptide‐spectrum matches for which to extract features.
678+
679+
Returns
680+
-------
681+
X : ndarray, shape (n_peptides, n_features)
682+
X_sum : ndarray, shape (n_peptides, n_sum_features)
683+
X_global : ndarray, shape (n_peptides, n_global_features * 2)
684+
X_hc : ndarray, shape (n_peptides, n_hc_features)
685+
"""
686+
feats = self.do_f_extraction_psm_list_parallel(psm_list)
687+
X = np.stack(list(feats["matrix"].values()))
688+
X_sum = np.stack(list(feats["matrix_sum"].values()))
689+
X_global = np.concatenate(
690+
(
691+
np.stack(list(feats["matrix_all"].values())),
692+
np.stack(list(feats["pos_matrix"].values())),
693+
),
694+
axis=1,
695+
)
696+
X_hc = np.stack(list(feats["matrix_hc"].values()))
697+
return X, X_sum, X_global, X_hc
698+
699+
700+
675701
def calibration_core(self, uncal_preds, cal_dict, cal_min, cal_max):
676702
"""
677703
Perform calibration on uncalibrated predictions.
@@ -818,18 +844,7 @@ def make_preds_core(
818844
if len(X) == 0 and len(psm_list) > 0:
819845
if self.verbose:
820846
logger.debug("Extracting features for the CNN model ...")
821-
X = self.do_f_extraction_psm_list_parallel(psm_list)
822-
823-
X_sum = np.stack(list(X["matrix_sum"].values()))
824-
X_global = np.concatenate(
825-
(
826-
np.stack(list(X["matrix_all"].values())),
827-
np.stack(list(X["pos_matrix"].values())),
828-
),
829-
axis=1,
830-
)
831-
X_hc = np.stack(list(X["matrix_hc"].values()))
832-
X = np.stack(list(X["matrix"].values()))
847+
X, X_sum, X_global, X_hc = self._prepare_feature_matrices(psm_list)
833848
elif len(X) == 0 and len(psm_list) == 0:
834849
return []
835850

@@ -940,17 +955,7 @@ def make_preds(
940955
if self.verbose:
941956
logger.debug("Extracting features for the CNN model ...")
942957

943-
X = self.do_f_extraction_psm_list_parallel(psm_list_t)
944-
X_sum = np.stack(list(X["matrix_sum"].values()))
945-
X_global = np.concatenate(
946-
(
947-
np.stack(list(X["matrix_all"].values())),
948-
np.stack(list(X["pos_matrix"].values())),
949-
),
950-
axis=1,
951-
)
952-
X_hc = np.stack(list(X["matrix_hc"].values()))
953-
X = np.stack(list(X["matrix"].values()))
958+
X, X_sum, X_global, X_hc = self._prepare_feature_matrices(psm_list_t)
954959
else:
955960
return []
956961

@@ -1410,19 +1415,9 @@ def calibrate_preds(
14101415
temp_pred = []
14111416

14121417
if self.deeplc_retrain:
1413-
logger.info("Preparing for model fine-tuning...")
1414-
1415-
X = self.do_f_extraction_psm_list_parallel(psm_list)
1416-
X_sum = np.stack(list(X["matrix_sum"].values()))
1417-
X_global = np.concatenate(
1418-
(
1419-
np.stack(list(X["matrix_all"].values())),
1420-
np.stack(list(X["pos_matrix"].values())),
1421-
),
1422-
axis=1,
1423-
)
1424-
X_hc = np.stack(list(X["matrix_hc"].values()))
1425-
X = np.stack(list(X["matrix"].values()))
1418+
logger.debug("Preparing for model fine-tuning...")
1419+
1420+
X, X_sum, X_global, X_hc = self._prepare_feature_matrices(psm_list)
14261421
dataset = DeepLCDataset(X, X_sum, X_global, X_hc, np.array(measured_tr))
14271422

14281423
base_model_path = self.model[0] if isinstance(self.model, list) else self.model
@@ -1448,7 +1443,6 @@ def calibrate_preds(
14481443

14491444
# Define path to save fine-tuned model
14501445
fine_tuned_model_path = os.path.join(t_dir_models, "fine_tuned_model.pth")
1451-
print("Saving fine-tuned model to:", fine_tuned_model_path)
14521446
torch.save(fine_tuned_model, fine_tuned_model_path)
14531447
self.model = [fine_tuned_model_path]
14541448

0 commit comments

Comments
 (0)