Skip to content

Commit ed08289

Browse files
committed
move config processing logic to base ensemble
1 parent e82f6d1 commit ed08289

File tree

3 files changed

+41
-49
lines changed

3 files changed

+41
-49
lines changed

chebifier/cli.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import yaml
55

66
from chebifier.model_registry import ENSEMBLES
7-
from chebifier.utils import get_default_configs
8-
97

108
@click.group()
119
def cli():
@@ -73,38 +71,10 @@ def predict(
7371
resolve_inconsistencies=True,
7472
):
7573
"""Predict ChEBI classes for SMILES strings using an ensemble model."""
76-
# Load configuration from YAML file
77-
if not ensemble_config:
78-
print("Using default ensemble configuration")
79-
config = get_default_configs()
80-
else:
81-
print(f"Loading ensemble configuration from {ensemble_config}")
82-
with open(ensemble_config, "r") as f:
83-
config = yaml.safe_load(f)
84-
85-
with (
86-
importlib.resources.files("chebifier")
87-
.joinpath("model_registry.yml")
88-
.open("r") as f
89-
):
90-
model_registry = yaml.safe_load(f)
91-
92-
new_config = {}
93-
for model_name, entry in config.items():
94-
if "load_model" in entry:
95-
if entry["load_model"] not in model_registry:
96-
raise ValueError(
97-
f"Model {entry['load_model']} not found in model registry. "
98-
f"Available models are: {','.join(model_registry.keys())}."
99-
)
100-
new_config[model_name] = {**model_registry[entry["load_model"]], **entry}
101-
else:
102-
new_config[model_name] = entry
103-
config = new_config
10474

10575
# Instantiate ensemble model
10676
ensemble = ENSEMBLES[ensemble_type](
107-
config,
77+
ensemble_config,
10878
chebi_version=chebi_version,
10979
resolve_inconsistencies=resolve_inconsistencies,
11080
)

chebifier/ensemble/base_ensemble.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
import torch
77
import tqdm
88
import yaml
9+
import importlib
910

1011
from chebifier.check_env import check_package_installed
1112
from chebifier.hugging_face import download_model_files
1213
from chebifier.inconsistency_resolution import PredictionSmoother
1314
from chebifier.prediction_models.base_predictor import BasePredictor
14-
from chebifier.utils import get_disjoint_files, load_chebi_graph, get_default_configs
15+
from chebifier.utils import get_disjoint_files, load_chebi_graph, get_default_configs, process_config
1516

1617

1718
class BaseEnsemble:
@@ -21,22 +22,34 @@ def __init__(
2122
chebi_version: int = 241,
2223
resolve_inconsistencies: bool = True,
2324
):
24-
if model_configs is None:
25-
model_configs = get_default_configs()
26-
elif isinstance(model_configs, (str, Path)):
27-
# Load configuration from YAML file
28-
with open(model_configs) as file:
29-
model_configs = yaml.safe_load(file)
30-
3125
# Deferred Import: To avoid circular import error
3226
from chebifier.model_registry import MODEL_TYPES
3327

28+
# Load configuration from YAML file
29+
if not model_configs:
30+
config = get_default_configs()
31+
elif isinstance(model_configs, dict):
32+
config = model_configs
33+
else:
34+
print(f"Loading ensemble configuration from {model_configs}")
35+
with open(model_configs, "r") as f:
36+
config = yaml.safe_load(f)
37+
38+
with (
39+
importlib.resources.files("chebifier")
40+
.joinpath("model_registry.yml")
41+
.open("r") as f
42+
):
43+
model_registry = yaml.safe_load(f)
44+
45+
processed_configs = process_config(config, model_registry)
46+
3447
self.chebi_graph = load_chebi_graph()
3548
self.disjoint_files = get_disjoint_files()
3649

3750
self.models = []
3851
self.positive_prediction_threshold = 0.5
39-
for model_name, model_config in model_configs.items():
52+
for model_name, model_config in processed_configs.items():
4053
model_cls = MODEL_TYPES[model_config["type"]]
4154
if "hugging_face" in model_config:
4255
hugging_face_kwargs = download_model_files(model_config["hugging_face"])

chebifier/utils.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,18 +126,27 @@ def get_disjoint_files():
126126
return disjoint_files
127127

128128

129-
if __name__ == "__main__":
130-
# chebi_graph = build_chebi_graph(chebi_version=241)
131-
# save the graph to a file
132-
# pickle.dump(chebi_graph, open("chebi_graph.pkl", "wb"))
133-
chebi_graph = load_chebi_graph()
134-
print(chebi_graph)
135-
136-
137129
def get_default_configs():
130+
default_config_name = "ensemble.yml"
131+
print(f"Using default ensemble configuration from {default_config_name}")
138132
with (
139133
importlib.resources.files("chebifier")
140-
.joinpath("ensemble.yml")
134+
.joinpath(default_config_name)
141135
.open("r") as f
142136
):
143137
return yaml.safe_load(f)
138+
139+
140+
def process_config(config, model_registry):
141+
new_config = {}
142+
for model_name, entry in config.items():
143+
if "load_model" in entry:
144+
if entry["load_model"] not in model_registry:
145+
raise ValueError(
146+
f"Model {entry['load_model']} not found in model registry. "
147+
f"Available models are: {','.join(model_registry.keys())}."
148+
)
149+
new_config[model_name] = {**model_registry[entry["load_model"]], **entry}
150+
else:
151+
new_config[model_name] = entry
152+
return new_config

0 commit comments

Comments
 (0)