1- import importlib
21from pathlib import Path
32
43import click
54import yaml
65
7- from chebifier .prediction_models . base_predictor import BasePredictor
6+ from chebifier .model_registry import ENSEMBLES , MODEL_TYPES
87
8+ from .check_env import check_package_installed , get_current_environment
99from .hugging_face import download_model_files
10- from .setup_env import SetupEnvAndPackage
1110
12- yaml_path = Path ("api/registry .yml" )
11+ yaml_path = Path ("api/api_registry .yml" )
1312if yaml_path .exists ():
1413 with yaml_path .open ("r" ) as f :
15- model_registry = yaml .safe_load (f )
14+ api_registry = yaml .safe_load (f )
1615else :
1716 raise FileNotFoundError (f"{ yaml_path } not found." )
1817
@@ -40,7 +39,7 @@ def cli():
4039@click .option (
4140 "--model-type" ,
4241 "-m" ,
43- type = click .Choice (model_registry .keys ()),
42+ type = click .Choice (api_registry .keys ()),
4443 default = "mv" ,
4544 help = "Type of model to use" ,
4645)
@@ -60,29 +59,39 @@ def predict(smiles, smiles_file, output, model_type):
6059 click .echo ("No SMILES strings provided. Use --smiles or --smiles-file options." )
6160 return
6261
63- model_config = model_registry [model_type ]
64- predictor_kwargs = {"model_name" : model_type }
65-
66- current_dir = Path (__file__ ).resolve ().parent
67-
68- if "hugging_face" in model_config :
69- print (f"For model type `{ model_type } ` following files are used:" )
70- local_file_path = download_model_files (model_config ["hugging_face" ])
71- predictor_kwargs ["ckpt_path" ] = local_file_path ["ckpt" ]
72- predictor_kwargs ["target_labels_path" ] = local_file_path ["labels" ]
73-
74- SetupEnvAndPackage ().setup (
75- repo_url = model_config ["repo_url" ],
76- clone_dir = current_dir / ".cloned_repos" ,
77- venv_dir = current_dir ,
78- )
79-
80- model_cls_path = model_config ["wrapper" ]
81- module_path , class_name = model_cls_path .rsplit ("." , 1 )
82- module = importlib .import_module (module_path )
83- model_cls : type = getattr (module , class_name )
84- model_instance = model_cls (** predictor_kwargs )
85- assert isinstance (model_instance , BasePredictor )
62+ print ("Current working environment is:" , get_current_environment ())
63+
64+ def get_individual_model (model_config ):
65+ predictor_kwargs = {}
66+ if "hugging_face" in model_config :
67+ predictor_kwargs = download_model_files (model_config ["hugging_face" ])
68+ check_package_installed (model_config ["package_name" ])
69+ return predictor_kwargs
70+
71+ if model_type in MODEL_TYPES :
72+ print (f"Predictor for Single/Individual Model: { model_type } " )
73+ model_config = api_registry [model_type ]
74+ predictor_kwargs = get_individual_model (model_config )
75+ predictor_kwargs ["model_name" ] = model_type
76+ model_instance = MODEL_TYPES [model_type ](** predictor_kwargs )
77+
78+ elif model_type in ENSEMBLES :
79+ print (f"Predictor for Ensemble Model: { model_type } " )
80+ ensemble_config = {}
81+ for i , en_comp in enumerate (api_registry [model_type ]["ensemble_of" ]):
82+ assert en_comp in MODEL_TYPES
83+ print (f"For ensemble component { en_comp } " )
84+ predictor_kwargs = get_individual_model (api_registry [en_comp ])
85+ model_key = f"model_{ i + 1 } "
86+ ensemble_config [model_key ] = {
87+ "type" : en_comp ,
88+ "model_name" : f"{ en_comp } _{ model_key } " ,
89+ ** predictor_kwargs ,
90+ }
91+ model_instance = ENSEMBLES [model_type ](ensemble_config )
92+
93+ else :
94+ raise ValueError ("" )
8695
8796 # Make predictions
8897 predictions = model_instance .predict_smiles_list (smiles_list )
0 commit comments