@@ -44,26 +44,6 @@ def get_vram(verbose: bool = True):
4444 return free , total
4545
4646
47- def read_pdb (pdbfile ):
48- from Bio import PDB
49-
50- pdb_io = PDB .PDBIO ()
51- pdb_parser = PDB .PDBParser ()
52- structure = pdb_parser .get_structure ('ppp' , pdbfile )
53-
54- new_resnums = [i + 200 for i in range (135 )]
55-
56- print (structure )
57- print (pdbfile )
58-
59- for model in structure :
60- for chain in model :
61- for i , residue in enumerate (chain .get_residues ()):
62- res_id = list (residue .id )
63- #res_id[1] = new_resnums[i]
64- #residue.id = tuple(res_id)
65-
66-
6747def compute_performances (mut_data , mut_sep = ':' , start_i : int = 0 , already_tested_is : list = []):
6848 # Get cpu, gpu or mps device for training.
6949 device = (
@@ -81,14 +61,13 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
8161 esm_base_model , esm_lora_model , esm_tokenizer , esm_optimizer = get_esm_models ()
8262 esm_base_model = esm_base_model .to (device )
8363 MAX_WT_SEQUENCE_LENGTH = 2000
84- N_EPOCHS = 5
8564 get_vram ()
8665 hybrid_perfs = []
8766 plt .figure (figsize = (40 , 12 ))
8867 numbers_of_datasets = [i + 1 for i in range (len (mut_data .keys ()))]
8968 delta_times = []
9069 for i , (dset_key , dset_paths ) in enumerate (mut_data .items ()):
91- if i >= start_i and i not in already_tested_is and i < 21 : # i > 3 and i <21: #i == 18 - 1:
70+ if i >= start_i and i not in already_tested_is : # i > 3 and i <21: #i == 18 - 1:
9271 start_time = time .time ()
9372 print (f'\n { i + 1 } /{ len (mut_data .items ())} \n '
9473 f'===============================================================' )
@@ -103,7 +82,6 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
10382 print ('MSA path:' , msa_path )
10483 print ('MSA start:' , msa_start , '- MSA end:' , msa_end )
10584 print ('WT sequence (trimmed from MSA start to MSA end):\n ' + wt_seq )
106- read_pdb (pdb )
10785 #if msa_start != 1:
10886 # print('Continuing (TODO: requires cut of PDB input struture residues)...')
10987 # continue
@@ -152,8 +130,6 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
152130 y_pred_dca = get_delta_e_statistical_model (x_dca , x_wt )
153131 print ('DCA:' , spearmanr (fitnesses , y_pred_dca ), len (fitnesses ))
154132 dca_unopt_perf = spearmanr (fitnesses , y_pred_dca )[0 ]
155- # TF 10,000: DCA: SignificanceResult(statistic=np.float64(0.6486616550552755), pvalue=np.float64(3.647740047145113e-119)) 989
156- # Torch 10,000: DCA: SignificanceResult(statistic=np.float64(0.6799982280150232), pvalue=np.float64(3.583110693136881e-135)) 989
157133
158134 x_esm , esm_attention_mask = esm_tokenize_sequences (sequences , esm_tokenizer , max_length = len (wt_seq ))
159135 y_esm = esm_infer (get_batches (x_esm , dtype = float , batch_size = 1 ), esm_attention_mask , esm_base_model )
@@ -248,7 +224,6 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
248224 llm_model_input = method ,
249225 x_wt = x_wt
250226 )
251-
252227 y_test_pred = hm .hybrid_prediction (
253228 x_dca = np .array (x_dca_test ),
254229 x_llm = [
@@ -257,7 +232,6 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
257232 np .asarray (x_llm_test_prosst )
258233 ][i_m ]
259234 )
260-
261235 print (f'Hybrid perf.: { spearmanr (y_test , y_test_pred )[0 ]} ' )
262236 hybrid_perfs .append (spearmanr (y_test , y_test_pred )[0 ])
263237 except RuntimeError : # modeling_prosst.py, line 920, in forward
@@ -336,7 +310,6 @@ def plot_csv_data(csv, plot_name):
336310 train_test_size_texts .append (plt .text (len (tested_dsets ), np .nanmean (dset_hybrid_perfs_dca_1000 ), f'{ np .nanmean (dset_hybrid_perfs_dca_1000 ):.2f} ' , color = 'blueviolet' ))
337311
338312
339-
340313 plt .plot (range (len (tested_dsets )), dset_esm_perfs , 'o--' , markersize = 8 , color = 'tab:green' , label = 'ESM (0)' )
341314 plt .plot (range (len (tested_dsets ) + 1 ), np .full (len (tested_dsets ) + 1 , np .nanmean (dset_esm_perfs )), color = 'tab:green' , linestyle = '--' )
342315 for i , (p , n_test ) in enumerate (zip (dset_esm_perfs , df ['N_Y_test' ].astype ('Int64' ).to_list ())):
@@ -362,8 +335,6 @@ def plot_csv_data(csv, plot_name):
362335 train_test_size_texts .append (plt .text (len (tested_dsets ), np .nanmean (dset_hybrid_perfs_dca_esm_1000 ), f'{ np .nanmean (dset_hybrid_perfs_dca_esm_1000 ):.2f} ' , color = 'turquoise' ))
363336
364337
365-
366-
367338 plt .plot (range (len (tested_dsets )), dset_prosst_perfs , 'o--' , markersize = 8 , color = 'tab:red' , label = 'ProSST (0)' )
368339 plt .plot (range (len (tested_dsets ) + 1 ), np .full (len (tested_dsets ) + 1 , np .nanmean (dset_prosst_perfs )), color = 'tab:red' , linestyle = '--' )
369340 for i , (p , n_test ) in enumerate (zip (dset_prosst_perfs , df ['N_Y_test' ].astype ('Int64' ).to_list ())):
@@ -389,9 +360,6 @@ def plot_csv_data(csv, plot_name):
389360 train_test_size_texts .append (plt .text (len (tested_dsets ), np .nanmean (dset_hybrid_perfs_dca_prosst_1000 ), f'{ np .nanmean (dset_hybrid_perfs_dca_prosst_1000 ):.2f} ' , color = 'darkred' ))
390361
391362
392-
393-
394-
395363 plt .grid (zorder = - 1 )
396364 plt .xticks (range (len (tested_dsets )), tested_dsets , rotation = 45 , ha = 'right' )
397365 plt .margins (0.01 )
@@ -433,6 +401,13 @@ def plot_csv_data(csv, plot_name):
433401 print (df .columns )
434402 dset_ns_y_test = [
435403 df ['N_Y_test' ].to_list (),
404+ df ['N_Y_test_100' ].to_list (),
405+ df ['N_Y_test_200' ].to_list (),
406+ df ['N_Y_test_1000' ].to_list (),
407+ df ['N_Y_test' ].to_list (),
408+ df ['N_Y_test_100' ].to_list (),
409+ df ['N_Y_test_200' ].to_list (),
410+ df ['N_Y_test_1000' ].to_list (),
436411 df ['N_Y_test' ].to_list (),
437412 df ['N_Y_test_100' ].to_list (),
438413 df ['N_Y_test_200' ].to_list (),
0 commit comments