44from typing import Any , Deque , Dict
55
66import torch
7+ from lightning import LightningModule
78from torch import Tensor
89
910from chebai .models import ChebaiBaseNet
1011from chebai .preprocessing .collate import RaggedCollator
1112
1213from ._base import EnsembleBase
14+ from ._constants import WRAPPER_CLS_PATH
15+ from ._utils import _load_class
16+ from ._wrappers import BaseWrapper
1317
1418
1519class _Controller (EnsembleBase , ABC ):
@@ -30,10 +34,7 @@ def __init__(self, **kwargs: Any):
3034 **kwargs (Any): Keyword arguments passed to the EnsembleBase initializer.
3135 """
3236 super ().__init__ (** kwargs )
33- self ._collator = RaggedCollator ()
34-
35- self .input_dim = len (self ._collated_data .x [0 ])
36- self ._total_data_size : int = len (self ._collated_data )
37+ self ._kwargs = kwargs
3738
3839 def _get_pred_conf_from_model_output (
3940 self , model_output : Dict [str , Tensor ], model_label_mask : Tensor
@@ -60,6 +61,20 @@ def _get_pred_conf_from_model_output(
6061 confidence [:, model_label_mask ] = 2 * torch .abs (sigmoid_logits - 0.5 )
6162 return {"prediction" : prediction , "confidence" : confidence }
6263
64+ def _wrap_model (self , model_name : str ) -> BaseWrapper :
65+ model_config = self ._model_configs [model_name ]
66+ wrp_cls = _load_class (model_config [WRAPPER_CLS_PATH ])
67+ assert issubclass (wrp_cls , BaseWrapper ), ""
68+ wrapped_model = wrp_cls (
69+ model_name = model_name ,
70+ model_config = model_config ,
71+ dm_labels = self ._dm_labels ,
72+ ** self ._kwargs
73+ )
74+ assert isinstance (wrapped_model , BaseWrapper ), ""
75+ # del wrapped_model # Model can be huge to keep it in memory, delete as no longer needed
76+ return wrapped_model
77+
6378
6479class NoActivationCondition (_Controller ):
6580 """
@@ -76,11 +91,9 @@ def __init__(self, **kwargs: Any):
7691 **kwargs (Any): Keyword arguments passed to the _Controller initializer.
7792 """
7893 super ().__init__ (** kwargs )
79- self ._model_queue : Deque [str ] = deque (list (self .model_configs .keys ()))
94+ self ._model_queue : Deque [str ] = deque (list (self ._model_configs .keys ()))
8095
81- def _controller (
82- self , model : ChebaiBaseNet , model_props : Dict [str , Tensor ], ** kwargs : Any
83- ) -> Dict [str , Tensor ]:
96+ def _controller (self , model_name , ** kwargs : Any ) -> Dict [str , Tensor ]:
8497 """
8598 Performs inference with the model and extracts predictions and confidence values.
8699
@@ -91,5 +104,5 @@ def _controller(
91104 Returns:
92105 Dict[str, Tensor]: Dictionary containing predictions and confidence scores.
93106 """
94- model_output = self ._forward_pass ( model )
107+ wrapped_model = self ._wrap_model ( model_name )
95108 return self ._get_pred_conf_from_model_output (model_output , model_props ["mask" ])
0 commit comments