Skip to content

Commit 187572b

Browse files
committed
Simplify getting started
This PR updates the BaseEnsemble constructor to allow the following: 1. Passing a string or path to the configuration 2. Not passing a configuration at all, which will automatically load the default configuration. This is now the default, since most users won't want to have to configure it (it should have reasonable defaults)
1 parent ccbeb74 commit 187572b

File tree

5 files changed

+35
-16
lines changed

5 files changed

+35
-16
lines changed

README.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,11 @@ python -m chebifier predict --help
7575
You can also use the package programmatically:
7676

7777
```python
78-
from chebifier.ensemble.base_ensemble import BaseEnsemble
79-
import yaml
78+
from chebifier import BaseEnsemble
8079
81-
# Load configuration from YAML file
82-
with open('configs/example_config.yml', 'r') as f:
83-
config = yaml.safe_load(f)
84-
85-
# Instantiate ensemble model
86-
ensemble = BaseEnsemble(config)
80+
# Instantiate ensemble model. If desired, can pass
81+
# a path to a configuration, like 'configs/example_config.yml'
82+
ensemble = BaseEnsemble()
8783
8884
# Make predictions
8985
smiles_list = ["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C(=O)O"]

chebifier/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,10 @@
22
# even if multiple subpackages are imported later.
33

44
from ._custom_cache import PerSmilesPerModelLRUCache
5+
from chebifier.ensemble.base_ensemble import BaseEnsemble
6+
7+
__all__ = [
8+
"BaseEnsemble",
9+
]
510

611
modelwise_smiles_lru_cache = PerSmilesPerModelLRUCache(max_size=100)

chebifier/cli.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import yaml
55

66
from chebifier.model_registry import ENSEMBLES
7+
from chebifier.utils import get_default_configs
78

89

910
@click.group()
@@ -75,12 +76,7 @@ def predict(
7576
# Load configuration from YAML file
7677
if not ensemble_config:
7778
print("Using default ensemble configuration")
78-
with (
79-
importlib.resources.files("chebifier")
80-
.joinpath("ensemble.yml")
81-
.open("r") as f
82-
):
83-
config = yaml.safe_load(f)
79+
config = get_default_configs()
8480
else:
8581
print(f"Loading ensemble configuration from {ensemble_config}")
8682
with open(ensemble_config, "r") as f:

chebifier/ensemble/base_ensemble.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,33 @@
11
import os
22
import time
3+
from pathlib import Path
4+
from typing import Union
35

46
import torch
57
import tqdm
8+
import yaml
69

710
from chebifier.check_env import check_package_installed
811
from chebifier.hugging_face import download_model_files
912
from chebifier.inconsistency_resolution import PredictionSmoother
1013
from chebifier.prediction_models.base_predictor import BasePredictor
11-
from chebifier.utils import get_disjoint_files, load_chebi_graph
14+
from chebifier.utils import get_disjoint_files, load_chebi_graph, get_default_configs
1215

1316

1417
class BaseEnsemble:
1518
def __init__(
1619
self,
17-
model_configs: dict,
20+
model_configs: Union[str, Path, dict, None] = None,
1821
chebi_version: int = 241,
1922
resolve_inconsistencies: bool = True,
2023
):
24+
if model_configs is None:
25+
model_configs = get_default_configs()
26+
elif isinstance(model_configs, (str, Path)):
27+
# Load configuration from YAML file
28+
with open(model_configs) as file:
29+
model_configs = yaml.safe_load(file)
30+
2131
# Deferred Import: To avoid circular import error
2232
from chebifier.model_registry import MODEL_TYPES
2333

chebifier/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import importlib.resources
12
import os
23

34
import networkx as nx
45
import requests
56
import fastobo
7+
import yaml
8+
69
from chebifier.hugging_face import download_model_files
710
import pickle
811

@@ -129,3 +132,12 @@ def get_disjoint_files():
129132
# pickle.dump(chebi_graph, open("chebi_graph.pkl", "wb"))
130133
chebi_graph = load_chebi_graph()
131134
print(chebi_graph)
135+
136+
137+
def get_default_configs():
138+
with (
139+
importlib.resources.files("chebifier")
140+
.joinpath("ensemble.yml")
141+
.open("r") as f
142+
):
143+
return yaml.safe_load(f)

0 commit comments

Comments
 (0)