Skip to content

Commit 87fb66a

Browse files
committed
move files from api to chebifier, add files to huggingface
1 parent e4f1c54 commit 87fb66a

File tree

10 files changed

+130
-189
lines changed

10 files changed

+130
-189
lines changed

api/__init__.py

Whitespace-only changes.

api/__main__.py

Lines changed: 0 additions & 10 deletions
This file was deleted.

api/api_registry.yml

Lines changed: 0 additions & 24 deletions
This file was deleted.

api/cli.py

Lines changed: 0 additions & 121 deletions
This file was deleted.
File renamed without changes.

chebifier/cli.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import importlib.resources
12
import os
23

34
import click
@@ -14,9 +15,10 @@ def cli():
1415

1516
@cli.command()
1617
@click.option(
17-
"--config_file",
18+
"--ensemble-config",
19+
"-e",
1820
type=click.Path(exists=True),
19-
default=os.path.join("configs", "huggingface_config.yml"),
21+
default=None,
2022
help="Configuration file for ensemble models",
2123
)
2224
@click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict")
@@ -34,10 +36,10 @@ def cli():
3436
)
3537
@click.option(
3638
"--ensemble-type",
37-
"-e",
39+
"-t",
3840
type=click.Choice(ENSEMBLES.keys()),
39-
default="mv",
40-
help="Type of ensemble to use (default: Majority Voting)",
41+
default="wmv-f1",
42+
help="Type of ensemble to use (default: Weighted Majority Voting)",
4143
)
4244
@click.option(
4345
"--chebi-version",
@@ -53,25 +55,53 @@ def cli():
5355
default=True,
5456
help="Weight predictions based on how 'confident' a model is in its prediction (default: True)",
5557
)
58+
@click.option(
59+
"--resolve-inconsistencies",
60+
"-r",
61+
is_flag=True,
62+
default=True,
63+
help="Resolve inconsistencies in predictions automatically (default: True)",
64+
)
5665
def predict(
57-
config_file,
66+
ensemble_config,
5867
smiles,
5968
smiles_file,
6069
output,
6170
ensemble_type,
6271
chebi_version,
6372
use_confidence,
73+
resolve_inconsistencies=True,
6474
):
6575
"""Predict ChEBI classes for SMILES strings using an ensemble model.
66-
67-
CONFIG_FILE is the path to a YAML configuration file for the ensemble model.
68-
"""
76+
"""
6977
# Load configuration from YAML file
70-
with open(config_file, "r") as f:
71-
config = yaml.safe_load(f)
78+
if not ensemble_config:
79+
print(f"Using default ensemble configuration")
80+
with importlib.resources.files("chebifier").joinpath("ensemble.yml").open("r") as f:
81+
config = yaml.safe_load(f)
82+
else:
83+
print(f"Loading ensemble configuration from {ensemble_config}")
84+
with open(ensemble_config, "r") as f:
85+
config = yaml.safe_load(f)
86+
87+
with importlib.resources.files("chebifier").joinpath("model_registry.yml").open("r") as f:
88+
model_registry = yaml.safe_load(f)
89+
90+
new_config = {}
91+
for model_name, entry in config.items():
92+
if "load_model" in entry:
93+
if entry["load_model"] not in model_registry:
94+
raise ValueError(
95+
f"Model {entry['load_model']} not found in model registry. "
96+
f"Available models are: {','.join(model_registry.keys())}."
97+
)
98+
new_config[model_name] = {**model_registry[entry["load_model"]], **entry}
99+
else:
100+
new_config[model_name] = entry
101+
config = new_config
72102

73103
# Instantiate ensemble model
74-
ensemble = ENSEMBLES[ensemble_type](config, chebi_version=chebi_version)
104+
ensemble = ENSEMBLES[ensemble_type](config, chebi_version=chebi_version, resolve_inconsistencies=resolve_inconsistencies)
75105

76106
# Collect SMILES strings from arguments and/or file
77107
smiles_list = list(smiles)

chebifier/ensemble.yml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
electra:
2+
load_model: electra_chebi50_v241
3+
resgated:
4+
load_model: resgated_chebi50_v241
5+
chemlog_peptides:
6+
type: chemlog_peptides
7+
model_weight: 100
8+
chemlog_element:
9+
type: chemlog_element
10+
model_weight: 100
11+
chemlog_organox:
12+
type: chemlog_organox
13+
model_weight: 100
14+
c3p:
15+
load_model: c3p_with_weights

chebifier/ensemble/base_ensemble.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,45 +6,63 @@
66
from chebai.preprocessing.datasets.chebi import ChEBIOver50
77
from chebai.result.analyse_sem import PredictionSmoother, get_chebi_graph
88

9+
from chebifier.check_env import check_package_installed
910
from chebifier.prediction_models.base_predictor import BasePredictor
10-
from functools import lru_cache
11+
1112

1213
class BaseEnsemble:
1314

14-
def __init__(self, model_configs: dict, chebi_version: int = 241):
15+
def __init__(self, model_configs: dict, chebi_version: int = 241, resolve_inconsistencies: bool = True):
1516
# Deferred Import: To avoid circular import error
1617
from chebifier.model_registry import MODEL_TYPES
1718

1819
self.chebi_dataset = ChEBIOver50(chebi_version=chebi_version)
1920
self.chebi_dataset._download_required_data() # download chebi if not already downloaded
2021
self.chebi_graph = get_chebi_graph(self.chebi_dataset, None)
21-
self.disjoint_files = [
22+
local_disjoint_files = [
2223
os.path.join("data", "disjoint_chebi.csv"),
2324
os.path.join("data", "disjoint_additional.csv"),
2425
]
26+
self.disjoint_files = []
27+
for file in local_disjoint_files:
28+
if os.path.isfile(file):
29+
self.disjoint_files.append(file)
30+
else:
31+
print(f"Disjoint axiom file {file} not found. Loading from huggingface instead...")
32+
from chebifier.hugging_face import download_model_files
33+
self.disjoint_files.append(download_model_files({
34+
"repo_id": "chebai/chebifier",
35+
"repo_type": "dataset",
36+
"files": {"disjoint_file": os.path.basename(file)},
37+
})["disjoint_file"])
2538

2639
self.models = []
2740
self.positive_prediction_threshold = 0.5
2841
for model_name, model_config in model_configs.items():
2942
model_cls = MODEL_TYPES[model_config["type"]]
3043
if "hugging_face" in model_config:
31-
from api.hugging_face import download_model_files
44+
from chebifier.hugging_face import download_model_files
3245
hugging_face_kwargs = download_model_files(model_config["hugging_face"])
3346
else:
3447
hugging_face_kwargs = {}
48+
if "package_name" in model_config:
49+
check_package_installed(model_config["package_name"])
50+
3551
model_instance = model_cls(
3652
model_name, **model_config, **hugging_face_kwargs, chebi_graph=self.chebi_graph
3753
)
3854
assert isinstance(model_instance, BasePredictor)
3955
self.models.append(model_instance)
4056

4157

42-
43-
self.smoother = PredictionSmoother(
44-
self.chebi_dataset,
45-
label_names=None,
46-
disjoint_files=self.disjoint_files,
47-
)
58+
if resolve_inconsistencies:
59+
self.smoother = PredictionSmoother(
60+
self.chebi_dataset,
61+
label_names=None,
62+
disjoint_files=self.disjoint_files,
63+
)
64+
else:
65+
self.smoother = None
4866

4967
def gather_predictions(self, smiles_list):
5068
# get predictions from all models for the SMILES list
@@ -131,15 +149,15 @@ def consolidate_predictions(self, predictions, classwise_weights, predicted_clas
131149
# Smooth predictions
132150
start_time = time.perf_counter()
133151
class_names = list(predicted_classes.keys())
134-
self.smoother.set_label_names(class_names)
135-
smooth_net_score = self.smoother(net_score)
152+
if self.smoother is not None:
153+
self.smoother.set_label_names(class_names)
154+
smooth_net_score = self.smoother(net_score)
155+
class_decisions = (smooth_net_score > 0.5) & has_valid_predictions # Shape: (num_smiles, num_classes)
156+
else:
157+
class_decisions = (net_score > 0) & has_valid_predictions # Shape: (num_smiles, num_classes)
136158
end_time = time.perf_counter()
137159
print(f"Prediction smoothing took {end_time - start_time:.2f} seconds")
138160

139-
class_decisions = (
140-
smooth_net_score > 0.5
141-
) & has_valid_predictions # Shape: (num_smiles, num_classes)
142-
143161
complete_failure = torch.all(~has_valid_predictions, dim=1)
144162
return class_decisions, complete_failure
145163

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,27 +25,26 @@ def download_model_files(
2525
model_config (Dict[str, str | Dict[str, str]]): A dictionary containing:
2626
- 'repo_id' (str): The Hugging Face repository ID (e.g., 'username/modelname').
2727
- 'subfolder' (str): The subfolder within the repo where the files are located.
28-
- 'files' (Dict[str, str]): A mapping from file type (e.g., 'ckpt', 'labels') to
28+
- 'files' (Dict[str, str]): A mapping from file type (e.g., 'ckpt_path', 'target_labels_path') to
2929
actual file names (e.g., 'electra.ckpt', 'classes.txt').
3030
3131
Returns:
3232
Dict[str, Path]: A dictionary mapping each file type to the local Path of the downloaded file.
3333
"""
3434
repo_id = model_config["repo_id"]
35-
subfolder = model_config["subfolder"]
35+
subfolder = model_config.get("subfolder", None)
36+
repo_type = model_config.get("repo_type", "model")
3637
filenames = model_config["files"]
3738

3839
local_paths: dict[str, Path] = {}
3940
for file_type, filename in filenames.items():
4041
downloaded_file_path = hf_hub_download(
4142
repo_id=repo_id,
4243
filename=filename,
44+
repo_type=repo_type,
4345
subfolder=subfolder,
4446
)
4547
local_paths[file_type] = Path(downloaded_file_path)
4648
print(f"\t Using file `{filename}` from: {downloaded_file_path}")
4749

48-
return {
49-
"ckpt_path": local_paths["ckpt"],
50-
"target_labels_path": local_paths["labels"],
51-
}
50+
return local_paths

0 commit comments

Comments
 (0)