Skip to content

Commit 4c5c4db

Browse files
committed
feat: Platinum_run_inference
To get model inference results for each model as a csv.
1 parent bd62ef6 commit 4c5c4db

File tree

1 file changed

+79
-61
lines changed

1 file changed

+79
-61
lines changed

paper_figures.py

Lines changed: 79 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -265,70 +265,88 @@ def plot_Platinum_pkd_distribution():
265265
# FIG 3 - Platinum MODEL RESULTS
266266
###############################
267267
# - Run all 5 models through platinum and save predicted pkds as platinum_preds/<model_opt>_<fold>.csv
268-
# def Platinum_run_inference(model:str):
269-
import logging
270-
import os
271-
272-
import torch
273-
import pandas as pd
274-
from src.utils.loader import Loader
275-
from src import TUNED_MODEL_CONFIGS, cfg
276-
from collections import defaultdict
277-
from tqdm import tqdm
278-
logging.getLogger().setLevel(logging.INFO)
279-
280-
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
281-
MODEL, model_kwargs = Loader.load_tuned_model('davis_esm', fold=0, device=DEVICE)
282-
283-
284-
model_opts = ['davis_DG', 'davis_gvpl', 'davis_esm',
268+
def Platinum_run_inference():
269+
"""
270+
This script runs inference on platinum for the following models:
271+
['davis_DG', 'davis_gvpl', 'davis_esm',
285272
'kiba_DG', 'kiba_esm', 'kiba_gvpl',
286273
'PDBbind_DG', 'PDBbind_esm', 'PDBbind_gvpl',
287274
'PDBbind_gvpl_aflow']
288-
for model_opt in model_opts:
289-
loader = None
290-
for fold in range(5):
291-
print(f"{model_opt}-{fold}")
292-
out_csv = f"./results/platinum_predictions/{model_opt}_{fold}.csv"
293-
if os.path.exists(out_csv):
294-
print('\t Predictions already exists')
295-
continue
296-
297-
MODEL_PARAMS = TUNED_MODEL_CONFIGS[model_opt]
298-
try:
299-
MODEL, model_kwargs = Loader.load_tuned_model(model_opt, fold=fold, device=DEVICE)
300-
except AssertionError as e:
301-
print(e)
302-
continue
303-
MODEL.eval()
304-
print("\t Model loaded")
305-
306-
if loader is None: # caches loader if already created for this model_opt
307-
loader = Loader.load_DataLoaders(
308-
data=cfg.DATA_OPT.platinum,
309-
datasets=['full'],
310-
pro_feature=MODEL_PARAMS['feature_opt'],
311-
edge_opt=MODEL_PARAMS['edge_opt'],
312-
ligand_feature=MODEL_PARAMS['lig_feat_opt'],
313-
ligand_edge=MODEL_PARAMS['lig_edge_opt'],
314-
)['full']
315-
print("\t Dataset loaded")
316-
317-
318-
PREDICTIONS = defaultdict(list)
319-
for batch in tqdm(loader, desc="\t running inference"):
320-
PREDICTIONS['code'].extend(batch['code'])
321-
PREDICTIONS['y'].extend(batch['y'].tolist())
322-
y_pred = MODEL(batch['protein'].to(DEVICE), batch['ligand'].to(DEVICE))
323-
PREDICTIONS['y_pred'].extend(y_pred[:,0].tolist())
324-
325-
326-
327-
df = pd.DataFrame.from_dict(PREDICTIONS)
328-
df.set_index('code', inplace=True)
329-
df.sort_index(key = lambda x: x.str.split("_").str[0].astype(int), inplace=True)
330-
df.to_csv(out_csv)
275+
276+
It assumes that checkpoints for these models are already present in the CHECKPOINT_SAVE_DIR location
277+
as specified by the src/utils/config file.
278+
279+
Also for any aflow model it also requires that aflow predictions for structures have already been generated
280+
otherwise building that dataset will be impossible.
281+
282+
If the dataset is already built for that model then no building of the dataset will be done so long as they
283+
are in the right location as specified by DATA_ROOT in the src/utils/config file.
284+
"""
285+
import logging
286+
import os
331287

332-
print("DONE!")
288+
import torch
289+
import pandas as pd
290+
from src.utils.loader import Loader
291+
from src import TUNED_MODEL_CONFIGS, cfg
292+
from collections import defaultdict
293+
from tqdm import tqdm
294+
logging.getLogger().setLevel(logging.INFO)
295+
296+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
297+
MODEL, model_kwargs = Loader.load_tuned_model('davis_esm', fold=0, device=DEVICE)
298+
299+
300+
model_opts = ['davis_DG', 'davis_gvpl', 'davis_esm',
301+
'kiba_DG', 'kiba_esm', 'kiba_gvpl',
302+
'PDBbind_DG', 'PDBbind_esm', 'PDBbind_gvpl',
303+
'PDBbind_gvpl_aflow']
304+
for model_opt in model_opts:
305+
loader = None
306+
for fold in range(5):
307+
print(f"{model_opt}-{fold}")
308+
out_csv = f"./results/platinum_predictions/{model_opt}_{fold}.csv"
309+
if os.path.exists(out_csv):
310+
print('\t Predictions already exists')
311+
continue
312+
313+
MODEL_PARAMS = TUNED_MODEL_CONFIGS[model_opt]
314+
try:
315+
MODEL, model_kwargs = Loader.load_tuned_model(model_opt, fold=fold, device=DEVICE)
316+
except AssertionError as e:
317+
print(e)
318+
continue
319+
MODEL.eval()
320+
print("\t Model loaded")
321+
322+
if loader is None: # caches loader if already created for this model_opt
323+
loader = Loader.load_DataLoaders(
324+
data=cfg.DATA_OPT.platinum,
325+
datasets=['full'],
326+
pro_feature=MODEL_PARAMS['feature_opt'],
327+
edge_opt=MODEL_PARAMS['edge_opt'],
328+
ligand_feature=MODEL_PARAMS['lig_feat_opt'],
329+
ligand_edge=MODEL_PARAMS['lig_edge_opt'],
330+
)['full']
331+
print("\t Dataset loaded")
332+
333+
334+
PREDICTIONS = defaultdict(list)
335+
for batch in tqdm(loader, desc="\t running inference", ncols=100):
336+
PREDICTIONS['code'].extend(batch['code'])
337+
PREDICTIONS['y'].extend(batch['y'].tolist())
338+
y_pred = MODEL(batch['protein'].to(DEVICE), batch['ligand'].to(DEVICE))
339+
PREDICTIONS['y_pred'].extend(y_pred[:,0].tolist())
340+
341+
342+
343+
df = pd.DataFrame.from_dict(PREDICTIONS)
344+
df.set_index('code', inplace=True)
345+
df.sort_index(key = lambda x: x.str.split("_").str[0].astype(int), inplace=True)
346+
df.to_csv(out_csv)
347+
348+
print("DONE!")
349+
350+
333351
def platinum_model_results_raw():
334352
pass

0 commit comments

Comments
 (0)