Skip to content

Commit 40a7848

Browse files
committed
Add seeds to tests
for hybrid model performance reproducibility
1 parent a728f9d commit 40a7848

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

tests/test_api_functions.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
import os.path
1010
import numpy as np
1111
from scipy.stats import spearmanr
12+
import torch
1213

1314
from pypef.ml.regression import AAIndexEncoding, full_aaidx_txt_path, get_regressor_performances
1415
from pypef.dca.gremlin_inference import GREMLIN
1516
from pypef.utils.variant_data import get_sequences_from_file, get_wt_sequence
1617
from pypef.llm.esm_lora_tune import esm_setup
1718
from pypef.llm.prosst_lora_tune import prosst_setup
18-
from pypef.llm.utils import corr_loss, get_batches
1919
from pypef.llm.inference import inference, llm_embedder
2020
from pypef.hybrid.hybrid_model import DCALLMHybridModel
2121

@@ -47,6 +47,9 @@
4747
train_seqs, _train_vars, train_ys = get_sequences_from_file(ls_b)
4848
test_seqs, _test_vars, test_ys = get_sequences_from_file(ts_b)
4949

50+
torch.manual_seed(42)
51+
np.random.seed(42)
52+
5053

5154
def test_gremlin():
5255
g = GREMLIN(
@@ -68,7 +71,7 @@ def test_gremlin():
6871
)
6972

7073

71-
def test_hybrid_model_dca_esm():
74+
def test_hybrid_model_dca_llm():
7275
g = GREMLIN(
7376
alignment=msa_file_aneh,
7477
char_alphabet="ARNDCQEGHILKMFPSTWYV-",
@@ -140,11 +143,27 @@ def test_hybrid_model_dca_esm():
140143
[-0.21761360470606333, -0.8330644449247571][i],
141144
decimal=5
142145
)
143-
# Nondeterministic behavior, should be about ~0.7 to ~0.9, but as sample size is so low
144-
# the following is only checking if not NaN / >=-1.0 and <=1.0,
146+
# Nondeterministic behavior (without setting seed), should be about ~0.7 to ~0.9,
147+
# but as sample size is so low the following is only checking if not NaN / >=-1.0 and <=1.0,
145148
# Torch reproducibility documentation: https://pytorch.org/docs/stable/notes/randomness.html
146149
assert -1.0 <= spearmanr(hm.y_ttest, hm.y_llm_lora_ttest)[0] <= 1.0
147-
assert -1.0 <= spearmanr(test_ys, y_pred_test)[0] <= 1.0
150+
assert -1.0 <= spearmanr(test_ys, y_pred_test)[0] <= 1.0
151+
# With seed 42 for numpy and torch for implemented LLM's:
152+
if setup == esm_setup:
153+
np.testing.assert_almost_equal(
154+
spearmanr(hm.y_ttest, hm.y_llm_lora_ttest)[0], 0.7772102863835341, decimal=5
155+
)
156+
np.testing.assert_almost_equal(
157+
spearmanr(test_ys, y_pred_test)[0], 0.8004896406836318, decimal=5
158+
)
159+
elif setup == prosst_setup:
160+
np.testing.assert_almost_equal(
161+
spearmanr(hm.y_ttest, hm.y_llm_lora_ttest)[0], 0.7770124558338013, decimal=5
162+
)
163+
np.testing.assert_almost_equal(
164+
spearmanr(test_ys, y_pred_test)[0], 0.8291977762544377, decimal=5
165+
)
166+
148167

149168

150169
def test_dataset_b_results():
@@ -173,6 +192,6 @@ def test_dataset_b_results():
173192

174193
if __name__ == "__main__":
175194
test_gremlin()
176-
test_hybrid_model_dca_esm()
195+
test_hybrid_model_dca_llm()
177196
test_dataset_b_results()
178197

0 commit comments

Comments
 (0)