1+
2+
3+ """
4+ Gaussian process optimization similar (but less sophisticated compared) to
5+ Kermut: Composite kernel regression for protein variant effects
6+ Peter Mørch Groth, Mads Herbert Kerrn, Lars Olsen, Jesper Salomon, Wouter Boomsma
7+ 2024, 38th Conference on Neural Information Processing Systems (NeurIPS 2024).
8+ TL;DR: Gaussian process regression model with a novel composite kernel, Kermut, achieves
9+ state-of-the-art variant effect prediction while providing meaningful uncertainties.
10+ https://openreview.net/forum?id=jM9atrvUii
11+ """
12+
13+
114import pandas as pd
215import numpy as np
316from sklearn .model_selection import train_test_split
1023
1124from tqdm import tqdm
1225
13- from pypef .llm .prosst_lora_tune import (
14- get_logits_from_full_seqs , get_prosst_models , get_structure_quantizied ,
15- prosst_tokenize_sequences , prosst_train
26+ from pypef .plm .prosst_lora_tune import (
27+ get_prosst_models , get_structure_quantizied
1628)
17- from pypef .llm .inference import inference
18- from pypef .utils .helpers import get_vram , get_device
29+ from pypef .plm .inference import tokenize_sequences , plm_inference
30+ from pypef .utils .helpers import get_device
1931
2032
2133device = "cuda" if torch .cuda .is_available () else "cpu"
@@ -76,14 +88,16 @@ def extract_prosst_embeddings(
7688 structures = [structures ] * len (sequences )
7789
7890 assert len (sequences ) == len (structures ), \
79- "Number of sequences must match number of structures"
91+ (f"Number of sequences must match number of structures "
92+ f"{ len (sequences )} != { len (structures )} " )
8093
8194 for seq , struct in tqdm (zip (sequences , structures ),
8295 total = len (sequences ),
8396 desc = "Embedding (ProSST)" ):
8497 # Tokenize sequence
8598 tokenized = prosst_tokenizer (
8699 [seq ],
100+ max_length = len (seq ) + 2 ,
87101 return_tensors = "pt" ,
88102 padding = False ,
89103 truncation = False
@@ -124,29 +138,32 @@ def extract_prosst_embeddings(
124138 return X
125139
126140
141+ def gp_fit ():
142+ pass
127143
128144
129145
146+ def gp_fit_sklearn ():
147+ pass
148+
149+
130150
131151
132152if __name__ == '__main__' :
133- wt_seq = list (read_fasta_biopython ('example_data/blat_ecolx/blat_ecolx_wt_seq.fa ' ).values ())[0 ]
134- pdb = 'example_data/blat_ecolx /BLAT_ECOLX.pdb'
153+ wt_seq = list (read_fasta_biopython ('datasets/BLAT_ECOLX/blat_ecolx_wt.fasta ' ).values ())[0 ]
154+ pdb = 'datasets/BLAT_ECOLX /BLAT_ECOLX.pdb'
135155 device = get_device ()
136156 print ("Getting ProSST models" )
137157 prosst_base_model , prosst_lora_model , prosst_tokenizer , prosst_optimizer = get_prosst_models ()
138158 prosst_vocab = prosst_tokenizer .get_vocab ()
139159 prosst_base_model = prosst_base_model .to (device )
140160
141161 print (f"Getting structure tokens..." )
142- input_ids , prosst_attention_mask , structure_input_ids = get_structure_quantizied (
162+ wt_input_ids , prosst_attention_mask , structure_input_ids = get_structure_quantizied (
143163 pdb , prosst_tokenizer , wt_seq , verbose = True
144164 )
145165
146-
147-
148-
149- df = pd .read_csv ('example_data/blat_ecolx/BLAT_ECOLX_Stiffler_2015.csv' )
166+ df = pd .read_csv ('datasets/BLAT_ECOLX/BLAT_ECOLX_Stiffler_2015.csv' )
150167 sequences = df ['mutated_sequence' ].to_list ()
151168 y = df ['DMS_score' ].to_list ()
152169
@@ -156,10 +173,18 @@ def extract_prosst_embeddings(
156173 # --- Step 2: Extract ProSST embeddings ---
157174 print (structure_input_ids )
158175 print ('np.shape(structure_input_ids):' , np .shape (structure_input_ids ))
159- wt_structure_input_ids = structure_input_ids [0 , 1 :- 1 ].tolist () # Remove CLS/EOS
160- X_train = extract_prosst_embeddings (prosst_base_model , prosst_tokenizer , s_train , wt_structure_input_ids )
176+ wt_structure_input_ids = structure_input_ids
177+
178+ X_emb_train = extract_prosst_embeddings (prosst_base_model , prosst_tokenizer , s_train , wt_structure_input_ids [0 , 1 :- 1 ].tolist ()) # Remove CLS/EOS
179+
180+ x_train_2 , prosst_attention_mask_2 = tokenize_sequences (s_train , prosst_tokenizer )
181+ assert len (prosst_attention_mask [0 ]) == len (prosst_attention_mask_2 ), f"{ len (prosst_attention_mask [0 ])} \n !=\n { len (prosst_attention_mask_2 )} "
182+
183+ X_emb_train_2 = plm_inference (x_train_2 , wt_input_ids , prosst_attention_mask , prosst_base_model ,
184+ extract_emb = True , wt_structure_input_ids = wt_structure_input_ids ) #[0, 1:-1])
185+
161186 print ("Embedding extraction done" )
162- print ( np .shape (X_train ) )
187+ assert np .shape (X_emb_train ) == np . shape ( X_emb_train_2 )
163188
164189 # --- Step 3: Fit Gaussian Process ---
165190 if USE_SCIKIT_LEARN :
@@ -168,12 +193,12 @@ def extract_prosst_embeddings(
168193
169194 kernel = 1.0 * RBF (length_scale = 1.0 ) + WhiteKernel (noise_level = 0.1 )
170195 gpr = GaussianProcessRegressor (kernel = kernel , n_restarts_optimizer = 10 , normalize_y = True )
171- gpr .fit (X_train , y_train )
196+ gpr .fit (X_emb_train , y_train )
172197
173198 else : # GPYTORCH
174199 import gpytorch
175200
176- X_train_t = torch .tensor (X_train , dtype = torch .float32 ).to (device )
201+ X_train_t = torch .tensor (X_emb_train , dtype = torch .float32 ).to (device )
177202 y_train_t = torch .tensor (y_train , dtype = torch .float32 ).to (device )
178203
179204 likelihood = gpytorch .likelihoods .GaussianLikelihood ().to (device )
@@ -195,26 +220,31 @@ def extract_prosst_embeddings(
195220
196221 # --- Step 4: Extract ProSST embeddings for test sequences ---
197222 print ("Extracting ProSST embeddings for test sequences..." )
198- X_test = extract_prosst_embeddings (
223+ X_emb_test = extract_prosst_embeddings (
199224 model = prosst_base_model ,
200225 prosst_tokenizer = prosst_tokenizer ,
201226 sequences = s_test ,
202- structures = wt_structure_input_ids , # still using same WT structure
227+ structures = wt_structure_input_ids [ 0 , 1 : - 1 ]. tolist () , # still using same WT structure
203228 device = device
204229 )
205- print ("Test embeddings shape:" , X_test .shape )
230+ print ("Test embeddings shape:" , X_emb_test .shape )
231+
232+ x_test_2 , prosst_attention_mask_2 = tokenize_sequences (s_test , prosst_tokenizer )
233+ X_emb_test_2 = plm_inference (x_test_2 , wt_input_ids , prosst_attention_mask , prosst_base_model ,
234+ extract_emb = True , wt_structure_input_ids = wt_structure_input_ids ) #[0, 1:-1])
206235
207236 # --- Step 5: Predict with Gaussian Process ---
208- if USE_SCIKIT_LEARN :
209- y_mean , y_std = gpr .predict (X_test , return_std = True )
210- else :
211- X_test_t = torch .tensor (X_test , dtype = torch .float32 ).to (device )
212- gp_model .eval ()
213- likelihood .eval ()
214- with torch .no_grad (), gpytorch .settings .fast_pred_var ():
215- pred = likelihood (gp_model (X_test_t ))
216- y_mean = pred .mean .cpu ().numpy ()
217- lower , upper = pred .confidence_region () # optional 95% CI
218-
219- print ("Predicted fitness:" , y_mean )
220- print ("Spearman correlation:" , spearmanr (y_test , y_mean ))
237+ for x_t in [torch .tensor (X_emb_test ).to (device ), X_emb_test_2 ]:
238+ if USE_SCIKIT_LEARN :
239+ y_mean , y_std = gpr .predict (x_t , return_std = True )
240+ else :
241+ #X_test_t = torch.tensor(x_t, dtype=torch.float32).to(device)
242+ gp_model .eval ()
243+ likelihood .eval ()
244+ with torch .no_grad (), gpytorch .settings .fast_pred_var ():
245+ pred = likelihood (gp_model (x_t ))
246+ y_mean = pred .mean .cpu ().numpy ()
247+ lower , upper = pred .confidence_region () # optional 95% CI
248+
249+ print ("Predicted fitness:" , y_mean )
250+ print ("Spearman correlation:" , spearmanr (y_test , y_mean )) # SignificanceResult(statistic=0.8617991109937434, pvalue=0.0)
0 commit comments