Skip to content

Commit 273db3f

Browse files
committed
minor plotting updates
1 parent 5d5a09c commit 273db3f

File tree

1 file changed

+8
-33
lines changed

1 file changed

+8
-33
lines changed

scripts/ProteinGym_runs/run_performance_tests_proteingym_hybrid_dca_llm.py

Lines changed: 8 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -44,26 +44,6 @@ def get_vram(verbose: bool = True):
4444
return free, total
4545

4646

47-
def read_pdb(pdbfile):
48-
from Bio import PDB
49-
50-
pdb_io = PDB.PDBIO()
51-
pdb_parser = PDB.PDBParser()
52-
structure = pdb_parser.get_structure('ppp', pdbfile)
53-
54-
new_resnums = [i + 200 for i in range(135)]
55-
56-
print(structure)
57-
print(pdbfile)
58-
59-
for model in structure:
60-
for chain in model:
61-
for i, residue in enumerate(chain.get_residues()):
62-
res_id = list(residue.id)
63-
#res_id[1] = new_resnums[i]
64-
#residue.id = tuple(res_id)
65-
66-
6747
def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested_is: list = []):
6848
# Get cpu, gpu or mps device for training.
6949
device = (
@@ -81,14 +61,13 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
8161
esm_base_model, esm_lora_model, esm_tokenizer, esm_optimizer = get_esm_models()
8262
esm_base_model = esm_base_model.to(device)
8363
MAX_WT_SEQUENCE_LENGTH = 2000
84-
N_EPOCHS = 5
8564
get_vram()
8665
hybrid_perfs = []
8766
plt.figure(figsize=(40, 12))
8867
numbers_of_datasets = [i + 1 for i in range(len(mut_data.keys()))]
8968
delta_times = []
9069
for i, (dset_key, dset_paths) in enumerate(mut_data.items()):
91-
if i >= start_i and i not in already_tested_is and i < 21: # i > 3 and i <21: #i == 18 - 1:
70+
if i >= start_i and i not in already_tested_is: # i > 3 and i <21: #i == 18 - 1:
9271
start_time = time.time()
9372
print(f'\n{i+1}/{len(mut_data.items())}\n'
9473
f'===============================================================')
@@ -103,7 +82,6 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
10382
print('MSA path:', msa_path)
10483
print('MSA start:', msa_start, '- MSA end:', msa_end)
10584
print('WT sequence (trimmed from MSA start to MSA end):\n' + wt_seq)
106-
read_pdb(pdb)
10785
#if msa_start != 1:
10886
# print('Continuing (TODO: requires cut of PDB input struture residues)...')
10987
# continue
@@ -152,8 +130,6 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
152130
y_pred_dca = get_delta_e_statistical_model(x_dca, x_wt)
153131
print('DCA:', spearmanr(fitnesses, y_pred_dca), len(fitnesses))
154132
dca_unopt_perf = spearmanr(fitnesses, y_pred_dca)[0]
155-
# TF 10,000: DCA: SignificanceResult(statistic=np.float64(0.6486616550552755), pvalue=np.float64(3.647740047145113e-119)) 989
156-
# Torch 10,000: DCA: SignificanceResult(statistic=np.float64(0.6799982280150232), pvalue=np.float64(3.583110693136881e-135)) 989
157133

158134
x_esm, esm_attention_mask = esm_tokenize_sequences(sequences, esm_tokenizer, max_length=len(wt_seq))
159135
y_esm = esm_infer(get_batches(x_esm, dtype=float, batch_size=1), esm_attention_mask, esm_base_model)
@@ -248,7 +224,6 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
248224
llm_model_input=method,
249225
x_wt=x_wt
250226
)
251-
252227
y_test_pred = hm.hybrid_prediction(
253228
x_dca=np.array(x_dca_test),
254229
x_llm=[
@@ -257,7 +232,6 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
257232
np.asarray(x_llm_test_prosst)
258233
][i_m]
259234
)
260-
261235
print(f'Hybrid perf.: {spearmanr(y_test, y_test_pred)[0]}')
262236
hybrid_perfs.append(spearmanr(y_test, y_test_pred)[0])
263237
except RuntimeError: # modeling_prosst.py, line 920, in forward
@@ -336,7 +310,6 @@ def plot_csv_data(csv, plot_name):
336310
train_test_size_texts.append(plt.text(len(tested_dsets), np.nanmean(dset_hybrid_perfs_dca_1000), f'{np.nanmean(dset_hybrid_perfs_dca_1000):.2f}', color='blueviolet'))
337311

338312

339-
340313
plt.plot(range(len(tested_dsets)), dset_esm_perfs, 'o--', markersize=8, color='tab:green', label='ESM (0)')
341314
plt.plot(range(len(tested_dsets) + 1), np.full(len(tested_dsets) + 1, np.nanmean(dset_esm_perfs)), color='tab:green', linestyle='--')
342315
for i, (p, n_test) in enumerate(zip(dset_esm_perfs, df['N_Y_test'].astype('Int64').to_list())):
@@ -362,8 +335,6 @@ def plot_csv_data(csv, plot_name):
362335
train_test_size_texts.append(plt.text(len(tested_dsets), np.nanmean(dset_hybrid_perfs_dca_esm_1000), f'{np.nanmean(dset_hybrid_perfs_dca_esm_1000):.2f}', color='turquoise'))
363336

364337

365-
366-
367338
plt.plot(range(len(tested_dsets)), dset_prosst_perfs, 'o--', markersize=8, color='tab:red', label='ProSST (0)')
368339
plt.plot(range(len(tested_dsets) + 1), np.full(len(tested_dsets) + 1, np.nanmean(dset_prosst_perfs)), color='tab:red', linestyle='--')
369340
for i, (p, n_test) in enumerate(zip(dset_prosst_perfs, df['N_Y_test'].astype('Int64').to_list())):
@@ -389,9 +360,6 @@ def plot_csv_data(csv, plot_name):
389360
train_test_size_texts.append(plt.text(len(tested_dsets), np.nanmean(dset_hybrid_perfs_dca_prosst_1000), f'{np.nanmean(dset_hybrid_perfs_dca_prosst_1000):.2f}', color='darkred'))
390361

391362

392-
393-
394-
395363
plt.grid(zorder=-1)
396364
plt.xticks(range(len(tested_dsets)), tested_dsets, rotation=45, ha='right')
397365
plt.margins(0.01)
@@ -433,6 +401,13 @@ def plot_csv_data(csv, plot_name):
433401
print(df.columns)
434402
dset_ns_y_test = [
435403
df['N_Y_test'].to_list(),
404+
df['N_Y_test_100'].to_list(),
405+
df['N_Y_test_200'].to_list(),
406+
df['N_Y_test_1000'].to_list(),
407+
df['N_Y_test'].to_list(),
408+
df['N_Y_test_100'].to_list(),
409+
df['N_Y_test_200'].to_list(),
410+
df['N_Y_test_1000'].to_list(),
436411
df['N_Y_test'].to_list(),
437412
df['N_Y_test_100'].to_list(),
438413
df['N_Y_test_200'].to_list(),

0 commit comments

Comments
 (0)