Skip to content

Commit c0cb6c9

Browse files
committed
fix collated labels none error
1 parent a20ce76 commit c0cb6c9

File tree

4 files changed

+18
-10
lines changed

4 files changed

+18
-10
lines changed

chebai/ensemble/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from ._base import EnsembleBase
12
from ._consolidator import WeightedMajorityVoting
23
from ._controller import NoActivationCondition
4+
from ._wrappers import NNWrapper
35

46

57
class FullEnsembleWMV(NoActivationCondition, WeightedMajorityVoting):
@@ -8,4 +10,4 @@ class FullEnsembleWMV(NoActivationCondition, WeightedMajorityVoting):
810
pass
911

1012

11-
__all__ = ["FullEnsembleWMV"]
13+
__all__ = ["FullEnsembleWMV", "NNWrapper"]

chebai/ensemble/_base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,9 @@ def run_ensemble(self) -> None:
205205
)
206206

207207
if self._operation == EVAL_OP:
208+
assert (
209+
self._collated_labels is not None
210+
), "Collated labels must be set for evaluation operation."
208211
print_metrics(
209212
final_preds,
210213
self._collated_labels,

chebai/ensemble/_controller.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch import Tensor
88

99
from ._base import EnsembleBase
10-
from ._constants import EVAL_OP, PRED_OP, WRAPPER_CLS_PATH
10+
from ._constants import PRED_OP, WRAPPER_CLS_PATH
1111
from ._utils import load_class
1212
from ._wrappers import BaseWrapper
1313

@@ -35,7 +35,7 @@ def __init__(self, **kwargs: Any):
3535
# This is in order to avoid re-adding models that have already been processed
3636
self._model_key_set: set[str] = set(self._model_configs.keys())
3737

38-
# Labels from any processed data.pt file for any reader
38+
# Labels from any processed `data.pt` file of any reader
3939
self._collated_labels: torch.Tensor | None = None
4040

4141
def _controller(
@@ -56,6 +56,12 @@ def _controller(
5656
model_output, model_props = wrapped_model.predict(model_input)
5757
else:
5858
model_output, model_props = wrapped_model.evaluate(model_input)
59+
if (
60+
self._collated_labels is None
61+
and wrapped_model.collated_labels is not None
62+
):
63+
self._collated_labels = wrapped_model.collated_labels
64+
5965
del wrapped_model # Model can be huge to keep it in memory, delete asap as no longer needed
6066

6167
pred_conf_dict = self._get_pred_conf_from_model_output(
@@ -98,8 +104,6 @@ def _wrap_model(self, model_name: str) -> BaseWrapper:
98104
dm_labels=self._dm_labels,
99105
**self._kwargs
100106
)
101-
if self._collated_labels is not None and self._operation == EVAL_OP:
102-
self._collated_labels = wrapped_model.collated_labels
103107

104108
assert isinstance(wrapped_model, BaseWrapper), ""
105109
return wrapped_model

chebai/ensemble/_scripts/_ensemble_run_script.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from typing import Any, Dict, Type
1+
from typing import Any, Dict
22

33
import yaml
44
from jsonargparse import ArgumentParser
55

6-
from chebai.ensemble._utils import _load_class
7-
8-
from .._base import EnsembleBase
6+
from chebai.ensemble._base import EnsembleBase
7+
from chebai.ensemble._utils import load_class
98

109

1110
def load_config_and_instantiate(config_path: str) -> EnsembleBase:
@@ -27,7 +26,7 @@ def load_config_and_instantiate(config_path: str) -> EnsembleBase:
2726
class_path: str = config["class_path"]
2827
init_args: Dict[str, Any] = config.get("init_args", {})
2928

30-
cls = _load_class(class_path)
29+
cls = load_class(class_path)
3130

3231
if not issubclass(cls, EnsembleBase):
3332
raise TypeError(f"{cls} must be subclass of EnsembleBase")

0 commit comments

Comments
 (0)