@@ -69,7 +69,6 @@ def get_logits_from_full_seqs(
6969 attention_mask = attention_mask ,
7070 ss_input_ids = structure_input_ids
7171 )
72-
7372 logits = torch .log_softmax (outputs .logits [:, 1 :- 1 ], dim = - 1 ).squeeze ()
7473 for i_s , sequence in enumerate (tqdm (xs , disable = not verbose , desc = 'Getting ProSST sequence logits' )):
7574 for i_aa , x_aa in enumerate (sequence ):
@@ -84,9 +83,6 @@ def get_logits_from_full_seqs(
8483 return log_probs
8584
8685
87-
88-
89-
9086def checkpoint (model , filename ):
9187 torch .save (model .state_dict (), filename )
9288
@@ -107,7 +103,6 @@ def prosst_train(
107103 print (f'ProSST training using { device .upper ()} device (N_Train={ len (torch .flatten (score_batches ))} )...' )
108104 x_sequence_batches = x_sequence_batches .to (device )
109105 score_batches = score_batches .to (device )
110-
111106 pbar_epochs = tqdm (range (1 , n_epochs + 1 ))
112107 epoch_spearman_1 = 0.0
113108 did_not_improve_counter = 0
@@ -191,7 +186,6 @@ def get_structure_quantizied(pdb_file, tokenizer, wt_seq):
191186 return input_ids , attention_mask , structure_input_ids
192187
193188
194-
195189def prosst_setup (wt_seq , pdb_file , sequences , device : str | None = None ):
196190 prosst_base_model , prosst_lora_model , prosst_tokenizer , prosst_optimizer = get_prosst_models ()
197191 prosst_vocab = prosst_tokenizer .get_vocab ()
@@ -215,66 +209,3 @@ def prosst_setup(wt_seq, pdb_file, sequences, device: str | None = None):
215209 }
216210 }
217211 return llm_dict_prosst
218-
219-
220- if __name__ == '__main__' :
221- import pandas as pd
222- import copy
223- from sklearn .model_selection import train_test_split
224- import matplotlib .pyplot as plt
225- # Test on dataset GRB2_HUMAN_Faure_2021: SignificanceResult(statistic=0.6997442598613315, pvalue=0.0)
226- wt_seq = "MEAIAKYDFKATADDELSFKRGDILKVLNEECDQNWYKAELNGKDGFIPKNYIEMKPHPWFFGKIPRAKAEEMLSKQRHDGAFLIRESESAPGDFSLSVKFGNDVQHFKVLRDGAGKYFLWVVKFNSLNELVDYHRSTSVSRNQQIFLRDIEQVPQQPTYVQALFDFDPQEDGELGFRRGDFIHVMDNSDPNWWKGACHGQTGMFPRNYVTPVNRNV"
227- grb2_folder = os .path .abspath (os .path .join (pypef_path , '..' , 'datasets' , 'GRB2' ))
228- pdb_file = os .path .join (grb2_folder , 'GRB2_HUMAN.pdb' )
229- csv_file = os .path .join (grb2_folder , 'GRB2_HUMAN_Faure_2021.csv' )
230- df = pd .read_csv (csv_file ) #, nrows=120)
231- print (df )
232- prosst_base_model , prosst_lora_model , tokenizer , optimizer = get_prosst_models ()
233- vocab = tokenizer .get_vocab ()
234- structure_sequence = PdbQuantizer ()(pdb_file = pdb_file )
235- structure_sequence_offset = [i + 3 for i in structure_sequence ]
236- tokenized_res = tokenizer ([wt_seq ], return_tensors = 'pt' )
237- input_ids = tokenized_res ['input_ids' ]
238- attention_mask = tokenized_res ['attention_mask' ]
239- structure_input_ids = torch .tensor ([1 , * structure_sequence_offset , 2 ], dtype = torch .long ).unsqueeze (0 )
240- #y_pred = get_logits_from_full_seqs(df['mutated_sequence'], prosst_model, input_ids, attention_mask, structure_input_ids, train=False)
241- #print(spearmanr(df['DMS_score'], y_pred.detach().cpu().numpy())) # SignificanceResult(statistic=np.float64(0.7216670719282277), pvalue=np.float64(0.0))
242- x_sequences = prosst_tokenize_sequences (df ['mutated_sequence' ], vocab = vocab )
243- for batch_size in [5 , 10 , 25 , 50 , 100 ]:
244- train_perfs_unsup , test_perfs_unsup = [], []
245- train_perfs , test_perfs = [], []
246- for train_size in [200 , 1000 , 10000 ]:
247- prosst_model_copy = copy .deepcopy (prosst_base_model )
248- x_train , x_test , scores_train , scores_test = train_test_split (
249- x_sequences , df ['DMS_score' ].to_numpy ().astype (float ), train_size = train_size , random_state = 42
250- )
251- print (f"\n =========================\n TRAIN SIZE: { train_size } TEST SIZE: { len (x_test )} -- BATCH SIZE: { batch_size } \n =========================" )
252-
253- y_pred = get_logits_from_full_seqs (
254- x_test , prosst_model_copy , input_ids , attention_mask , structure_input_ids , train = False )
255- print (f'Train-->Test UNTRAINED Performance (N={ len (y_pred .flatten ())} ):' ,spearmanr (scores_test , y_pred .detach ().cpu ().numpy ()))
256- test_perfs_unsup .append (spearmanr (scores_test , y_pred .detach ().cpu ().numpy ()))
257-
258-
259- y_preds_train_unsup = get_logits_from_full_seqs (
260- x_train , prosst_model_copy , input_ids , attention_mask , structure_input_ids , train = False , verbose = False )
261- y_preds_train_unsup = y_preds_train_unsup .cpu ().numpy ()
262- print (f'Train-->Train UNTRAINED Performance (N={ len (y_preds_train_unsup )} ):' , spearmanr (scores_train , y_preds_train_unsup ))
263- train_perfs_unsup .append (spearmanr (scores_train , y_preds_train_unsup )[0 ])
264-
265- # TRAINING
266- x_train_b = get_batches (x_train , dtype = int , batch_size = batch_size , verbose = True )
267- scores_train_b = get_batches (scores_train , dtype = float , batch_size = batch_size , verbose = True )
268- y_preds_train = prosst_train (x_train_b , scores_train_b , corr_loss , prosst_model_copy , optimizer , pdb_file , n_epochs = 500 )
269- print (f'Train-->Train Performance (N={ len (y_preds_train )} ):' , spearmanr (scores_train , y_preds_train ))
270- train_perfs .append (spearmanr (scores_train , y_preds_train )[0 ])
271-
272- y_pred = get_logits_from_full_seqs (
273- x_test , prosst_model_copy , input_ids , attention_mask , structure_input_ids , train = False )
274- print (f'Train-->Test Performance (N={ len (y_pred .flatten ())} ):' , spearmanr (scores_test , y_pred .detach ().cpu ().numpy ()))
275- test_perfs .append (spearmanr (scores_test , y_pred .detach ().cpu ().numpy ())[0 ])
276- for k in [train_perfs_unsup , train_perfs , test_perfs_unsup , test_perfs ]:
277- plt .plot (range (len (k )), k , label = f'Batch size: { batch_size } ' )
278- plt .xticks (range (len (k )), [100 , 200 , 1000 , 10000 ])
279- plt .legend ()
280- plt .savefig ('1.png' )
0 commit comments