Skip to content

Commit b939639

Browse files
committed
Adding ESM1v-finetuned+DCA-finetuned hybrid PGym tests
1 parent 36a89d3 commit b939639

File tree

6 files changed

+1529
-246
lines changed

6 files changed

+1529
-246
lines changed

.gitignore

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,3 +402,13 @@ scripts/Runtime_tests/runtimes.png
402402
datasets/AVGFP/Recomb_Double_Split/Predictions_Hybrid_TopRecomb_Double_Split.txt
403403
scripts/ProteinGym_runs/single_point_mut_performance_violin.png
404404
scripts/ProteinGym_runs/multi_point_mut_performance_violin.png
405+
scripts/ESM_finetuning/DMS_msa_files/
406+
scripts/ESM_finetuning/DMS_ProteinGym_substitutions/
407+
scripts/ESM_finetuning/ProteinGym_AF2_structures/
408+
409+
scripts/ESM_finetuning/higher_point_dms_mut_data.json
410+
scripts/ESM_finetuning/single_point_dms_mut_data.json
411+
scripts/ESM_finetuning/results/dca_esm_and_hybrid_opt_results_clean.csv
412+
scripts/ESM_finetuning/results/dca_esm_and_hybrid_opt_results.csv
413+
scripts/ESM_finetuning/mut_performance.png
414+
scripts/ESM_finetuning/_Description_DMS_substitutions_data.csv
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import os
2+
import urllib.request
3+
import zipfile
4+
import pandas as pd
5+
import json
6+
# To use unverified ssl you can add this to your code, taken from:
7+
# https://stackoverflow.com/questions/50236117/scraping-ssl-certificate-verify-failed-error-for-http-en-wikipedia-org
8+
#import ssl
9+
#ssl._create_default_https_context = ssl._create_unverified_context
10+
11+
12+
13+
def download_proteingym_data(version: str = '1.1'):
14+
url = f'https://marks.hms.harvard.edu/proteingym/ProteinGym_v{version}/DMS_substitutions.csv'
15+
print(f'Getting {url}...')
16+
urllib.request.urlretrieve(url, os.path.join(os.path.dirname(__file__), '_Description_DMS_substitutions_data.csv'))
17+
18+
url = f'https://marks.hms.harvard.edu/proteingym/ProteinGym_v{version}/DMS_ProteinGym_substitutions.zip'
19+
print(f'Getting {url}...')
20+
urllib.request.urlretrieve(url, os.path.join(os.path.dirname(__file__), 'DMS_ProteinGym_substitutions.zip'))
21+
with zipfile.ZipFile(os.path.join(os.path.dirname(__file__), 'DMS_ProteinGym_substitutions.zip'), "r") as zip_ref:
22+
zip_ref.extractall(os.path.join(os.path.dirname(__file__), 'DMS_ProteinGym_substitutions', '..'))
23+
os.remove(os.path.join(os.path.dirname(__file__), 'DMS_ProteinGym_substitutions.zip'))
24+
25+
url = f'https://marks.hms.harvard.edu/proteingym/ProteinGym_v{version}/DMS_msa_files.zip'
26+
print(f'Getting {url}...')
27+
urllib.request.urlretrieve(url, os.path.join(os.path.dirname(__file__), 'DMS_msa_files.zip'))
28+
with zipfile.ZipFile(os.path.join(os.path.dirname(__file__), 'DMS_msa_files.zip'), "r") as zip_ref:
29+
zip_ref.extractall(os.path.join(os.path.dirname(__file__), 'DMS_msa_files', '..'))
30+
os.remove(os.path.join(os.path.dirname(__file__), 'DMS_msa_files.zip'))
31+
32+
url = f'https://marks.hms.harvard.edu/proteingym/ProteinGym_v{version}/ProteinGym_AF2_structures.zip'
33+
print(f'Getting {url}...')
34+
urllib.request.urlretrieve(url, os.path.join(os.path.dirname(__file__), 'ProteinGym_AF2_structures.zip'))
35+
with zipfile.ZipFile(os.path.join(os.path.dirname(__file__), 'ProteinGym_AF2_structures.zip'), "r") as zip_ref:
36+
zip_ref.extractall(os.path.join(os.path.dirname(__file__), 'ProteinGym_AF2_structures', '..'))
37+
os.remove(os.path.join(os.path.dirname(__file__), 'ProteinGym_AF2_structures.zip'))
38+
39+
40+
def get_single_or_multi_point_mut_data(csv_description_path, datasets_path=None, msas_path=None, pdbs_path=None, single: bool = True):
41+
"""
42+
Get ProteinGym data, here only the single or multi-point mutant data (all data for
43+
that target dataset having single- or multi-point mutated variants available).
44+
Reads the dataset description/overview CSV to search for available data in
45+
the 'DMS_ProteinGym_substitutions' sub-directory.
46+
"""
47+
if single:
48+
type_str = 'single'
49+
else:
50+
type_str = 'multi'
51+
file_dirname = os.path.abspath(os.path.dirname(__file__))
52+
if datasets_path is None:
53+
datasets_path = os.path.join(file_dirname, 'DMS_ProteinGym_substitutions')
54+
if msas_path is None:
55+
msas_path = os.path.join(file_dirname, 'DMS_msa_files') # used to be DMS_msa_files/MSA_files/DMS
56+
msas = os.listdir(msas_path)
57+
if pdbs_path is None:
58+
pdbs_path = os.path.join(file_dirname, 'ProteinGym_AF2_structures')
59+
pdbs = os.listdir(pdbs_path)
60+
description_df = pd.read_csv(csv_description_path, sep=',')
61+
i_s = []
62+
for i, n_mp in enumerate(description_df['DMS_number_multiple_mutants'].to_list()):
63+
if n_mp > 0:
64+
if not single:
65+
i_s.append(i)
66+
else:
67+
if single:
68+
i_s.append(i)
69+
else:
70+
pass
71+
target_description_df = description_df.iloc[i_s, :]
72+
target_filenames = target_description_df['DMS_filename'].to_list()
73+
target_wt_seqs = target_description_df['target_seq'].to_list()
74+
target_msa_starts = target_description_df['MSA_start'].to_list()
75+
target_msa_ends = target_description_df['MSA_end'].to_list()
76+
print(f'Searching for CSV files in {datasets_path}...')
77+
csv_paths = [os.path.join(datasets_path, target_filename) for target_filename in target_filenames]
78+
print(f'Found {len(csv_paths)} {type_str}-point datasets, will check if all are available in datasets folder...')
79+
avail_filenames, avail_csvs, avail_wt_seqs = [], [], []
80+
for i, csv_path in enumerate(csv_paths):
81+
if not os.path.isfile(csv_path):
82+
print(f"Did not find CSV file {csv_path} - will remove it from prediction process!")
83+
else:
84+
avail_csvs.append(csv_path)
85+
avail_wt_seqs.append(target_wt_seqs[i])
86+
avail_filenames.append(os.path.splitext(target_filenames[i])[0])
87+
assert len(avail_wt_seqs) == len(avail_csvs)
88+
print(f'Getting data from {len(avail_csvs)} {type_str}-point mutation DMS CSV files...')
89+
dms_mp_data = {}
90+
for i, csv_path in enumerate(avail_csvs):
91+
begin = avail_filenames[i].split('_')[0] + '_' + avail_filenames[i].split('_')[1]
92+
msa_path=None
93+
for msa in msas:
94+
if msa.startswith(begin):
95+
msa_path = os.path.join(msas_path, msa)
96+
for pdb in pdbs:
97+
if pdb.startswith(begin):
98+
pdb_path = os.path.join(pdbs_path, pdb)
99+
if msa_path is None or pdb_path is None:
100+
print(f'Did not find a MSA or a PDB beginning with {begin}, continuing...')
101+
continue
102+
target_msa_start = target_msa_starts[i]
103+
target_msa_end = target_msa_ends[i]
104+
dms_mp_data.update({
105+
avail_filenames[i]: {
106+
'CSV_path': csv_path,
107+
'WT_sequence': avail_wt_seqs[i],
108+
'MSA_path': msa_path,
109+
'MSA_start': target_msa_start,
110+
'MSA_end': target_msa_end,
111+
'PDB_path': pdb_path
112+
}
113+
})
114+
return dms_mp_data
115+
116+
117+
if __name__ == '__main__':
118+
download_proteingym_data()
119+
single_mut_data = get_single_or_multi_point_mut_data(os.path.join(os.path.dirname(__file__), '_Description_DMS_substitutions_data.csv'), single=True)
120+
higher_mut_data = get_single_or_multi_point_mut_data(os.path.join(os.path.dirname(__file__), '_Description_DMS_substitutions_data.csv'), single=False)
121+
json_output_file_single = os.path.abspath(os.path.join(os.path.dirname(__file__), f"single_point_dms_mut_data.json"))
122+
json_output_file_higher = os.path.abspath(os.path.join(os.path.dirname(__file__), f"higher_point_dms_mut_data.json"))
123+
with open(json_output_file_single, 'w') as fp:
124+
json.dump(single_mut_data, fp, indent=4)
125+
with open(json_output_file_higher, 'w') as fp:
126+
json.dump(higher_mut_data, fp, indent=4)
127+
print(f"Saved path data information as JSON files at {json_output_file_single} and {json_output_file_higher}.")

0 commit comments

Comments
 (0)