@@ -265,70 +265,88 @@ def plot_Platinum_pkd_distribution():
265265# FIG 3 - Platinum MODEL RESULTS
266266###############################
267267# - Run all 5 models through platinum and save predicted pkds as platinum_preds/<model_opt>_<fold>.csv
268- # def Platinum_run_inference(model:str):
269- import logging
270- import os
271-
272- import torch
273- import pandas as pd
274- from src .utils .loader import Loader
275- from src import TUNED_MODEL_CONFIGS , cfg
276- from collections import defaultdict
277- from tqdm import tqdm
278- logging .getLogger ().setLevel (logging .INFO )
279-
280- DEVICE = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
281- MODEL , model_kwargs = Loader .load_tuned_model ('davis_esm' , fold = 0 , device = DEVICE )
282-
283-
284- model_opts = ['davis_DG' , 'davis_gvpl' , 'davis_esm' ,
268+ def Platinum_run_inference ():
269+ """
270+ This script runs inference on platinum for the following models:
271+ ['davis_DG', 'davis_gvpl', 'davis_esm',
285272 'kiba_DG', 'kiba_esm', 'kiba_gvpl',
286273 'PDBbind_DG', 'PDBbind_esm', 'PDBbind_gvpl',
287274 'PDBbind_gvpl_aflow']
288- for model_opt in model_opts :
289- loader = None
290- for fold in range (5 ):
291- print (f"{ model_opt } -{ fold } " )
292- out_csv = f"./results/platinum_predictions/{ model_opt } _{ fold } .csv"
293- if os .path .exists (out_csv ):
294- print ('\t Predictions already exists' )
295- continue
296-
297- MODEL_PARAMS = TUNED_MODEL_CONFIGS [model_opt ]
298- try :
299- MODEL , model_kwargs = Loader .load_tuned_model (model_opt , fold = fold , device = DEVICE )
300- except AssertionError as e :
301- print (e )
302- continue
303- MODEL .eval ()
304- print ("\t Model loaded" )
305-
306- if loader is None : # caches loader if already created for this model_opt
307- loader = Loader .load_DataLoaders (
308- data = cfg .DATA_OPT .platinum ,
309- datasets = ['full' ],
310- pro_feature = MODEL_PARAMS ['feature_opt' ],
311- edge_opt = MODEL_PARAMS ['edge_opt' ],
312- ligand_feature = MODEL_PARAMS ['lig_feat_opt' ],
313- ligand_edge = MODEL_PARAMS ['lig_edge_opt' ],
314- )['full' ]
315- print ("\t Dataset loaded" )
316-
317-
318- PREDICTIONS = defaultdict (list )
319- for batch in tqdm (loader , desc = "\t running inference" ):
320- PREDICTIONS ['code' ].extend (batch ['code' ])
321- PREDICTIONS ['y' ].extend (batch ['y' ].tolist ())
322- y_pred = MODEL (batch ['protein' ].to (DEVICE ), batch ['ligand' ].to (DEVICE ))
323- PREDICTIONS ['y_pred' ].extend (y_pred [:,0 ].tolist ())
324-
325-
326-
327- df = pd .DataFrame .from_dict (PREDICTIONS )
328- df .set_index ('code' , inplace = True )
329- df .sort_index (key = lambda x : x .str .split ("_" ).str [0 ].astype (int ), inplace = True )
330- df .to_csv (out_csv )
275+
276+ It assumes that checkpoints for these models are already present in the CHECKPOINT_SAVE_DIR location
277+ as specified by the src/utils/config file.
278+
279+ Also for any aflow model it also requires that aflow predictions for structures have already been generated
280+ otherwise building that dataset will be impossible.
281+
282+ If the dataset is already built for that model then no building of the dataset will be done so long as they
283+ are in the right location as specified by DATA_ROOT in the src/utils/config file.
284+ """
285+ import logging
286+ import os
331287
332- print ("DONE!" )
288+ import torch
289+ import pandas as pd
290+ from src .utils .loader import Loader
291+ from src import TUNED_MODEL_CONFIGS , cfg
292+ from collections import defaultdict
293+ from tqdm import tqdm
294+ logging .getLogger ().setLevel (logging .INFO )
295+
296+ DEVICE = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
297+ MODEL , model_kwargs = Loader .load_tuned_model ('davis_esm' , fold = 0 , device = DEVICE )
298+
299+
300+ model_opts = ['davis_DG' , 'davis_gvpl' , 'davis_esm' ,
301+ 'kiba_DG' , 'kiba_esm' , 'kiba_gvpl' ,
302+ 'PDBbind_DG' , 'PDBbind_esm' , 'PDBbind_gvpl' ,
303+ 'PDBbind_gvpl_aflow' ]
304+ for model_opt in model_opts :
305+ loader = None
306+ for fold in range (5 ):
307+ print (f"{ model_opt } -{ fold } " )
308+ out_csv = f"./results/platinum_predictions/{ model_opt } _{ fold } .csv"
309+ if os .path .exists (out_csv ):
310+ print ('\t Predictions already exists' )
311+ continue
312+
313+ MODEL_PARAMS = TUNED_MODEL_CONFIGS [model_opt ]
314+ try :
315+ MODEL , model_kwargs = Loader .load_tuned_model (model_opt , fold = fold , device = DEVICE )
316+ except AssertionError as e :
317+ print (e )
318+ continue
319+ MODEL .eval ()
320+ print ("\t Model loaded" )
321+
322+ if loader is None : # caches loader if already created for this model_opt
323+ loader = Loader .load_DataLoaders (
324+ data = cfg .DATA_OPT .platinum ,
325+ datasets = ['full' ],
326+ pro_feature = MODEL_PARAMS ['feature_opt' ],
327+ edge_opt = MODEL_PARAMS ['edge_opt' ],
328+ ligand_feature = MODEL_PARAMS ['lig_feat_opt' ],
329+ ligand_edge = MODEL_PARAMS ['lig_edge_opt' ],
330+ )['full' ]
331+ print ("\t Dataset loaded" )
332+
333+
334+ PREDICTIONS = defaultdict (list )
335+ for batch in tqdm (loader , desc = "\t running inference" , ncols = 100 ):
336+ PREDICTIONS ['code' ].extend (batch ['code' ])
337+ PREDICTIONS ['y' ].extend (batch ['y' ].tolist ())
338+ y_pred = MODEL (batch ['protein' ].to (DEVICE ), batch ['ligand' ].to (DEVICE ))
339+ PREDICTIONS ['y_pred' ].extend (y_pred [:,0 ].tolist ())
340+
341+
342+
343+ df = pd .DataFrame .from_dict (PREDICTIONS )
344+ df .set_index ('code' , inplace = True )
345+ df .sort_index (key = lambda x : x .str .split ("_" ).str [0 ].astype (int ), inplace = True )
346+ df .to_csv (out_csv )
347+
348+ print ("DONE!" )
349+
350+
333351def platinum_model_results_raw ():
334352 pass
0 commit comments