@@ -50,6 +50,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
5050 get_vram ()
5151 MAX_WT_SEQUENCE_LENGTH = 600 # TODO: 1000
5252 MAX_VARIANT_FITNESS_PAIRS = 5000
53+ N_CV = 5
5354 print (f"Maximum sequence length: { MAX_WT_SEQUENCE_LENGTH } " )
5455 print (f"Loading LLM models into { device } device..." )
5556 prosst_base_model , prosst_lora_model , prosst_tokenizer , prosst_optimizer = get_prosst_models ()
@@ -160,11 +161,6 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
160161 x_esm , esm_attention_mask = esm_tokenize_sequences (
161162 sequences , esm_tokenizer , max_length = len (wt_seq ), verbose = False
162163 )
163- #y_esm = esm_infer(
164- # get_batches(x_esm, dtype=float, batch_size=1),
165- # esm_attention_mask,
166- # esm_base_model
167- #)
168164 y_esm = inference (sequences , 'esm' , model = esm_base_model , verbose = False )
169165 print (f'ESM1v (unsupervised performance): '
170166 f'{ spearmanr (fitnesses , y_esm .cpu ())[0 ]:.3f} ' )
@@ -177,10 +173,6 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
177173 pdb , prosst_tokenizer , wt_seq , verbose = False
178174 )
179175 x_prosst = prosst_tokenize_sequences (sequences = sequences , vocab = prosst_vocab , verbose = False )
180- #y_prosst = get_logits_from_full_seqs(
181- # x_prosst, prosst_base_model, input_ids, prosst_attention_mask,
182- # structure_input_ids, train=False
183- #)
184176 y_prosst = inference (sequences , 'prosst' , pdb_file = pdb , wt_seq = wt_seq , model = prosst_base_model , verbose = False )
185177 print (f'ProSST (unsupervised performance): '
186178 f'{ spearmanr (fitnesses , y_prosst .cpu ())[0 ]:.3f} ' )
@@ -192,8 +184,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
192184 print ('Both LLM\' s had RunTimeErrors, skipping dataset...' )
193185 continue
194186
195- ns_y_test = [len (variants )]
196- ds = DatasetSplitter (df_or_csv_file = csv_path , n_cv = 5 , mutation_separator = mut_sep )
187+ ds = DatasetSplitter (df_or_csv_file = csv_path , n_cv = N_CV , mutation_separator = mut_sep )
197188 ds .plot_distributions ()
198189 if max_muts >= 2 : # Only using random cross-validation splits
199190 print ("Only performing random splits as data contains multi-substituted variants..." )
@@ -202,17 +193,20 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
202193 print ("Only single substituted variants found, performing random, modulo, and continuous data splits..." )
203194 target_split_indices = ds .get_all_split_indices ()
204195 temp_results = {}
205- # TODO: Get correct indices for full df for multi-muts using DatasetSplitter!
196+ for c in ["Random" , "Modulo" , "Continuous" ]:
197+ temp_results .update ({c : {}})
198+ for s in range (N_CV ):
199+ temp_results [c ].update ({f'Split { s } ' : {}})
200+ for m in ['DCA' , 'ESM1v' , 'ProSST' , 'DCA hybrid' , 'DCA+ESM1v hybrid' , 'DCA+ProSST hybrid' ]:
201+ # Prefill with NaN's
202+ temp_results [c ][f'Split { s } ' ].update ({m : np .nan })
206203 for i_category , (train_indices , test_indices ) in enumerate (target_split_indices ):
207204 category = ["Random" , "Modulo" , "Continuous" ][i_category ]
208205 print (f'Category: { category } ' )
209- temp_results .update ({category : {}})
210206 for i_split , (train_i , test_i ) in enumerate (zip (
211207 train_indices , test_indices
212208 )):
213209 print (f' Split: { i_split + 1 } ' )
214- print (test_i )
215- temp_results [category ].update ({f'Split { i_split } ' : {}})
216210 try :
217211 _train_sequences , test_sequences = np .asarray (sequences )[train_i ], np .asarray (sequences )[test_i ]
218212 x_dca_train , x_dca_test = np .asarray (x_dca )[train_i ], np .asarray (x_dca )[test_i ]
@@ -224,14 +218,10 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
224218 esm_lora_model_2 = copy .deepcopy (esm_lora_model )
225219 esm_optimizer = torch .optim .Adam (esm_lora_model_2 .parameters (), lr = 0.0001 )
226220 train_size , test_size = len (train_i ), len (test_i )
227- #get_vram()
228221 except ValueError as e :
229222 print (f"Only { len (fitnesses )} variant-fitness pairs in total, "
230223 f"cannot split the data in N_Train = { train_size } and N_Test "
231224 f"(N_Total - N_Train) [Excepted error: { e } ]." )
232- for m in ['DCA' , 'ESM1v' , 'ProSST' , 'DCA hybrid' , 'DCA+ESM1v hybrid' , 'DCA+ProSST hybrid' ]:
233- temp_results [category ][f'Split { i_split } ' ].update ({m : np .nan })
234- ns_y_test .append (np .nan )
235225 continue
236226 (
237227 x_dca_train ,
@@ -276,9 +266,6 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
276266 f"in N_Train = { len (y_train )} and N_Test = { len (y_test )} "
277267 f"results in N_Test <= 50 variants - not getting "
278268 f"performance for N_Train = { len (y_train )} ..." )
279- ns_y_test .append (np .nan )
280- for m in ['DCA' , 'ESM1v' , 'ProSST' , 'DCA hybrid' , 'DCA+ESM1v hybrid' , 'DCA+ProSST hybrid' ]:
281- temp_results [category ][f'Split { i_split } ' ].update ({m : np .nan })
282269 continue
283270
284271 y_test_pred_dca = get_delta_e_statistical_model (x_dca_test , x_wt )
@@ -313,10 +300,8 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
313300 print (f' { m_str } (split { i_split + 1 } ) performance: { spearmanr (y_test , y_test_pred )[0 ]:.3f} '
314301 f'(train size={ train_size } , test_size={ test_size } )' )
315302 temp_results [category ][f'Split { i_split } ' ].update ({m_str : spearmanr (y_test , y_test_pred )[0 ]})
316- except RuntimeError as e : # modeling_prosst.py, line 920, in forward
317- # or UnboundLocalError in prosst_lora_tune.py, line 167
318- temp_results [category ][f'Split { i_split } ' ].update ({m_str : np .nan })
319- ns_y_test .append (len (y_test_pred ))
303+ except RuntimeError as e : # modeling_prosst.py in forward
304+ continue
320305 del prosst_lora_model_2
321306 del esm_lora_model_2
322307 torch .cuda .empty_cache ()
@@ -358,7 +343,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
358343 f'{ int (dt )} \n ' )
359344
360345
361- def plot_csv_data (csv , plot_name ):
346+ def plot_csv_data (csv ):
362347 plt .figure (figsize = (24 , 12 ))
363348 sns .set_style ("whitegrid" )
364349 df = pd .read_csv (csv , sep = ',' )
@@ -487,4 +472,4 @@ def plot_csv_data(csv, plot_name):
487472 ):
488473 fh2 .write (line )
489474
490- plot_csv_data (csv = clean_out_results_csv , plot_name = 'mut_performance' )
475+ plot_csv_data (csv = clean_out_results_csv )
0 commit comments