@@ -140,13 +140,13 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
140140 gremlin = GREMLIN (alignment = msa_path , opt_iter = 100 , optimize = True )
141141 sequences_batched = get_batches (sequences , batch_size = 1000 ,
142142 dtype = str , keep_remaining = True , verbose = True )
143- x_dca = []
144- for seq_b in tqdm (sequences_batched , desc = "Getting GREMLIN sequence encodings" ):
143+ x_dca = [] # required later on also
144+ for seq_b in tqdm (sequences_batched , desc = "Getting GREMLIN sequence encodings" , disable = True ):
145145 for x in gremlin .collect_encoded_sequences (seq_b ):
146146 x_dca .append (x )
147147 x_wt = gremlin .x_wt
148148 y_pred_dca = get_delta_e_statistical_model (x_dca , x_wt )
149- print (f'DCA (unsupervised performance): { spearmanr (fitnesses , y_pred_dca )[0 ]:.3f} ' )
149+ print (f'DCA (unsupervised performance): { spearmanr (fitnesses , y_pred_dca )[0 ]:.3f} ' )
150150 dca_unopt_perf = spearmanr (fitnesses , y_pred_dca )[0 ]
151151 # ESM unsupervised
152152 try :
@@ -158,7 +158,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
158158 # esm_attention_mask,
159159 # esm_base_model
160160 #)
161- y_esm = inference (sequences , 'esm' , model = esm_base_model )
161+ y_esm = inference (sequences , 'esm' , model = esm_base_model , verbose = False )
162162 print (f'ESM1v (unsupervised performance): '
163163 f'{ spearmanr (fitnesses , y_esm .cpu ())[0 ]:.3f} ' )
164164 esm_unopt_perf = spearmanr (fitnesses , y_esm .cpu ())[0 ]
@@ -167,13 +167,14 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
167167 # ProSST unsupervised
168168 try :
169169 input_ids , prosst_attention_mask , structure_input_ids = get_structure_quantizied (
170- pdb , prosst_tokenizer , wt_seq )
170+ pdb , prosst_tokenizer , wt_seq , verbose = False
171+ )
171172 x_prosst = prosst_tokenize_sequences (sequences = sequences , vocab = prosst_vocab , verbose = False )
172173 #y_prosst = get_logits_from_full_seqs(
173174 # x_prosst, prosst_base_model, input_ids, prosst_attention_mask,
174175 # structure_input_ids, train=False
175176 #)
176- y_prosst = inference (sequences , 'prosst' , pdb_file = pdb , wt_seq = wt_seq , model = prosst_base_model )
177+ y_prosst = inference (sequences , 'prosst' , pdb_file = pdb , wt_seq = wt_seq , model = prosst_base_model , verbose = False )
177178 print (f'ProSST (unsupervised performance): '
178179 f'{ spearmanr (fitnesses , y_prosst .cpu ())[0 ]:.3f} ' )
179180 prosst_unopt_perf = spearmanr (fitnesses , y_prosst .cpu ())[0 ]
@@ -195,9 +196,10 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
195196 for i_split , (train_i , test_i ) in enumerate (zip (
196197 train_indices , test_indices
197198 )):
198- print (f'Split: { i_split + 1 } ' )
199+ print (f' Split: { i_split + 1 } ' )
199200 temp_results [category ].update ({f'Split { i_split } ' : {}})
200201 try :
202+ _train_sequences , test_sequences = np .asarray (sequences )[train_i ], np .asarray (sequences )[test_i ]
201203 x_dca_train , x_dca_test = np .asarray (x_dca )[train_i ], np .asarray (x_dca )[test_i ]
202204 x_llm_train_prosst , x_llm_test_prosst = np .asarray (x_prosst )[train_i ], np .asarray (x_prosst )[test_i ]
203205 x_llm_train_esm , x_llm_test_esm = np .asarray (x_esm )[train_i ], np .asarray (x_esm )[test_i ]
@@ -253,7 +255,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
253255 'structure_input_ids' : structure_input_ids
254256 }
255257 }
256- print (f'Train: { len (np .array (y_train ))} --> Test: { len (np .array (y_test ))} ' )
258+ print (f' Train: { len (np .array (y_train ))} --> Test: { len (np .array (y_test ))} ' )
257259 if len (y_test ) <= 20 : # TODO: 50
258260 print (f"Only { len (fitnesses )} in total, splitting the data "
259261 f"in N_Train = { len (y_train )} and N_Test = { len (y_test )} "
@@ -264,6 +266,17 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
264266 ns_y_test .append (np .nan )
265267 continue
266268 #get_vram()
269+
270+ y_test_pred_dca = get_delta_e_statistical_model (x_dca_test , x_wt )
271+ temp_results [category ][f'Split { i_split } ' ].update ({'DCA' : spearmanr (y_test , y_test_pred_dca )[0 ]})
272+ print (f' DCA ZeroShot (split { i_split + 1 } ) performance: { spearmanr (y_test , y_test_pred_dca )[0 ]:.3f} ' )
273+ y_test_pred_esm = inference (test_sequences , 'esm' , model = esm_base_model , verbose = False )
274+ temp_results [category ][f'Split { i_split } ' ].update ({'ESM1v' : spearmanr (y_test , y_test_pred_esm )[0 ]})
275+ print (f' ESM1v ZeroShot (split { i_split + 1 } ) performance: { spearmanr (y_test , y_test_pred_esm )[0 ]:.3f} ' )
276+ y_test_pred_prosst = inference (test_sequences , 'prosst' , model = prosst_base_model , pdb_file = pdb , wt_seq = wt_seq , verbose = False )
277+ temp_results [category ][f'Split { i_split } ' ].update ({'ProSST' : spearmanr (y_test , y_test_pred_prosst )[0 ]})
278+ print (f' ProSST ZeroShot (split { i_split + 1 } ) performance: { spearmanr (y_test , y_test_pred_prosst )[0 ]:.3f} ' )
279+
267280 for i_m , method in enumerate ([None , llm_dict_esm , llm_dict_prosst ]):
268281 m_str = ['DCA hybrid' , 'DCA+ESM1v hybrid' , 'DCA+ProSST hybrid' ][i_m ]
269282 #print('\n~~~ ' + m_str + ' ~~~')
@@ -284,7 +297,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
284297 ][i_m ],
285298 verbose = False
286299 )
287- print (f'{ m_str } (split { i_split + 1 } ) performance: { spearmanr (y_test , y_test_pred )[0 ]:.3f} '
300+ print (f' { m_str } (split { i_split + 1 } ) performance: { spearmanr (y_test , y_test_pred )[0 ]:.3f} '
288301 f'(train size={ train_size } , test_size={ test_size } )' )
289302 temp_results [category ][f'Split { i_split } ' ].update ({m_str : spearmanr (y_test , y_test_pred )[0 ]})
290303 except RuntimeError as e : # modeling_prosst.py, line 920, in forward
0 commit comments