@@ -65,7 +65,8 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
6565 plt .figure (figsize = (40 , 12 ))
6666 numbers_of_datasets = [i + 1 for i in range (len (mut_data .keys ()))]
6767 for i , (dset_key , dset_paths ) in enumerate (mut_data .items ()):
68- if i >= start_i and i not in already_tested_is : # i > 3 and i <21: #i == 18 - 1:
68+ if i >= start_i and i not in already_tested_is and i != 19 : # i > 3 and i <21: #i == 18 - 1:
69+ # Skipping 20 BRCA1_HUMAN_Findlay_2018 due to LLM RunTimeErros
6970 start_time = time .time ()
7071 print (f'\n { i + 1 } /{ len (mut_data .items ())} \n '
7172 f'===============================================================' )
@@ -131,21 +132,27 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
131132 print ('DCA:' , spearmanr (fitnesses , y_pred_dca ), len (fitnesses ))
132133 dca_unopt_perf = spearmanr (fitnesses , y_pred_dca )[0 ]
133134
134- x_esm , esm_attention_mask = esm_tokenize_sequences (sequences , esm_tokenizer , max_length = len (wt_seq ))
135- y_esm = esm_infer (get_batches (x_esm , dtype = float , batch_size = 1 ), esm_attention_mask , esm_base_model )
136- print ('ESM1v:' , spearmanr (fitnesses , y_esm .cpu ()))
135+ try :
136+ x_esm , esm_attention_mask = esm_tokenize_sequences (sequences , esm_tokenizer , max_length = len (wt_seq ))
137+ y_esm = esm_infer (get_batches (x_esm , dtype = float , batch_size = 1 ), esm_attention_mask , esm_base_model )
138+ print ('ESM1v:' , spearmanr (fitnesses , y_esm .cpu ()))
139+ esm_unopt_perf = spearmanr (fitnesses , y_esm .cpu ())[0 ]
140+ except RuntimeError :
141+ esm_unopt_perf = np .nan
137142
138- input_ids , prosst_attention_mask , structure_input_ids = get_structure_quantizied (pdb , prosst_tokenizer , wt_seq )
139- x_prosst = prosst_tokenize_sequences (sequences = sequences , vocab = prosst_vocab )
140143 try :
144+ input_ids , prosst_attention_mask , structure_input_ids = get_structure_quantizied (pdb , prosst_tokenizer , wt_seq )
145+ x_prosst = prosst_tokenize_sequences (sequences = sequences , vocab = prosst_vocab )
141146 y_prosst = get_logits_from_full_seqs (
142- x_prosst , prosst_base_model , input_ids , prosst_attention_mask , structure_input_ids , train = False )
147+ x_prosst , prosst_base_model , input_ids , prosst_attention_mask , structure_input_ids , train = False )
143148 print ('ProSST:' , spearmanr (fitnesses , y_prosst .cpu ()))
144149 prosst_unopt_perf = spearmanr (fitnesses , y_prosst .cpu ())[0 ]
145150 except RuntimeError :
146151 prosst_unopt_perf = np .nan
147152
148- esm_unopt_perf = spearmanr (fitnesses , y_esm .cpu ())[0 ]
153+ if np .isnan (esm_unopt_perf ) and np .isnan (prosst_unopt_perf ):
154+ print ('Both LLM\' s had RunTimeErrors, skipping dataset...' )
155+ continue
149156
150157 ns_y_test = [len (variants )]
151158 for i_t , train_size in enumerate ([100 , 200 , 1000 ]):
@@ -431,7 +438,7 @@ def plot_csv_data(csv, plot_name):
431438 r'$\overline{|\rho|}=$' + f'{ np .nanmean (dset_hybrid_perfs_dca_prosst_1000 ):.2f} '
432439 ][n ]
433440 )
434- plt .text ( # N_Y_test,N_Y_test_100,N_Y_test_200,N_Y_test_1000
441+ plt .text (
435442 n + 0.15 , - 0.05 ,
436443 r'$\overline{N_{Y_\mathrm{test}}}=$' + f'{ int (np .nanmean (np .array (dset_ns_y_test )[n ]))} '
437444 )
@@ -496,11 +503,11 @@ def plot_csv_data(csv, plot_name):
496503 already_tested_is = []
497504
498505
499- # compute_performances(
500- # mut_data=combined_mut_data,
501- # start_i=start_i,
502- # already_tested_is=already_tested_is
503- # )
506+ compute_performances (
507+ mut_data = combined_mut_data ,
508+ start_i = start_i ,
509+ already_tested_is = already_tested_is
510+ )
504511
505512
506513 with open (out_results_csv , 'r' ) as fh :
0 commit comments