Skip to content

Commit fd814e9

Browse files
committed
api support for ensemble
1 parent e6602ef commit fd814e9

File tree

9 files changed

+130
-242
lines changed

9 files changed

+130
-242
lines changed

api/api_registry.yml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
electra:
2+
hugging_face:
3+
repo_id: aditya0by0/python-chebifier
4+
subfolder: electra
5+
files:
6+
ckpt: electra.ckpt
7+
labels: classes.txt
8+
package_name: chebai
9+
10+
resgated:
11+
hugging_face:
12+
repo_id: aditya0by0/python-chebifier
13+
subfolder: resgated
14+
files:
15+
ckpt: resgated.ckpt
16+
labels: classes.txt
17+
package_name: chebai-graph
18+
19+
chemlog:
20+
package_name: chemlog
21+
22+
23+
en_mv:
24+
ensemble_of: {electra, chemlog}

api/check_env.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import subprocess
2+
import sys
3+
4+
5+
def get_current_environment() -> str:
6+
"""
7+
Return the path of the Python executable for the current environment.
8+
"""
9+
return sys.executable
10+
11+
12+
def check_package_installed(package_name: str) -> None:
13+
"""
14+
Check if the given package is installed in the current Python environment.
15+
"""
16+
python_exec = get_current_environment()
17+
try:
18+
subprocess.check_output(
19+
[python_exec, "-m", "pip", "show", package_name], stderr=subprocess.DEVNULL
20+
)
21+
print(f"✅ Package '{package_name}' is already installed.")
22+
except subprocess.CalledProcessError:
23+
raise (
24+
f"❌ Please install '{package_name}' into your environment: {python_exec}"
25+
)
26+
27+
28+
if __name__ == "__main__":
29+
print(f"🔍 Using Python executable: {get_current_environment()}")
30+
check_package_installed("numpy") # Replace with your desired package

api/cli.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
1-
import importlib
21
from pathlib import Path
32

43
import click
54
import yaml
65

7-
from chebifier.prediction_models.base_predictor import BasePredictor
6+
from chebifier.model_registry import ENSEMBLES, MODEL_TYPES
87

8+
from .check_env import check_package_installed, get_current_environment
99
from .hugging_face import download_model_files
10-
from .setup_env import SetupEnvAndPackage
1110

12-
yaml_path = Path("api/registry.yml")
11+
yaml_path = Path("api/api_registry.yml")
1312
if yaml_path.exists():
1413
with yaml_path.open("r") as f:
15-
model_registry = yaml.safe_load(f)
14+
api_registry = yaml.safe_load(f)
1615
else:
1716
raise FileNotFoundError(f"{yaml_path} not found.")
1817

@@ -40,7 +39,7 @@ def cli():
4039
@click.option(
4140
"--model-type",
4241
"-m",
43-
type=click.Choice(model_registry.keys()),
42+
type=click.Choice(api_registry.keys()),
4443
default="mv",
4544
help="Type of model to use",
4645
)
@@ -60,29 +59,39 @@ def predict(smiles, smiles_file, output, model_type):
6059
click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.")
6160
return
6261

63-
model_config = model_registry[model_type]
64-
predictor_kwargs = {"model_name": model_type}
65-
66-
current_dir = Path(__file__).resolve().parent
67-
68-
if "hugging_face" in model_config:
69-
print(f"For model type `{model_type}` following files are used:")
70-
local_file_path = download_model_files(model_config["hugging_face"])
71-
predictor_kwargs["ckpt_path"] = local_file_path["ckpt"]
72-
predictor_kwargs["target_labels_path"] = local_file_path["labels"]
73-
74-
SetupEnvAndPackage().setup(
75-
repo_url=model_config["repo_url"],
76-
clone_dir=current_dir / ".cloned_repos",
77-
venv_dir=current_dir,
78-
)
79-
80-
model_cls_path = model_config["wrapper"]
81-
module_path, class_name = model_cls_path.rsplit(".", 1)
82-
module = importlib.import_module(module_path)
83-
model_cls: type = getattr(module, class_name)
84-
model_instance = model_cls(**predictor_kwargs)
85-
assert isinstance(model_instance, BasePredictor)
62+
print("Current working environment is:", get_current_environment())
63+
64+
def get_individual_model(model_config):
65+
predictor_kwargs = {}
66+
if "hugging_face" in model_config:
67+
predictor_kwargs = download_model_files(model_config["hugging_face"])
68+
check_package_installed(model_config["package_name"])
69+
return predictor_kwargs
70+
71+
if model_type in MODEL_TYPES:
72+
print(f"Predictor for Single/Individual Model: {model_type}")
73+
model_config = api_registry[model_type]
74+
predictor_kwargs = get_individual_model(model_config)
75+
predictor_kwargs["model_name"] = model_type
76+
model_instance = MODEL_TYPES[model_type](**predictor_kwargs)
77+
78+
elif model_type in ENSEMBLES:
79+
print(f"Predictor for Ensemble Model: {model_type}")
80+
ensemble_config = {}
81+
for i, en_comp in enumerate(api_registry[model_type]["ensemble_of"]):
82+
assert en_comp in MODEL_TYPES
83+
print(f"For ensemble component {en_comp}")
84+
predictor_kwargs = get_individual_model(api_registry[en_comp])
85+
model_key = f"model_{i + 1}"
86+
ensemble_config[model_key] = {
87+
"type": en_comp,
88+
"model_name": f"{en_comp}_{model_key}",
89+
**predictor_kwargs,
90+
}
91+
model_instance = ENSEMBLES[model_type](ensemble_config)
92+
93+
else:
94+
raise ValueError("")
8695

8796
# Make predictions
8897
predictions = model_instance.predict_smiles_list(smiles_list)

api/hugging_face.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,7 @@ def download_model_files(
4545
local_paths[file_type] = Path(downloaded_file_path)
4646
print(f"\t Using file `{filename}` from: {downloaded_file_path}")
4747

48-
return local_paths
48+
return {
49+
"ckpt_path": local_paths["ckpt"],
50+
"target_labels_path": local_paths["labels"],
51+
}

api/registry.yml

Lines changed: 0 additions & 23 deletions
This file was deleted.

api/setup_env.py

Lines changed: 0 additions & 165 deletions
This file was deleted.

chebifier/cli.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
import click
22
import yaml
33

4-
from chebifier.ensemble.base_ensemble import BaseEnsemble
5-
from chebifier.ensemble.weighted_majority_ensemble import (
6-
WMVwithF1Ensemble,
7-
WMVwithPPVNPVEnsemble,
8-
)
4+
from .model_registry import ENSEMBLES
95

106

117
@click.group()
@@ -14,13 +10,6 @@ def cli():
1410
pass
1511

1612

17-
ENSEMBLES = {
18-
"mv": BaseEnsemble,
19-
"wmv-ppvnpv": WMVwithPPVNPVEnsemble,
20-
"wmv-f1": WMVwithF1Ensemble,
21-
}
22-
23-
2413
@cli.command()
2514
@click.argument("config_file", type=click.Path(exists=True))
2615
@click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict")

0 commit comments

Comments
 (0)