Skip to content

Commit 7d6465c

Browse files
committed
Update gremlin (dev): implement msa_start & msa_end (II)
1 parent 03b88f8 commit 7d6465c

File tree

5 files changed

+91
-72
lines changed

5 files changed

+91
-72
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ jobs:
5050
sleep 3
5151
pip install .[gui]
5252
echo $(which pypef)
53-
python -m pytest ./tests/ -v -m "not main_script_specific" --capture=tee-sys --log-cli-level=INFO
53+
python -m pytest ./tests/ -v -m "not main_script_specific" --log-cli-level=INFO
5454
5555
windows:
5656
name: windows
@@ -88,5 +88,5 @@ jobs:
8888
$env:PYTHONPATH = ""
8989
pip install .[gui]
9090
echo (Get-Command pypef).Source
91-
python -m pytest .\tests -v -m "not main_script_specific" --capture=tee-sys --log-cli-level=INFO
91+
python -m pytest .\tests -v -m "not main_script_specific" --log-cli-level=INFO
9292

pypef/dca/gremlin_inference.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def __init__(
145145
self.optimize = optimize
146146
if self.optimize:
147147
self.run_optimization()
148+
self.wt_score = self.get_wt_score()
148149
self.x_wt = self.collect_encoded_sequences(np.atleast_1d(self.wt_seq))
149150

150151
def get_sequences_from_msa(self, msa_file: str):
@@ -437,17 +438,17 @@ def get_scores(self, seqs, v=None, w=None, v_idx=None, encode=False, h_wt_seq=0.
437438
v_idx = self.v_idx
438439
seqs_int = self.seq2int(seqs)
439440
wt_seq_len = len(self.wt_seq)
440-
#if np.shape(seqs_int)[1] != wt_seq_len:
441-
# raise RuntimeError(
442-
# f"Input sequence shape (length: {np.shape(seqs_int)[1]}) does not match GREMLIN "
443-
# f"MSA shape (common sequence length: {wt_seq_len}) inferred from the MSA."
444-
# )
441+
if np.shape(seqs_int)[1] != wt_seq_len:
442+
raise RuntimeError(
443+
f"Input sequence shape (length: {np.shape(seqs_int)[1]}) does not match GREMLIN "
444+
f"MSA shape (common sequence length: {wt_seq_len}) inferred from the MSA."
445+
)
445446
# Check nums of mutations to MSA first/WT sequence and gives warning if too apart from MSA seq
446447
for i, seq in enumerate(seqs):
447448
n_mismatches, mismatches = get_mismatches(self.wt_seq, seq)
448449
if n_mismatches / wt_seq_len > 0.05:
449450
logger.warning(
450-
f"Sequence {mismatches} contains more than 5% sequence mismatches to the "
451+
f"Sequence {i + 1}: {mismatches} contains more than 5% sequence mismatches to the "
451452
f"first MSA/\"WT\" sequence. Effect predictions will likely be incorrect!"
452453
)
453454
try:
@@ -496,8 +497,8 @@ def get_scores(self, seqs, v=None, w=None, v_idx=None, encode=False, h_wt_seq=0.
496497
def get_wt_score(self, wt_seq=None, encode=False):
497498
if wt_seq is None:
498499
wt_seq = self.wt_seq
499-
wt_seq = np.array(wt_seq, dtype=str)
500-
return self.get_scores(wt_seq, encode=encode)
500+
wt_seq = np.atleast_1d(np.array(wt_seq, dtype=str))
501+
return self.get_scores(wt_seq, encode=encode)[0]
501502

502503
def collect_encoded_sequences(self, seqs, v=None, w=None, v_idx=None):
503504
"""

scripts/ProteinGym_runs/protgym_hybrid_perf_test_crossval.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -144,12 +144,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
144144

145145
print('GREMLIN-DCA: optimization...')
146146
gremlin = GREMLIN(alignment=msa_path, opt_iter=100, optimize=True)
147-
sequences_batched = get_batches(sequences, batch_size=1000,
148-
dtype=str, keep_remaining=True, verbose=True)
149-
x_dca = [] # required later on also
150-
for seq_b in tqdm(sequences_batched, desc="Getting GREMLIN sequence encodings", disable=True):
151-
for x in gremlin.collect_encoded_sequences(seq_b):
152-
x_dca.append(x)
147+
x_dca = gremlin.collect_encoded_sequences(sequences)
153148
x_wt = gremlin.x_wt
154149
y_pred_dca = get_delta_e_statistical_model(x_dca, x_wt)
155150
print(f'DCA (unsupervised performance): {spearmanr(fitnesses, y_pred_dca)[0]:.3f}')
@@ -453,7 +448,7 @@ def plot_csv_data(csv):
453448
if not JUST_PLOT_RESULTS:
454449
compute_performances(
455450
mut_data=combined_mut_data,
456-
start_i=start_i,
451+
start_i=5,#start_i,
457452
already_tested_is=already_tested_is
458453
)
459454

scripts/ProteinGym_runs/protgym_hybrid_perf_test_low_n.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
import sys # Use local directory PyPEF files
2323
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
2424
from pypef.dca.gremlin_inference import GREMLIN
25+
from pypef.llm.utils import get_batches
2526
from pypef.llm.esm_lora_tune import (
2627
get_esm_models, esm_tokenize_sequences,
27-
get_batches, esm_train, esm_infer, corr_loss
28+
esm_train, esm_infer, corr_loss
2829
)
2930
from pypef.llm.prosst_lora_tune import (
3031
get_logits_from_full_seqs, get_prosst_models, get_structure_quantizied,
@@ -117,7 +118,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
117118
f'{max_muts},Sequence too long ({len(wt_seq)} > {MAX_WT_SEQUENCE_LENGTH})\n'
118119
)
119120
continue
120-
ratio_input_vars_at_gaps = count_gap_variants / len(variants)
121+
_ratio_input_vars_at_gaps = count_gap_variants / len(variants)
121122
pdb_seq = str(list(SeqIO.parse(pdb, "pdb-atom"))[0].seq)
122123
try:
123124
assert wt_seq == pdb_seq # pdb_seq.startswith(wt_seq)
@@ -135,12 +136,7 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
135136

136137
print('GREMLIN-DCA: optimization...')
137138
gremlin = GREMLIN(alignment=msa_path, opt_iter=100, optimize=True)
138-
sequences_batched = get_batches(sequences, batch_size=1000,
139-
dtype=str, keep_remaining=True, verbose=True)
140-
x_dca = []
141-
for seq_b in tqdm(sequences_batched, desc="Getting GREMLIN sequence encodings"):
142-
for x in gremlin.collect_encoded_sequences(seq_b):
143-
x_dca.append(x)
139+
x_dca = gremlin.collect_encoded_sequences(sequences)
144140
x_wt = gremlin.x_wt
145141
y_pred_dca = get_delta_e_statistical_model(x_dca, x_wt)
146142
print(f'DCA (unsupervised performance): {spearmanr(fitnesses, y_pred_dca)[0]:.3f}')

tests/test_api_functions.py

Lines changed: 74 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from pypef.hybrid.hybrid_model import DCALLMHybridModel
2121

2222

23+
torch.manual_seed(42)
24+
np.random.seed(42)
25+
2326
msa_file_avgfp = os.path.abspath(os.path.join(
2427
__file__, '../../datasets/AVGFP/uref100_avgfp_jhmmer_119.a2m'
2528
))
@@ -44,31 +47,44 @@
4447
os.path.join(__file__, '../../datasets/ANEH/TS_B.fasl'
4548
))
4649

47-
train_seqs, _train_vars, train_ys = get_sequences_from_file(ls_b)
48-
test_seqs, _test_vars, test_ys = get_sequences_from_file(ts_b)
49-
50-
torch.manual_seed(42)
51-
np.random.seed(42)
50+
train_seqs_aneh, _train_vars_aneh, train_ys_aneh = get_sequences_from_file(ls_b)
51+
test_seqs_aneh, _test_vars_aneh, test_ys_aneh = get_sequences_from_file(ts_b)
5252

5353

54-
def test_gremlin():
54+
def test_gremlin_aneh():
5555
g = GREMLIN(
56-
alignment=msa_file_avgfp,
56+
alignment=msa_file_aneh,
5757
char_alphabet="ARNDCQEGHILKMFPSTWYV-",
5858
wt_seq=None,
5959
optimize=True,
6060
gap_cutoff=0.5,
6161
eff_cutoff=0.8,
6262
opt_iter=100
6363
)
64-
wt_score = g.get_wt_score() # only 1 decimal place for Torch result
65-
np.testing.assert_almost_equal(wt_score, 952.1, decimal=1)
66-
y_pred = g.get_scores(np.append(train_seqs, test_seqs))
64+
wt_score = g.get_wt_score()
65+
np.testing.assert_almost_equal(wt_score, 1743.2087199198131, decimal=7)
66+
assert wt_score == g.wt_score == np.sum(g.x_wt)
67+
y_pred = g.get_scores(np.append(train_seqs_aneh, test_seqs_aneh))
6768
np.testing.assert_almost_equal(
68-
spearmanr(np.append(train_ys, test_ys), y_pred)[0],
69-
0.4516502675400598,
70-
decimal=3
69+
spearmanr(np.append(train_ys_aneh, test_ys_aneh), y_pred)[0],
70+
-0.5528510930046211,
71+
decimal=7
72+
)
73+
74+
75+
def test_gremlin_avgfp():
76+
g = GREMLIN(
77+
alignment=msa_file_avgfp,
78+
char_alphabet="ARNDCQEGHILKMFPSTWYV-",
79+
wt_seq=None,
80+
optimize=True,
81+
gap_cutoff=0.5,
82+
eff_cutoff=0.8,
83+
opt_iter=100
7184
)
85+
wt_score = g.get_wt_score()
86+
np.testing.assert_almost_equal(wt_score, 952.1102220697624, decimal=7)
87+
assert wt_score == g.wt_score == np.sum(g.x_wt)
7288

7389

7490
def test_hybrid_model_dca_llm():
@@ -81,43 +97,43 @@ def test_hybrid_model_dca_llm():
8197
eff_cutoff=0.8,
8298
opt_iter=100
8399
)
84-
x_dca_train = g.get_scores(train_seqs, encode=True)
100+
x_dca_train = g.get_scores(train_seqs_aneh, encode=True)
85101
np.testing.assert_almost_equal(
86-
spearmanr(train_ys, np.sum(x_dca_train, axis=1))[0],
102+
spearmanr(train_ys_aneh, np.sum(x_dca_train, axis=1))[0],
87103
-0.5556053466180598,
88-
decimal=6
104+
decimal=7
89105
)
90-
assert len(train_seqs[0]) == len(g.wt_seq)
106+
assert len(train_seqs_aneh[0]) == len(g.wt_seq)
91107

92-
y_pred_esm = inference(train_seqs, 'esm')
108+
y_pred_esm = inference(train_seqs_aneh, 'esm')
93109
np.testing.assert_almost_equal(
94-
spearmanr(train_ys, y_pred_esm)[0],
110+
spearmanr(train_ys_aneh, y_pred_esm)[0],
95111
-0.21073416060442696,
96-
decimal=6
112+
decimal=7
97113
)
98114
aneh_wt_seq = get_wt_sequence(wt_seq_file_aneh)
99115
y_pred_prosst = inference(
100-
train_seqs, 'prosst',
116+
train_seqs_aneh, 'prosst',
101117
pdb_file=pdb_file_aneh, wt_seq=aneh_wt_seq
102118
)
103119
np.testing.assert_almost_equal(
104-
spearmanr(train_ys, y_pred_prosst)[0],
120+
spearmanr(train_ys_aneh, y_pred_prosst)[0],
105121
-0.7425657069861902,
106-
decimal=6
122+
decimal=7
107123
)
108124

109-
x_dca_test = g.get_scores(test_seqs, encode=True)
125+
x_dca_test = g.get_scores(test_seqs_aneh, encode=True)
110126
for i, setup in enumerate([esm_setup, prosst_setup]):
111127
print(['~~~ ESM ~~~', '~~~ ProSST ~~~'][i])
112128
if setup == esm_setup:
113-
llm_dict = setup(sequences=train_seqs)
129+
llm_dict = setup(sequences=train_seqs_aneh)
114130
else: # elif setup == prosst_setup:
115131
llm_dict = setup(
116-
aneh_wt_seq, pdb_file_aneh, sequences=train_seqs)
117-
x_llm_test = llm_embedder(llm_dict, test_seqs)
132+
aneh_wt_seq, pdb_file_aneh, sequences=train_seqs_aneh)
133+
x_llm_test = llm_embedder(llm_dict, test_seqs_aneh)
118134
hm = DCALLMHybridModel(
119135
x_train_dca=np.array(x_dca_train),
120-
y_train=train_ys,
136+
y_train=train_ys_aneh,
121137
llm_model_input=llm_dict,
122138
x_wt=g.x_wt,
123139
seed=42
@@ -129,56 +145,66 @@ def test_hybrid_model_dca_llm():
129145
print('hm.y_dca_ridge_ttest:', spearmanr(hm.y_ttest, hm.y_dca_ridge_ttest), len(hm.y_ttest))
130146
print('hm.y_llm_ttest:', spearmanr(hm.y_ttest, hm.y_llm_ttest), len(hm.y_ttest))
131147
print('hm.y_llm_lora_ttest:', spearmanr(hm.y_ttest, hm.y_llm_lora_ttest), len(hm.y_ttest))
132-
print('Hybrid prediction:', spearmanr(test_ys, y_pred_test), len(test_ys))
148+
print('Hybrid prediction:', spearmanr(test_ys_aneh, y_pred_test), len(test_ys_aneh))
133149
np.testing.assert_almost_equal(
134150
spearmanr(hm.y_ttest, hm.y_dca_ttest)[0], -0.5342743713116743,
135-
decimal=5
151+
decimal=7
136152
)
137153
np.testing.assert_almost_equal(
138154
spearmanr(hm.y_ttest, hm.y_dca_ridge_ttest)[0], 0.717333573331078,
139-
decimal=5
155+
decimal=7
140156
)
141157
np.testing.assert_almost_equal(
142158
spearmanr(hm.y_ttest, hm.y_llm_ttest)[0],
143159
[-0.21761360470606333, -0.8330644449247571][i],
144-
decimal=5
160+
decimal=7
145161
)
146162
# Nondeterministic behavior (without setting seed), should be about ~0.7 to ~0.9,
147163
# but as sample size is so low the following is only checking if not NaN / >=-1.0 and <=1.0,
148164
# Torch reproducibility documentation: https://pytorch.org/docs/stable/notes/randomness.html
149165
assert -1.0 <= spearmanr(hm.y_ttest, hm.y_llm_lora_ttest)[0] <= 1.0
150-
assert -1.0 <= spearmanr(test_ys, y_pred_test)[0] <= 1.0
166+
assert -1.0 <= spearmanr(test_ys_aneh, y_pred_test)[0] <= 1.0
151167
# With seed 42 for numpy and torch for implemented LLM's:
152168
if setup == esm_setup:
153169
np.testing.assert_almost_equal(
154-
spearmanr(hm.y_ttest, hm.y_llm_lora_ttest)[0], 0.7772102863835341, decimal=5
170+
spearmanr(hm.y_ttest, hm.y_llm_lora_ttest)[0], 0.7772102863835341, decimal=7
155171
)
156172
np.testing.assert_almost_equal(
157-
spearmanr(test_ys, y_pred_test)[0], 0.8004896406836318, decimal=5
173+
spearmanr(test_ys_aneh, y_pred_test)[0], 0.8004896406836318, decimal=7
158174
)
159175
elif setup == prosst_setup:
176+
try:
177+
np.testing.assert_almost_equal(
178+
spearmanr(hm.y_ttest, hm.y_llm_lora_ttest)[0], 0.7770124558338013, decimal=7
179+
)
180+
except AssertionError as ae1:
181+
try:
182+
np.testing.assert_almost_equal( # Different values on different machines
183+
spearmanr(hm.y_ttest, hm.y_llm_lora_ttest)[0], 0.7239938685054149, decimal=7
184+
) # (TODO) has to be investigated
185+
except AssertionError as ae2:
186+
raise AssertionError(
187+
f"Neither condition passed:\nFirst comparison failed:\n{ae1}\n"
188+
f"Second comparison failed:\n{ae2}"
189+
)
160190
np.testing.assert_almost_equal(
161-
spearmanr(hm.y_ttest, hm.y_llm_lora_ttest)[0], 0.7770124558338013, decimal=5
191+
spearmanr(test_ys_aneh, y_pred_test)[0], 0.8291977762544377, decimal=7
162192
)
163-
np.testing.assert_almost_equal(
164-
spearmanr(test_ys, y_pred_test)[0], 0.8291977762544377, decimal=5
165-
)
166-
167193

168194

169195
def test_dataset_b_results():
170196
aaindex = "WOLR810101.txt"
171197
x_fft_train, _ = AAIndexEncoding(
172-
full_aaidx_txt_path(aaindex), train_seqs
198+
full_aaidx_txt_path(aaindex), train_seqs_aneh
173199
).collect_encoded_sequences()
174200
x_fft_test, _ = AAIndexEncoding(
175-
full_aaidx_txt_path(aaindex), test_seqs
201+
full_aaidx_txt_path(aaindex), test_seqs_aneh
176202
).collect_encoded_sequences()
177203
performances = get_regressor_performances(
178-
x_learn=x_fft_train,
179-
x_test=x_fft_test,
180-
y_learn=train_ys,
181-
y_test=test_ys,
204+
x_learn=x_fft_train,
205+
x_test=x_fft_test,
206+
y_learn=train_ys_aneh,
207+
y_test=test_ys_aneh,
182208
regressor='pls_loocv'
183209
)
184210
# Dataset B PLS_LOOCV results: R², RMSE, NRMSE, Pearson's r, Spearman's rho
@@ -191,7 +217,8 @@ def test_dataset_b_results():
191217

192218

193219
if __name__ == "__main__":
194-
test_gremlin()
220+
test_gremlin_aneh()
221+
test_gremlin_avgfp()
195222
test_hybrid_model_dca_llm()
196223
test_dataset_b_results()
197224

0 commit comments

Comments
 (0)