99import os .path
1010import numpy as np
1111from scipy .stats import spearmanr
12+ import torch
1213
1314from pypef .ml .regression import AAIndexEncoding , full_aaidx_txt_path , get_regressor_performances
1415from pypef .dca .gremlin_inference import GREMLIN
1516from pypef .utils .variant_data import get_sequences_from_file , get_wt_sequence
1617from pypef .llm .esm_lora_tune import esm_setup
1718from pypef .llm .prosst_lora_tune import prosst_setup
18- from pypef .llm .utils import corr_loss , get_batches
1919from pypef .llm .inference import inference , llm_embedder
2020from pypef .hybrid .hybrid_model import DCALLMHybridModel
2121
4747train_seqs , _train_vars , train_ys = get_sequences_from_file (ls_b )
4848test_seqs , _test_vars , test_ys = get_sequences_from_file (ts_b )
4949
50+ torch .manual_seed (42 )
51+ np .random .seed (42 )
52+
5053
5154def test_gremlin ():
5255 g = GREMLIN (
@@ -68,7 +71,7 @@ def test_gremlin():
6871 )
6972
7073
71- def test_hybrid_model_dca_esm ():
74+ def test_hybrid_model_dca_llm ():
7275 g = GREMLIN (
7376 alignment = msa_file_aneh ,
7477 char_alphabet = "ARNDCQEGHILKMFPSTWYV-" ,
@@ -140,11 +143,27 @@ def test_hybrid_model_dca_esm():
140143 [- 0.21761360470606333 , - 0.8330644449247571 ][i ],
141144 decimal = 5
142145 )
143- # Nondeterministic behavior, should be about ~0.7 to ~0.9, but as sample size is so low
144- # the following is only checking if not NaN / >=-1.0 and <=1.0,
146+ # Nondeterministic behavior (without setting seed) , should be about ~0.7 to ~0.9,
147+ # but as sample size is so low the following is only checking if not NaN / >=-1.0 and <=1.0,
145148 # Torch reproducibility documentation: https://pytorch.org/docs/stable/notes/randomness.html
146149 assert - 1.0 <= spearmanr (hm .y_ttest , hm .y_llm_lora_ttest )[0 ] <= 1.0
147- assert - 1.0 <= spearmanr (test_ys , y_pred_test )[0 ] <= 1.0
150+ assert - 1.0 <= spearmanr (test_ys , y_pred_test )[0 ] <= 1.0
151+ # With seed 42 for numpy and torch for implemented LLM's:
152+ if setup == esm_setup :
153+ np .testing .assert_almost_equal (
154+ spearmanr (hm .y_ttest , hm .y_llm_lora_ttest )[0 ], 0.7772102863835341 , decimal = 5
155+ )
156+ np .testing .assert_almost_equal (
157+ spearmanr (test_ys , y_pred_test )[0 ], 0.8004896406836318 , decimal = 5
158+ )
159+ elif setup == prosst_setup :
160+ np .testing .assert_almost_equal (
161+ spearmanr (hm .y_ttest , hm .y_llm_lora_ttest )[0 ], 0.7770124558338013 , decimal = 5
162+ )
163+ np .testing .assert_almost_equal (
164+ spearmanr (test_ys , y_pred_test )[0 ], 0.8291977762544377 , decimal = 5
165+ )
166+
148167
149168
150169def test_dataset_b_results ():
@@ -173,6 +192,6 @@ def test_dataset_b_results():
173192
174193if __name__ == "__main__" :
175194 test_gremlin ()
176- test_hybrid_model_dca_esm ()
195+ test_hybrid_model_dca_llm ()
177196 test_dataset_b_results ()
178197
0 commit comments