@@ -71,7 +71,7 @@ def test_gremlin():
7171 )
7272
7373
74- def test_hybrid_model ():
74+ def test_hybrid_model_dca_esm ():
7575 g = GREMLIN (
7676 alignment = msa_file_aneh ,
7777 char_alphabet = "ARNDCQEGHILKMFPSTWYV-" ,
@@ -90,45 +90,58 @@ def test_hybrid_model():
9090 print (len (train_seqs [0 ]), train_seqs [0 ])
9191 assert len (train_seqs [0 ]) == len (g .wt_seq )
9292 base_model , lora_model , tokenizer , optimizer = get_esm_models ()
93- encoded_seqs_train , attention_masks_train = esm_tokenize_sequences (list ( train_seqs ), tokenizer , max_length = len ( train_seqs [ 0 ]))
94- x_esm_b , attention_masks_b , train_ys_b = (
95- get_batches ( encoded_seqs_train , dtype = int ),
96- get_batches (attention_masks_train , dtype = int ),
93+ x_train_esm , attention_mask_esm = esm_tokenize_sequences (
94+ list ( train_seqs ), tokenizer , max_length = len ( train_seqs [ 0 ]))
95+ x_esm_b , train_ys_b = (
96+ get_batches (x_train_esm , dtype = int ),
9797 get_batches (train_ys , dtype = float )
9898 )
99- y_true , y_pred_esm = esm_test (x_esm_b , attention_masks_b , train_ys_b , loss_fn = corr_loss , model = base_model )
99+ y_true , y_pred_esm = esm_test (
100+ x_esm_b , attention_mask_esm , train_ys_b ,
101+ loss_fn = corr_loss , model = base_model
102+ )
100103 print (spearmanr (
101104 y_true ,
102105 y_pred_esm
103106 ), len (y_true ))
104107
108+ llm_dict_esm = {
109+ 'esm1v' : {
110+ 'llm_base_model' : base_model ,
111+ 'llm_model' : lora_model ,
112+ 'llm_optimizer' : optimizer ,
113+ 'llm_train_function' : esm_train ,
114+ 'llm_inference_function' : esm_infer ,
115+ 'llm_loss_function' : corr_loss ,
116+ 'x_llm_train' : x_train_esm ,
117+ 'llm_attention_mask' : attention_mask_esm
118+ }
119+ }
120+
105121 hm = DCALLMHybridModel (
106122 x_train_dca = np .array (x_dca_train ),
107- x_train_llm = np .array (encoded_seqs_train ),
108- x_train_llm_attention_mask = np .array (attention_masks_train ),
109123 y_train = train_ys ,
110- llm_model = lora_model ,
111- llm_base_model = base_model ,
112- llm_optimizer = optimizer ,
113- llm_train_function = esm_train ,
114- llm_inference_function = esm_infer ,
115- llm_loss_function = corr_loss ,
124+ llm_model_input = llm_dict_esm ,
116125 x_wt = g .x_wt ,
117126 seed = 42
118127 )
119128
120129 x_dca_test = g .get_scores (test_seqs , encode = True )
121- encoded_seqs_test , attention_masks_test = esm_tokenize_sequences (list (test_seqs ), tokenizer , max_length = len (test_seqs [0 ]))
122- y_pred_test = hm .hybrid_prediction (x_dca = x_dca_test , x_llm = encoded_seqs_test , attns_llm = attention_masks_test )
130+ encoded_seqs_test , attention_masks_test = esm_tokenize_sequences (
131+ list (test_seqs ), tokenizer , max_length = len (test_seqs [0 ]))
132+ y_pred_test = hm .hybrid_prediction (x_dca = x_dca_test , x_llm = encoded_seqs_test )
123133 print (hm .beta1 , hm .beta2 , hm .beta3 , hm .beta4 , hm .ridge_opt )
124134 print ('hm.y_dca_ttest' , spearmanr (hm .y_ttest , hm .y_dca_ttest ), len (hm .y_ttest ))
125135 print ('hm.y_dca_ridge_ttest' , spearmanr (hm .y_ttest , hm .y_dca_ridge_ttest ), len (hm .y_ttest ))
126136 print ('hm.y_llm_ttest' , spearmanr (hm .y_ttest , hm .y_llm_ttest ), len (hm .y_ttest ))
127137 print ('hm.y_llm_lora_ttest' , spearmanr (hm .y_ttest , hm .y_llm_lora_ttest ), len (hm .y_ttest ))
128138 print ('Hybrid' , spearmanr (test_ys , y_pred_test ), len (test_ys ))
129- np .testing .assert_almost_equal (spearmanr (hm .y_ttest , hm .y_dca_ttest )[0 ], - 0.5342743713116743 , decimal = 5 )
130- np .testing .assert_almost_equal (spearmanr (hm .y_ttest , hm .y_dca_ridge_ttest )[0 ], 0.717333573331078 , decimal = 5 )
131- np .testing .assert_almost_equal (spearmanr (hm .y_ttest , hm .y_llm_ttest )[0 ], - 0.21761360470606333 , decimal = 5 )
139+ np .testing .assert_almost_equal (spearmanr (
140+ hm .y_ttest , hm .y_dca_ttest )[0 ], - 0.5342743713116743 , decimal = 5 )
141+ np .testing .assert_almost_equal (spearmanr (
142+ hm .y_ttest , hm .y_dca_ridge_ttest )[0 ], 0.717333573331078 , decimal = 5 )
143+ np .testing .assert_almost_equal (spearmanr (
144+ hm .y_ttest , hm .y_llm_ttest )[0 ], - 0.21761360470606333 , decimal = 5 )
132145 # Nondeterministic behavior, should be about ~ 0.8, checking if not NaN
133146 # Torch reproducibility documentation: https://pytorch.org/docs/stable/notes/randomness.html
134147 assert - 1.0 <= spearmanr (hm .y_ttest , hm .y_llm_lora_ttest )[0 ] <= 1.0
@@ -137,8 +150,12 @@ def test_hybrid_model():
137150
138151def test_dataset_b_results ():
139152 aaindex = "WOLR810101.txt"
140- x_fft_train , _ = AAIndexEncoding (full_aaidx_txt_path (aaindex ), train_seqs ).collect_encoded_sequences ()
141- x_fft_test , _ = AAIndexEncoding (full_aaidx_txt_path (aaindex ), test_seqs ).collect_encoded_sequences ()
153+ x_fft_train , _ = AAIndexEncoding (
154+ full_aaidx_txt_path (aaindex ), train_seqs
155+ ).collect_encoded_sequences ()
156+ x_fft_test , _ = AAIndexEncoding (
157+ full_aaidx_txt_path (aaindex ), test_seqs
158+ ).collect_encoded_sequences ()
142159 performances = get_regressor_performances (
143160 x_learn = x_fft_train ,
144161 x_test = x_fft_test ,
@@ -157,6 +174,6 @@ def test_dataset_b_results():
157174
158175if __name__ == "__main__" :
159176 test_gremlin ()
160- test_hybrid_model ()
177+ test_hybrid_model_dca_esm ()
161178 test_dataset_b_results ()
162179
0 commit comments