2020from pypef .hybrid .hybrid_model import DCALLMHybridModel
2121
2222
23+ torch .manual_seed (42 )
24+ np .random .seed (42 )
25+
2326msa_file_avgfp = os .path .abspath (os .path .join (
2427 __file__ , '../../datasets/AVGFP/uref100_avgfp_jhmmer_119.a2m'
2528))
4447 os .path .join (__file__ , '../../datasets/ANEH/TS_B.fasl'
4548))
4649
47- train_seqs , _train_vars , train_ys = get_sequences_from_file (ls_b )
48- test_seqs , _test_vars , test_ys = get_sequences_from_file (ts_b )
49-
50- torch .manual_seed (42 )
51- np .random .seed (42 )
50+ train_seqs_aneh , _train_vars_aneh , train_ys_aneh = get_sequences_from_file (ls_b )
51+ test_seqs_aneh , _test_vars_aneh , test_ys_aneh = get_sequences_from_file (ts_b )
5252
5353
54- def test_gremlin ():
54+ def test_gremlin_aneh ():
5555 g = GREMLIN (
56- alignment = msa_file_avgfp ,
56+ alignment = msa_file_aneh ,
5757 char_alphabet = "ARNDCQEGHILKMFPSTWYV-" ,
5858 wt_seq = None ,
5959 optimize = True ,
6060 gap_cutoff = 0.5 ,
6161 eff_cutoff = 0.8 ,
6262 opt_iter = 100
6363 )
64- wt_score = g .get_wt_score () # only 1 decimal place for Torch result
65- np .testing .assert_almost_equal (wt_score , 952.1 , decimal = 1 )
66- y_pred = g .get_scores (np .append (train_seqs , test_seqs ))
64+ wt_score = g .get_wt_score ()
65+ np .testing .assert_almost_equal (wt_score , 1743.2087199198131 , decimal = 7 )
66+ assert wt_score == g .wt_score == np .sum (g .x_wt )
67+ y_pred = g .get_scores (np .append (train_seqs_aneh , test_seqs_aneh ))
6768 np .testing .assert_almost_equal (
68- spearmanr (np .append (train_ys , test_ys ), y_pred )[0 ],
69- 0.4516502675400598 ,
70- decimal = 3
69+ spearmanr (np .append (train_ys_aneh , test_ys_aneh ), y_pred )[0 ],
70+ - 0.5528510930046211 ,
71+ decimal = 7
72+ )
73+
74+
75+ def test_gremlin_avgfp ():
76+ g = GREMLIN (
77+ alignment = msa_file_avgfp ,
78+ char_alphabet = "ARNDCQEGHILKMFPSTWYV-" ,
79+ wt_seq = None ,
80+ optimize = True ,
81+ gap_cutoff = 0.5 ,
82+ eff_cutoff = 0.8 ,
83+ opt_iter = 100
7184 )
85+ wt_score = g .get_wt_score ()
86+ np .testing .assert_almost_equal (wt_score , 952.1102220697624 , decimal = 7 )
87+ assert wt_score == g .wt_score == np .sum (g .x_wt )
7288
7389
7490def test_hybrid_model_dca_llm ():
@@ -81,43 +97,43 @@ def test_hybrid_model_dca_llm():
8197 eff_cutoff = 0.8 ,
8298 opt_iter = 100
8399 )
84- x_dca_train = g .get_scores (train_seqs , encode = True )
100+ x_dca_train = g .get_scores (train_seqs_aneh , encode = True )
85101 np .testing .assert_almost_equal (
86- spearmanr (train_ys , np .sum (x_dca_train , axis = 1 ))[0 ],
102+ spearmanr (train_ys_aneh , np .sum (x_dca_train , axis = 1 ))[0 ],
87103 - 0.5556053466180598 ,
88- decimal = 6
104+ decimal = 7
89105 )
90- assert len (train_seqs [0 ]) == len (g .wt_seq )
106+ assert len (train_seqs_aneh [0 ]) == len (g .wt_seq )
91107
92- y_pred_esm = inference (train_seqs , 'esm' )
108+ y_pred_esm = inference (train_seqs_aneh , 'esm' )
93109 np .testing .assert_almost_equal (
94- spearmanr (train_ys , y_pred_esm )[0 ],
110+ spearmanr (train_ys_aneh , y_pred_esm )[0 ],
95111 - 0.21073416060442696 ,
96- decimal = 6
112+ decimal = 7
97113 )
98114 aneh_wt_seq = get_wt_sequence (wt_seq_file_aneh )
99115 y_pred_prosst = inference (
100- train_seqs , 'prosst' ,
116+ train_seqs_aneh , 'prosst' ,
101117 pdb_file = pdb_file_aneh , wt_seq = aneh_wt_seq
102118 )
103119 np .testing .assert_almost_equal (
104- spearmanr (train_ys , y_pred_prosst )[0 ],
120+ spearmanr (train_ys_aneh , y_pred_prosst )[0 ],
105121 - 0.7425657069861902 ,
106- decimal = 6
122+ decimal = 7
107123 )
108124
109- x_dca_test = g .get_scores (test_seqs , encode = True )
125+ x_dca_test = g .get_scores (test_seqs_aneh , encode = True )
110126 for i , setup in enumerate ([esm_setup , prosst_setup ]):
111127 print (['~~~ ESM ~~~' , '~~~ ProSST ~~~' ][i ])
112128 if setup == esm_setup :
113- llm_dict = setup (sequences = train_seqs )
129+ llm_dict = setup (sequences = train_seqs_aneh )
114130 else : # elif setup == prosst_setup:
115131 llm_dict = setup (
116- aneh_wt_seq , pdb_file_aneh , sequences = train_seqs )
117- x_llm_test = llm_embedder (llm_dict , test_seqs )
132+ aneh_wt_seq , pdb_file_aneh , sequences = train_seqs_aneh )
133+ x_llm_test = llm_embedder (llm_dict , test_seqs_aneh )
118134 hm = DCALLMHybridModel (
119135 x_train_dca = np .array (x_dca_train ),
120- y_train = train_ys ,
136+ y_train = train_ys_aneh ,
121137 llm_model_input = llm_dict ,
122138 x_wt = g .x_wt ,
123139 seed = 42
@@ -129,56 +145,66 @@ def test_hybrid_model_dca_llm():
129145 print ('hm.y_dca_ridge_ttest:' , spearmanr (hm .y_ttest , hm .y_dca_ridge_ttest ), len (hm .y_ttest ))
130146 print ('hm.y_llm_ttest:' , spearmanr (hm .y_ttest , hm .y_llm_ttest ), len (hm .y_ttest ))
131147 print ('hm.y_llm_lora_ttest:' , spearmanr (hm .y_ttest , hm .y_llm_lora_ttest ), len (hm .y_ttest ))
132- print ('Hybrid prediction:' , spearmanr (test_ys , y_pred_test ), len (test_ys ))
148+ print ('Hybrid prediction:' , spearmanr (test_ys_aneh , y_pred_test ), len (test_ys_aneh ))
133149 np .testing .assert_almost_equal (
134150 spearmanr (hm .y_ttest , hm .y_dca_ttest )[0 ], - 0.5342743713116743 ,
135- decimal = 5
151+ decimal = 7
136152 )
137153 np .testing .assert_almost_equal (
138154 spearmanr (hm .y_ttest , hm .y_dca_ridge_ttest )[0 ], 0.717333573331078 ,
139- decimal = 5
155+ decimal = 7
140156 )
141157 np .testing .assert_almost_equal (
142158 spearmanr (hm .y_ttest , hm .y_llm_ttest )[0 ],
143159 [- 0.21761360470606333 , - 0.8330644449247571 ][i ],
144- decimal = 5
160+ decimal = 7
145161 )
146162 # Nondeterministic behavior (without setting seed), should be about ~0.7 to ~0.9,
147163 # but as sample size is so low the following is only checking if not NaN / >=-1.0 and <=1.0,
148164 # Torch reproducibility documentation: https://pytorch.org/docs/stable/notes/randomness.html
149165 assert - 1.0 <= spearmanr (hm .y_ttest , hm .y_llm_lora_ttest )[0 ] <= 1.0
150- assert - 1.0 <= spearmanr (test_ys , y_pred_test )[0 ] <= 1.0
166+ assert - 1.0 <= spearmanr (test_ys_aneh , y_pred_test )[0 ] <= 1.0
151167 # With seed 42 for numpy and torch for implemented LLM's:
152168 if setup == esm_setup :
153169 np .testing .assert_almost_equal (
154- spearmanr (hm .y_ttest , hm .y_llm_lora_ttest )[0 ], 0.7772102863835341 , decimal = 5
170+ spearmanr (hm .y_ttest , hm .y_llm_lora_ttest )[0 ], 0.7772102863835341 , decimal = 7
155171 )
156172 np .testing .assert_almost_equal (
157- spearmanr (test_ys , y_pred_test )[0 ], 0.8004896406836318 , decimal = 5
173+ spearmanr (test_ys_aneh , y_pred_test )[0 ], 0.8004896406836318 , decimal = 7
158174 )
159175 elif setup == prosst_setup :
176+ try :
177+ np .testing .assert_almost_equal (
178+ spearmanr (hm .y_ttest , hm .y_llm_lora_ttest )[0 ], 0.7770124558338013 , decimal = 7
179+ )
180+ except AssertionError as ae1 :
181+ try :
182+ np .testing .assert_almost_equal ( # Different values on different machines
183+ spearmanr (hm .y_ttest , hm .y_llm_lora_ttest )[0 ], 0.7239938685054149 , decimal = 7
184+ ) # (TODO) has to be investigated
185+ except AssertionError as ae2 :
186+ raise AssertionError (
187+ f"Neither condition passed:\n First comparison failed:\n { ae1 } \n "
188+ f"Second comparison failed:\n { ae2 } "
189+ )
160190 np .testing .assert_almost_equal (
161- spearmanr (hm . y_ttest , hm . y_llm_lora_ttest )[0 ], 0.7770124558338013 , decimal = 5
191+ spearmanr (test_ys_aneh , y_pred_test )[0 ], 0.8291977762544377 , decimal = 7
162192 )
163- np .testing .assert_almost_equal (
164- spearmanr (test_ys , y_pred_test )[0 ], 0.8291977762544377 , decimal = 5
165- )
166-
167193
168194
169195def test_dataset_b_results ():
170196 aaindex = "WOLR810101.txt"
171197 x_fft_train , _ = AAIndexEncoding (
172- full_aaidx_txt_path (aaindex ), train_seqs
198+ full_aaidx_txt_path (aaindex ), train_seqs_aneh
173199 ).collect_encoded_sequences ()
174200 x_fft_test , _ = AAIndexEncoding (
175- full_aaidx_txt_path (aaindex ), test_seqs
201+ full_aaidx_txt_path (aaindex ), test_seqs_aneh
176202 ).collect_encoded_sequences ()
177203 performances = get_regressor_performances (
178- x_learn = x_fft_train ,
179- x_test = x_fft_test ,
180- y_learn = train_ys ,
181- y_test = test_ys ,
204+ x_learn = x_fft_train ,
205+ x_test = x_fft_test ,
206+ y_learn = train_ys_aneh ,
207+ y_test = test_ys_aneh ,
182208 regressor = 'pls_loocv'
183209 )
184210 # Dataset B PLS_LOOCV results: R², RMSE, NRMSE, Pearson's r, Spearman's rho
@@ -191,7 +217,8 @@ def test_dataset_b_results():
191217
192218
193219if __name__ == "__main__" :
194- test_gremlin ()
220+ test_gremlin_aneh ()
221+ test_gremlin_avgfp ()
195222 test_hybrid_model_dca_llm ()
196223 test_dataset_b_results ()
197224
0 commit comments