Skip to content

Commit 4f35007

Browse files
committed
utility for loading class
1 parent 00bd478 commit 4f35007

File tree

4 files changed

+13
-32
lines changed

4 files changed

+13
-32
lines changed

chebai/ensemble/_scripts/_ensemble_run_script.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,11 @@
1-
import importlib
21
from typing import Any, Dict, Type
32

43
import yaml
54
from jsonargparse import ArgumentParser
65

7-
from ._base import EnsembleBase
6+
from chebai.ensemble._utils import _load_class
87

9-
10-
def load_class(class_path: str) -> Type[EnsembleBase]:
11-
"""
12-
Dynamically imports and returns a class from a full dotted path.
13-
14-
Args:
15-
class_path (str): Full module path to the class (e.g., 'my_package.module.MyClass').
16-
17-
Returns:
18-
Type[EnsembleBase]: The imported class object.
19-
20-
Raises:
21-
ModuleNotFoundError, AttributeError: If module or class cannot be loaded.
22-
"""
23-
module_path, class_name = class_path.rsplit(".", 1)
24-
module = importlib.import_module(module_path)
25-
return getattr(module, class_name)
8+
from .._base import EnsembleBase
269

2710

2811
def load_config_and_instantiate(config_path: str) -> EnsembleBase:
@@ -44,7 +27,7 @@ def load_config_and_instantiate(config_path: str) -> EnsembleBase:
4427
class_path: str = config["class_path"]
4528
init_args: Dict[str, Any] = config.get("init_args", {})
4629

47-
cls = load_class(class_path)
30+
cls = _load_class(class_path)
4831

4932
if not issubclass(cls, EnsembleBase):
5033
raise TypeError(f"{cls} must be subclass of EnsembleBase")

chebai/ensemble/_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import importlib
2+
3+
4+
def _load_class(class_path):
5+
module_path, class_name = class_path.rsplit(".", 1)
6+
module = importlib.import_module(module_path)
7+
return getattr(module, class_name)

chebai/ensemble/_wrappers/_base.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,3 @@ def _predict_from_list_of_smiles(self, smiles_list: list) -> dict:
136136
@abstractmethod
137137
def _predict_from_data_file(self, data_file_path: str) -> dict:
138138
pass
139-
140-
@staticmethod
141-
def _load_class(class_path):
142-
class_name = class_path.split(".")[-1]
143-
module_path = ".".join(class_path.split(".")[:-1])
144-
module = importlib.import_module(module_path)
145-
return getattr(module, class_name)

chebai/ensemble/_wrappers/_neural_network.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,13 @@
88
from chebai.preprocessing.reader import DataReader
99

1010
from .._constants import MODEL_CKPT_PATH, READER_CLS_PATH, READER_KWARGS
11+
from .._utils import _load_class
1112
from ._base import BaseWrapper
1213

1314

1415
class NNWrapper(BaseWrapper):
1516

16-
def __init__(
17-
self,
18-
**kwargs,
19-
):
17+
def __init__(self, **kwargs):
2018
self._validate_model_configs(**kwargs)
2119
super().__init__(**kwargs)
2220

@@ -28,7 +26,7 @@ def __init__(
2826
else dict()
2927
)
3028

31-
reader_cls: Type[DataReader] = self._load_class(self._reader_class_path)
29+
reader_cls: Type[DataReader] = _load_class(self._reader_class_path)
3230
assert issubclass(reader_cls, DataReader), ""
3331
self._reader = reader_cls(**self._reader_kwargs)
3432
self._collator = reader_cls.COLLATOR()

0 commit comments

Comments
 (0)