1111from pypef .gaussian_process .gp_pmpnn_test import HellingerRBFKernel , get_probs_from_mutations
1212from pypef .gaussian_process .gp_prosst_test import (extract_prosst_embeddings , get_prosst_models ,
1313 get_structure_quantizied , read_fasta_biopython )
14- from pypef .gaussian_process .metrics import spearman_soft , spearman_corr_differentiable , spearmanr2
14+ from pypef .plm .utils import spearman_soft , correlation_loss , hybrid_corr_mse_loss , pearson_loss
15+
1516
1617class CombinedKernel (gpytorch .kernels .Kernel ):
1718 """
@@ -47,30 +48,26 @@ def forward(self, X):
4748 return gpytorch .distributions .MultivariateNormal (mean_x , covar_x )
4849
4950
50-
51-
52-
53-
5451# -----------------------------
5552# Load and preprocess data
5653# -----------------------------
57- df = pd .read_csv ('example_data/blat_ecolx /BLAT_ECOLX_Stiffler_2015.csv' )
54+ df = pd .read_csv ('datasets/BLAT_ECOLX /BLAT_ECOLX_Stiffler_2015.csv' )
5855
5956print (df .columns )
6057mutants = df ['mutant' ].to_list ()
6158sequences = df ['mutated_sequence' ].to_list ()
6259y = df ['DMS_score' ].to_list ()
6360
6461m_train , m_test , s_train , s_test , y_train , y_test = train_test_split (
65- mutants , sequences , y , test_size = 0.33 , random_state = 42
62+ mutants , sequences , y , train_size = 100 , test_size = 100 , random_state = 42
6663)
6764
68- X_struct = get_probs_from_mutations (m_train ) # [N, 20]
65+ # X_struct = get_probs_from_mutations(m_train) # [N, 20]
6966
7067
7168print ("Getting ProSST models" )
72- pdb = 'example_data/blat_ecolx /BLAT_ECOLX.pdb'
73- wt_seq = list (read_fasta_biopython ('example_data/blat_ecolx/blat_ecolx_wt_seq.fa ' ).values ())[0 ]
69+ pdb = 'datasets/BLAT_ECOLX /BLAT_ECOLX.pdb'
70+ wt_seq = list (read_fasta_biopython ('datasets/BLAT_ECOLX/blat_ecolx_wt.fasta ' ).values ())[0 ]
7471prosst_base_model , prosst_lora_model , prosst_tokenizer , prosst_optimizer = get_prosst_models ()
7572prosst_vocab = prosst_tokenizer .get_vocab ()
7673prosst_base_model = prosst_base_model .to ("cuda" )
@@ -87,7 +84,7 @@ def forward(self, X):
8784y_test = torch .tensor (y_test ).float ()
8885
8986# Concatenate features
90- X_combined = torch .cat ([X_seq , X_struct ], dim = - 1 ) # Concenation is necessary as GPkernel does not accept a tuple as input
87+ X_combined = torch .cat ([X_seq , X_seq ], dim = - 1 ) # Concenation is necessary as GPkernel does not accept a tuple as input
9188d_seq = X_seq .shape [1 ]
9289
9390# -----------------------------
@@ -121,10 +118,10 @@ def forward(self, X):
121118# -----------------------------
122119# Test
123120# -----------------------------
124- X_struct_test = get_probs_from_mutations (m_test )
121+ # X_struct_test = get_probs_from_mutations(m_test)
125122#X_seq_test = torch.tensor(extract_esm_embeddings(s_test)).float()
126123X_seq_test = torch .tensor (extract_prosst_embeddings (prosst_base_model , prosst_tokenizer , s_test , wt_structure_input_ids ))
127- X_test_combined = torch .cat ([X_seq_test , X_struct_test ], dim = - 1 )
124+ X_test_combined = torch .cat ([X_seq_test , X_seq_test ], dim = - 1 )
128125
129126model .eval ()
130127likelihood .eval ()
@@ -143,15 +140,23 @@ def forward(self, X):
143140rho , p = spearmanr (y_train , y_pred_train )
144141print ("Spearman rho SciPy TRAIN:" , rho )
145142print ("Spearman soft TRAIN:" , spearman_soft (y_train , torch .from_numpy (y_pred_train )).item ())
143+ print ("Correlation loss Spearman TRAIN:" , correlation_loss (y_train , torch .from_numpy (y_pred_train ), method = "spearman" ))
144+ print ("Correlation hybrid MSE loss Spearman TRAIN:" , hybrid_corr_mse_loss (y_train , torch .from_numpy (y_pred_train )))
145+ print ("Correlation loss Pearson TRAIN:" , correlation_loss (y_train , torch .from_numpy (y_pred_train ), method = "pearson" ))
146+ print ("Correlation loss Pearson 2 TRAIN:" , pearson_loss (y_train , torch .from_numpy (y_pred_train )))
146147y_train_t = y_train .float ().unsqueeze (0 ) # shape (1, n)
147148y_pred_train_t = torch .from_numpy (y_pred_train ).float ().unsqueeze (0 ) # shape (1, n)
148- print ("Spearman corr diff (ChatGPT) TRAIN:" , spearman_corr_differentiable (y_train_t , y_pred_train_t ).item ())
149- print ("Spearman2 torchsort TRAIN:" , spearmanr2 (y_train_t , y_pred_train_t ).item ())
149+ # print("Spearman corr diff (ChatGPT) TRAIN:", spearman_corr_differentiable(y_train_t, y_pred_train_t).item())
150+ # print("Spearman2 torchsort TRAIN:", spearmanr2(y_train_t, y_pred_train_t).item())
150151
151152rho , p = spearmanr (y_test , y_pred )
152153print ("Spearman rho SciPy TEST:" , rho )
153154print ("Spearman soft TEST:" , spearman_soft (y_test , torch .from_numpy (y_pred )).item ())
155+ print ("Correlation loss Spearman TEST:" , correlation_loss (y_test , torch .from_numpy (y_pred ), method = "spearman" ))
156+ print ("Correlation hybrid MSE loss Spearman TEST:" , hybrid_corr_mse_loss (y_test , torch .from_numpy (y_pred )))
157+ print ("Correlation loss Pearson TEST:" , correlation_loss (y_test , torch .from_numpy (y_pred ), method = "pearson" ))
158+ print ("Correlation loss Pearson 2 TEST:" , pearson_loss (y_test , torch .from_numpy (y_pred )))
154159y_test_t = y_test .float ().unsqueeze (0 ) # shape (1, n)
155160y_pred_t = torch .from_numpy (y_pred ).float ().unsqueeze (0 ) # shape (1, n)
156- print ("Spearman corr diff (ChatGPT) TEST:" , spearman_corr_differentiable (y_test_t , y_pred_t ).item ())
157- print ("Spearman2 torchsort TEST:" , spearmanr2 (y_test_t , y_pred_t ).item ())
161+ # print("Spearman corr diff (ChatGPT) TEST:", spearman_corr_differentiable(y_test_t, y_pred_t).item())
162+ # print("Spearman2 torchsort TEST:", spearmanr2(y_test_t, y_pred_t).item())
0 commit comments