Skip to content

Commit f812cd7

Browse files
committed
update controller for wrapper
1 parent a1a70eb commit f812cd7

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

chebai/ensemble/_controller.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44
from typing import Any, Deque, Dict
55

66
import torch
7+
from lightning import LightningModule
78
from torch import Tensor
89

910
from chebai.models import ChebaiBaseNet
1011
from chebai.preprocessing.collate import RaggedCollator
1112

1213
from ._base import EnsembleBase
14+
from ._constants import WRAPPER_CLS_PATH
15+
from ._utils import _load_class
16+
from ._wrappers import BaseWrapper
1317

1418

1519
class _Controller(EnsembleBase, ABC):
@@ -30,10 +34,7 @@ def __init__(self, **kwargs: Any):
3034
**kwargs (Any): Keyword arguments passed to the EnsembleBase initializer.
3135
"""
3236
super().__init__(**kwargs)
33-
self._collator = RaggedCollator()
34-
35-
self.input_dim = len(self._collated_data.x[0])
36-
self._total_data_size: int = len(self._collated_data)
37+
self._kwargs = kwargs
3738

3839
def _get_pred_conf_from_model_output(
3940
self, model_output: Dict[str, Tensor], model_label_mask: Tensor
@@ -60,6 +61,20 @@ def _get_pred_conf_from_model_output(
6061
confidence[:, model_label_mask] = 2 * torch.abs(sigmoid_logits - 0.5)
6162
return {"prediction": prediction, "confidence": confidence}
6263

64+
def _wrap_model(self, model_name: str) -> BaseWrapper:
65+
model_config = self._model_configs[model_name]
66+
wrp_cls = _load_class(model_config[WRAPPER_CLS_PATH])
67+
assert issubclass(wrp_cls, BaseWrapper), ""
68+
wrapped_model = wrp_cls(
69+
model_name=model_name,
70+
model_config=model_config,
71+
dm_labels=self._dm_labels,
72+
**self._kwargs
73+
)
74+
assert isinstance(wrapped_model, BaseWrapper), ""
75+
# del wrapped_model # Model can be huge to keep it in memory, delete as no longer needed
76+
return wrapped_model
77+
6378

6479
class NoActivationCondition(_Controller):
6580
"""
@@ -76,11 +91,9 @@ def __init__(self, **kwargs: Any):
7691
**kwargs (Any): Keyword arguments passed to the _Controller initializer.
7792
"""
7893
super().__init__(**kwargs)
79-
self._model_queue: Deque[str] = deque(list(self.model_configs.keys()))
94+
self._model_queue: Deque[str] = deque(list(self._model_configs.keys()))
8095

81-
def _controller(
82-
self, model: ChebaiBaseNet, model_props: Dict[str, Tensor], **kwargs: Any
83-
) -> Dict[str, Tensor]:
96+
def _controller(self, model_name, **kwargs: Any) -> Dict[str, Tensor]:
8497
"""
8598
Performs inference with the model and extracts predictions and confidence values.
8699
@@ -91,5 +104,5 @@ def _controller(
91104
Returns:
92105
Dict[str, Tensor]: Dictionary containing predictions and confidence scores.
93106
"""
94-
model_output = self._forward_pass(model)
107+
wrapped_model = self._wrap_model(model_name)
95108
return self._get_pred_conf_from_model_output(model_output, model_props["mask"])
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from ._base import BaseWrapper
12
from ._neural_network import NNWrapper
23

3-
__all__ = ["NNWrapper"]
4+
__all__ = ["NNWrapper", "BaseWrapper"]

0 commit comments

Comments
 (0)