1717from pypef .ml .regression import AAIndexEncoding , full_aaidx_txt_path , get_regressor_performances
1818from pypef .dca .gremlin_inference import GREMLIN
1919from pypef .utils .variant_data import get_sequences_from_file , get_wt_sequence
20- from pypef .plm .esm_lora_tune import esm_infer , esm_infer_pll , esm_setup , esm_train
20+ from pypef .plm .esm_lora_tune import esm_infer , plm_inference , esm_setup , esm_train
2121from pypef .plm .prosst_lora_tune import prosst_setup
2222from pypef .plm .inference import inference , llm_tokenizer
2323from pypef .hybrid .hybrid_model import DCALLMHybridModel
2626)
2727from pypef .plm .prosst_lora_tune import (
2828 get_logits_from_full_seqs , get_prosst_models , get_structure_quantizied ,
29- prosst_tokenize_sequences
29+ prosst_simple_vocab_aa_tokenizer
3030)
3131from pypef .utils .helpers import get_device
3232
@@ -258,10 +258,6 @@ def test_plm_corr_blat_ecolx():
258258 prosst_base_model = prosst_base_model .to (device )
259259 df = pd .read_csv (csv_blat_ecolx_stiffler2015 )
260260 sequences = df ['mutated_sequence' ].to_list ()
261- print (sequences [0 ][23 ])
262- print (sequences [1 ][23 ])
263- print ('len(sequences[0]):' , len (sequences [0 ]))
264- print ('len(blat_ecolx_wt_seq):' , len (blat_ecolx_wt_seq ))
265261 y_true = df ['DMS_score' ].to_list ()
266262 for x in ['facebook/esm1v_t33_650M_UR90S_3' ]:
267263 esm_base_model , _esm_lora_model , esm_tokenizer , esm_optimizer = get_esm_models (model = x )
@@ -275,7 +271,7 @@ def test_plm_corr_blat_ecolx():
275271 max_length = len (blat_ecolx_wt_seq ) + 2
276272 )
277273 wt_tokens = torch .tensor (wt_tokens [0 ], dtype = torch .long ) # shape (L,)
278- y_esm = esm_infer_pll (
274+ y_esm = plm_inference (
279275 xs = x_esm ,
280276 wt_input_ids = wt_tokens ,
281277 attention_mask = esm_attention_mask ,
@@ -289,7 +285,7 @@ def test_plm_corr_blat_ecolx():
289285 print (f'{ x } : ESM1v (unsupervised performance): '
290286 f'{ spearmanr (y_true , y_esm .cpu ())[0 ]} ' )
291287 np .testing .assert_almost_equal (spearmanr (y_true , y_esm .cpu ())[0 ], 0.6367826285982324 , decimal = 6 )
292- y_esm = esm_infer_pll (
288+ y_esm = plm_inference (
293289 xs = x_esm ,
294290 wt_input_ids = wt_tokens ,
295291 attention_mask = esm_attention_mask ,
@@ -303,7 +299,7 @@ def test_plm_corr_blat_ecolx():
303299 print (f'{ x } : ESM1v (unsupervised performance): '
304300 f'{ spearmanr (y_true , y_esm .cpu ())[0 ]} ' )
305301 np .testing .assert_almost_equal (spearmanr (y_true , y_esm .cpu ())[0 ], 0.6498987261125897 , decimal = 6 )
306- #y_esm = esm_infer_pll (
302+ #y_esm = plm_inference (
307303 # xs=x_esm,
308304 # wt_input_ids=wt_tokens,
309305 # attention_mask=esm_attention_mask,
@@ -317,31 +313,34 @@ def test_plm_corr_blat_ecolx():
317313 #print(f'{x}: ESM1v (unsupervised performance): '
318314 # f'{spearmanr(y_true, y_esm.cpu())[0]}')
319315 #np.testing.assert_almost_equal(spearmanr(y_true, y_esm.cpu())[0], 0.666666666666666, decimal=6)
320-
321316 wt_input_ids , prosst_attention_mask , wt_structure_input_ids = get_structure_quantizied (
322317 pdb_blat_ecolx , prosst_tokenizer , blat_ecolx_wt_seq )
323- x_prosst = tokenize_sequences (sequences = sequences , tokenizer = prosst_tokenizer )
324- y_prosst = get_logits_from_full_seqs (
325- x_prosst , prosst_base_model , wt_input_ids , prosst_attention_mask ,
326- wt_structure_input_ids , train = False , verbose = True
318+ x_prosst2 = prosst_simple_vocab_aa_tokenizer (sequences , prosst_vocab )
319+ x_prosst , prosst_attention_mask_ = tokenize_sequences (
320+ sequences = sequences ,
321+ tokenizer = prosst_tokenizer ,
322+ max_length = len (blat_ecolx_wt_seq ) + 2
327323 )
328- print (f'ProSST (unsupervised performance): ' # ProteinGym: ProSST: 0.760
329- f'{ spearmanr (y_true , y_prosst .cpu ())[0 ]:.3f} ' )
324+ assert x_prosst [0 ][1 :- 1 ] == x_prosst2 .tolist ()[0 ], (
325+ f"{ x_prosst [0 ][1 :- 1 ]} != { x_prosst2 .tolist ()[0 ]} " )
326+ assert prosst_attention_mask .tolist ()[0 ] == prosst_attention_mask_ , (
327+ f"{ prosst_attention_mask .tolist ()[0 ]} != { prosst_attention_mask_ } " )
330328
331- y_prosst = esm_infer_pll (
329+ y_prosst = plm_inference (
332330 xs = x_prosst ,
333- wt_input_ids = ( wt_input_ids , wt_structure_input_ids ), ## TODO
331+ wt_input_ids = wt_input_ids ,
334332 attention_mask = prosst_attention_mask ,
335333 model = prosst_base_model ,
336334 mask_token_id = prosst_tokenizer .mask_token_id ,
337- inference_type = 'prosst' , ## TODO
335+ inference_type = 'unmasked' ,
336+ wt_structure_input_ids = wt_structure_input_ids ,
338337 batch_size = 5 ,
339338 train = False ,
340339 verbose = True
341340 )
342341 print (f'ProSST (unsupervised performance): ' # ProteinGym: ProSST: 0.760
343- f'{ spearmanr (y_true , y_prosst .cpu ())[0 ]:.3f } ' )
344- # ACTUAL OLD VERSION: 0.743
342+ f'{ spearmanr (y_true , y_prosst .cpu ())[0 ]} ' )
343+ np . testing . assert_almost_equal ( spearmanr ( y_true , y_prosst . cpu ())[ 0 ], 0.7430279087189432 , decimal = 6 )
345344
346345
347346
0 commit comments