Skip to content

Commit 584b6a6

Browse files
committed
api cli
1 parent 481a2eb commit 584b6a6

File tree

2 files changed

+124
-0
lines changed

2 files changed

+124
-0
lines changed

api/__main__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from .cli import cli
2+
3+
if __name__ == "__main__":
4+
"""
5+
Entry point for the CLI application.
6+
7+
This script calls the `cli` function from the `api.cli` module
8+
when executed as the main program.
9+
"""
10+
cli()

api/cli.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import importlib
2+
from pathlib import Path
3+
4+
import click
5+
import yaml
6+
7+
from chebifier.prediction_models.base_predictor import BasePredictor
8+
9+
from .hugging_face import download_model_files
10+
from .setup_env import SetupEnvAndPackage
11+
12+
yaml_path = Path("api/registry.yml")
13+
if yaml_path.exists():
14+
with yaml_path.open("r") as f:
15+
model_registry = yaml.safe_load(f)
16+
else:
17+
raise FileNotFoundError(f"{yaml_path} not found.")
18+
19+
20+
@click.group()
21+
def cli():
22+
"""Command line interface for Api-Chebifier."""
23+
pass
24+
25+
26+
@cli.command()
27+
@click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict")
28+
@click.option(
29+
"--smiles-file",
30+
"-f",
31+
type=click.Path(exists=True),
32+
help="File containing SMILES strings (one per line)",
33+
)
34+
@click.option(
35+
"--output",
36+
"-o",
37+
type=click.Path(),
38+
help="Output file to save predictions (optional)",
39+
)
40+
@click.option(
41+
"--model-type",
42+
"-m",
43+
type=click.Choice(model_registry.keys()),
44+
default="mv",
45+
help="Type of model to use",
46+
)
47+
def predict(smiles, smiles_file, output, model_type):
48+
"""Predict ChEBI classes for SMILES strings using an ensemble model.
49+
50+
CONFIG_FILE is the path to a YAML configuration file for the ensemble model.
51+
"""
52+
53+
# Collect SMILES strings from arguments and/or file
54+
smiles_list = list(smiles)
55+
if smiles_file:
56+
with open(smiles_file, "r") as f:
57+
smiles_list.extend([line.strip() for line in f if line.strip()])
58+
59+
if not smiles_list:
60+
click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.")
61+
return
62+
63+
model_config = model_registry[model_type]
64+
predictor_kwargs = {"model_name": model_type}
65+
66+
current_dir = Path(__file__).resolve().parent
67+
68+
if "hugging_face" in model_config:
69+
local_file_path = download_model_files(
70+
model_config["hugging_face"],
71+
current_dir / ".api_models" / model_type,
72+
)
73+
predictor_kwargs["ckpt_path"] = local_file_path["ckpt"]
74+
predictor_kwargs["target_labels_path"] = local_file_path["labels"]
75+
76+
SetupEnvAndPackage().setup(
77+
repo_url=model_config["repo_url"],
78+
clone_dir=current_dir / ".cloned_repos",
79+
venv_dir=current_dir,
80+
)
81+
82+
model_cls_path = model_config["wrapper"]
83+
module_path, class_name = model_cls_path.rsplit(".", 1)
84+
module = importlib.import_module(module_path)
85+
model_cls: type = getattr(module, class_name)
86+
model_instance = model_cls(**predictor_kwargs)
87+
assert isinstance(model_instance, BasePredictor)
88+
89+
# Make predictions
90+
predictions = model_instance.predict_smiles_list(smiles_list)
91+
92+
if output:
93+
# save as json
94+
import json
95+
96+
with open(output, "w") as f:
97+
json.dump(
98+
{smiles: pred for smiles, pred in zip(smiles_list, predictions)},
99+
f,
100+
indent=2,
101+
)
102+
103+
else:
104+
# Print results
105+
for i, (smiles, prediction) in enumerate(zip(smiles_list, predictions)):
106+
click.echo(f"Result for: {smiles}")
107+
if prediction:
108+
click.echo(f" Predicted classes: {', '.join(map(str, prediction))}")
109+
else:
110+
click.echo(" No predictions")
111+
112+
113+
if __name__ == "__main__":
114+
cli()

0 commit comments

Comments
 (0)