Skip to content

Commit 825916e

Browse files
committed
add script for running ensemble
1 parent 4d6856d commit 825916e

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import importlib
2+
3+
import yaml
4+
from jsonargparse import ArgumentParser
5+
6+
from chebai.ensemble.base import EnsembleBase
7+
8+
9+
def load_class(class_path: str):
10+
"""Dynamically import a class from a full dotted path."""
11+
module_path, class_name = class_path.rsplit(".", 1)
12+
module = importlib.import_module(module_path)
13+
return getattr(module, class_name)
14+
15+
16+
def load_config_and_instantiate(config_path: str):
17+
with open(config_path, "r") as f:
18+
config = yaml.safe_load(f)
19+
20+
class_path = config["class_path"]
21+
init_args = config.get("init_args", {})
22+
23+
cls = load_class(class_path)
24+
if not issubclass(cls, EnsembleBase):
25+
raise TypeError(f"{cls} must be subclass of EnsembleBase")
26+
return cls(**init_args)
27+
28+
29+
if __name__ == "__main__":
30+
parser = ArgumentParser()
31+
parser.add_argument("--config", type=str, help="Path to the YAML config file")
32+
33+
args = parser.parse_args()
34+
ensemble = load_config_and_instantiate(args.config)
35+
36+
if not isinstance(ensemble, EnsembleBase):
37+
raise TypeError("Object must be an instance of `EnsembleBase`")
38+
39+
ensemble.run_ensemble()

0 commit comments

Comments
 (0)