Skip to content

Commit 7aabd79

Browse files
committed
Skip dataset if RunTimeErrors occurred for both LLMs
RuntimeError: CUDA error: device-side assert triggered CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
1 parent 87a0c9e commit 7aabd79

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

scripts/ProteinGym_runs/run_performance_tests_proteingym_hybrid_dca_llm.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
6565
plt.figure(figsize=(40, 12))
6666
numbers_of_datasets = [i + 1 for i in range(len(mut_data.keys()))]
6767
for i, (dset_key, dset_paths) in enumerate(mut_data.items()):
68-
if i >= start_i and i not in already_tested_is: # i > 3 and i <21: #i == 18 - 1:
68+
if i >= start_i and i not in already_tested_is and i != 19: # i > 3 and i <21: #i == 18 - 1:
69+
# Skipping 20 BRCA1_HUMAN_Findlay_2018 due to LLM RunTimeErros
6970
start_time = time.time()
7071
print(f'\n{i+1}/{len(mut_data.items())}\n'
7172
f'===============================================================')
@@ -131,21 +132,27 @@ def compute_performances(mut_data, mut_sep=':', start_i: int = 0, already_tested
131132
print('DCA:', spearmanr(fitnesses, y_pred_dca), len(fitnesses))
132133
dca_unopt_perf = spearmanr(fitnesses, y_pred_dca)[0]
133134

134-
x_esm, esm_attention_mask = esm_tokenize_sequences(sequences, esm_tokenizer, max_length=len(wt_seq))
135-
y_esm = esm_infer(get_batches(x_esm, dtype=float, batch_size=1), esm_attention_mask, esm_base_model)
136-
print('ESM1v:', spearmanr(fitnesses, y_esm.cpu()))
135+
try:
136+
x_esm, esm_attention_mask = esm_tokenize_sequences(sequences, esm_tokenizer, max_length=len(wt_seq))
137+
y_esm = esm_infer(get_batches(x_esm, dtype=float, batch_size=1), esm_attention_mask, esm_base_model)
138+
print('ESM1v:', spearmanr(fitnesses, y_esm.cpu()))
139+
esm_unopt_perf = spearmanr(fitnesses, y_esm.cpu())[0]
140+
except RuntimeError:
141+
esm_unopt_perf = np.nan
137142

138-
input_ids, prosst_attention_mask, structure_input_ids = get_structure_quantizied(pdb, prosst_tokenizer, wt_seq)
139-
x_prosst = prosst_tokenize_sequences(sequences=sequences, vocab=prosst_vocab)
140143
try:
144+
input_ids, prosst_attention_mask, structure_input_ids = get_structure_quantizied(pdb, prosst_tokenizer, wt_seq)
145+
x_prosst = prosst_tokenize_sequences(sequences=sequences, vocab=prosst_vocab)
141146
y_prosst = get_logits_from_full_seqs(
142-
x_prosst, prosst_base_model, input_ids, prosst_attention_mask, structure_input_ids, train=False)
147+
x_prosst, prosst_base_model, input_ids, prosst_attention_mask, structure_input_ids, train=False)
143148
print('ProSST:', spearmanr(fitnesses, y_prosst.cpu()))
144149
prosst_unopt_perf = spearmanr(fitnesses, y_prosst.cpu())[0]
145150
except RuntimeError:
146151
prosst_unopt_perf = np.nan
147152

148-
esm_unopt_perf = spearmanr(fitnesses, y_esm.cpu())[0]
153+
if np.isnan(esm_unopt_perf) and np.isnan(prosst_unopt_perf):
154+
print('Both LLM\'s had RunTimeErrors, skipping dataset...')
155+
continue
149156

150157
ns_y_test = [len(variants)]
151158
for i_t, train_size in enumerate([100, 200, 1000]):
@@ -431,7 +438,7 @@ def plot_csv_data(csv, plot_name):
431438
r'$\overline{|\rho|}=$' + f'{np.nanmean(dset_hybrid_perfs_dca_prosst_1000):.2f}'
432439
][n]
433440
)
434-
plt.text( # N_Y_test,N_Y_test_100,N_Y_test_200,N_Y_test_1000
441+
plt.text(
435442
n + 0.15, -0.05,
436443
r'$\overline{N_{Y_\mathrm{test}}}=$' + f'{int(np.nanmean(np.array(dset_ns_y_test)[n]))}'
437444
)
@@ -496,11 +503,11 @@ def plot_csv_data(csv, plot_name):
496503
already_tested_is = []
497504

498505

499-
#compute_performances(
500-
# mut_data=combined_mut_data,
501-
# start_i=start_i,
502-
# already_tested_is=already_tested_is
503-
#)
506+
compute_performances(
507+
mut_data=combined_mut_data,
508+
start_i=start_i,
509+
already_tested_is=already_tested_is
510+
)
504511

505512

506513
with open(out_results_csv, 'r') as fh:

0 commit comments

Comments
 (0)