Skip to content

Commit 5d5a09c

Browse files
committed
Append np.nan performance if ProSST PDB
doesn't match WT sequence
1 parent 36c865b commit 5d5a09c

File tree

1 file changed

+55
-25
lines changed

1 file changed

+55
-25
lines changed

scripts/ProteinGym_runs/run_performance_tests_proteingym_hybrid_dca_llm.py

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,26 @@ def get_vram(verbose: bool = True):
4444
return free, total
4545

4646

47+
def read_pdb(pdbfile):
48+
from Bio import PDB
49+
50+
pdb_io = PDB.PDBIO()
51+
pdb_parser = PDB.PDBParser()
52+
structure = pdb_parser.get_structure('ppp', pdbfile)
53+
54+
new_resnums = [i + 200 for i in range(135)]
55+
56+
print(structure)
57+
print(pdbfile)
58+
59+
for model in structure:
60+
for chain in model:
61+
for i, residue in enumerate(chain.get_residues()):
62+
res_id = list(residue.id)
63+
#res_id[1] = new_resnums[i]
64+
#residue.id = tuple(res_id)
65+
66+
4767
def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested_is: list = []):
4868
# Get cpu, gpu or mps device for training.
4969
device = (
@@ -83,6 +103,10 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
83103
print('MSA path:', msa_path)
84104
print('MSA start:', msa_start, '- MSA end:', msa_end)
85105
print('WT sequence (trimmed from MSA start to MSA end):\n' + wt_seq)
106+
read_pdb(pdb)
107+
#if msa_start != 1:
108+
# print('Continuing (TODO: requires cut of PDB input struture residues)...')
109+
# continue
86110
# Getting % usage of virtual_memory (3rd field)
87111
import psutil;print(f'RAM used: {round(psutil.virtual_memory()[3]/1E9, 3)} '
88112
f'GB ({psutil.virtual_memory()[2]} %)')
@@ -92,9 +116,9 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
92116
# print('More than 400000 variant-fitness pairs which represents a '
93117
# 'potential out-of-memory risk, skipping dataset...')
94118
# continue
95-
variants = variant_fitness_data['mutant'].to_numpy() # [400:700]
119+
variants = variant_fitness_data['mutant'].to_numpy()
96120
variants_orig = variants
97-
fitnesses = variant_fitness_data['DMS_score'].to_numpy() # [400:700]
121+
fitnesses = variant_fitness_data['DMS_score'].to_numpy()
98122
if len(fitnesses) <= 50: # and len(fitnesses) >= 500: # TODO: RESET TO 50
99123
print('Number of available variants <= 50, skipping dataset...')
100124
continue
@@ -137,11 +161,14 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
137161

138162
input_ids, prosst_attention_mask, structure_input_ids = get_structure_quantizied(pdb, prosst_tokenizer, wt_seq)
139163
x_prosst = prosst_tokenize_sequences(sequences=sequences, vocab=prosst_vocab)
140-
y_prosst = get_logits_from_full_seqs(
141-
x_prosst, prosst_base_model, input_ids, prosst_attention_mask, structure_input_ids, train=False)
142-
print('ProSST:', spearmanr(fitnesses, y_prosst.cpu()))
143-
144-
prosst_unopt_perf = spearmanr(fitnesses, y_prosst.cpu())[0]
164+
try:
165+
y_prosst = get_logits_from_full_seqs(
166+
x_prosst, prosst_base_model, input_ids, prosst_attention_mask, structure_input_ids, train=False)
167+
print('ProSST:', spearmanr(fitnesses, y_prosst.cpu()))
168+
prosst_unopt_perf = spearmanr(fitnesses, y_prosst.cpu())[0]
169+
except RuntimeError:
170+
prosst_unopt_perf = np.nan
171+
145172
esm_unopt_perf = spearmanr(fitnesses, y_esm.cpu())[0]
146173

147174
ns_y_test = [len(variants)]
@@ -214,24 +241,27 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
214241
get_vram()
215242
for i_m, method in enumerate([None, llm_dict_esm, llm_dict_prosst]):
216243
print('\n~~~ ' + ['DCA hybrid', 'DCA+ESM1v hybrid', 'DCA+ProSST hybrid'][i_m] + ' ~~~')
217-
hm = DCALLMHybridModel(
218-
x_train_dca=np.array(x_dca_train),
219-
y_train=y_train,
220-
llm_model_input=method,
221-
x_wt=x_wt
222-
)
223-
224-
y_test_pred = hm.hybrid_prediction(
225-
x_dca=np.array(x_dca_test),
226-
x_llm=[
227-
None,
228-
np.asarray(x_llm_test_esm),
229-
np.asarray(x_llm_test_prosst)
230-
][i_m]
231-
)
232-
233-
print(f'Hybrid perf.: {spearmanr(y_test, y_test_pred)[0]}')
234-
hybrid_perfs.append(spearmanr(y_test, y_test_pred)[0])
244+
try:
245+
hm = DCALLMHybridModel(
246+
x_train_dca=np.array(x_dca_train),
247+
y_train=y_train,
248+
llm_model_input=method,
249+
x_wt=x_wt
250+
)
251+
252+
y_test_pred = hm.hybrid_prediction(
253+
x_dca=np.array(x_dca_test),
254+
x_llm=[
255+
None,
256+
np.asarray(x_llm_test_esm),
257+
np.asarray(x_llm_test_prosst)
258+
][i_m]
259+
)
260+
261+
print(f'Hybrid perf.: {spearmanr(y_test, y_test_pred)[0]}')
262+
hybrid_perfs.append(spearmanr(y_test, y_test_pred)[0])
263+
except RuntimeError: # modeling_prosst.py, line 920, in forward
264+
hybrid_perfs.append(np.nan)
235265
ns_y_test.append(len(y_test_pred))
236266
except ValueError as e:
237267
print(f'Only {len(fitnesses)} variant-fitness pairs in total, cannot split the data '

0 commit comments

Comments
 (0)