11from abc import ABC , abstractmethod
22from collections import deque
33from pathlib import Path
4- from typing import Any , Deque , Dict , Optional
4+ from typing import Any , Deque , Dict , Literal , Optional
55
66import pandas as pd
77import torch
88from lightning import LightningModule
99
1010from chebai .result .classification import print_metrics
1111
12- from ._constants import MODEL_CLS_PATH , MODEL_LBL_PATH , WRAPPER_CLS_PATH
12+ from ._constants import (
13+ EVAL_OP ,
14+ MODEL_CLS_PATH ,
15+ MODEL_LBL_PATH ,
16+ PRED_OP ,
17+ WRAPPER_CLS_PATH ,
18+ )
1319
1420
1521class EnsembleBase (ABC ):
@@ -22,38 +28,40 @@ class EnsembleBase(ABC):
2228 def __init__ (
2329 self ,
2430 model_configs : Dict [str , Dict [str , Any ]],
25- data_file_path : str ,
26- classes_file_path : str ,
31+ data_processed_dir_main : str ,
32+ operation : str = EVAL_OP ,
2733 ** kwargs : Any ,
2834 ) -> None :
2935 """
3036 Initializes the ensemble model and loads configurations, labels, and sets up the environment.
3137
3238 Args:
3339 model_configs (Dict[str, Dict[str, Any]]): Dictionary of model configurations.
34- data_file_path (str): Path to the processed data directory.
40+ data_processed_dir_main (str): Path to the processed data directory.
3541 reader_dir_name (str): Name of the directory used by the reader. Defaults to 'smiles_token'.
3642 **kwargs (Any): Additional arguments, such as 'input_dim' and '_validate_configs'.
3743 """
38- if bool (kwargs .get ("_validate_configs" , True )):
39- self ._validate_model_configs (model_configs )
44+ if bool (kwargs .get ("_perform_validation_checks" , True )):
45+ self ._perform_validation_checks (
46+ model_configs , operation = operation , ** kwargs
47+ )
4048
4149 self ._model_configs : Dict [str , Dict [str , Any ]] = model_configs
42- self ._data_file_path : str = data_file_path
43- self ._classes_file_path : str = classes_file_path
50+ self ._data_processed_dir_main : str = data_processed_dir_main
51+ self ._operation : str = operation
52+ print (f"Ensemble operation: { self ._operation } " )
53+
4454 self ._input_dim : Optional [int ] = kwargs .get ("input_dim" , None )
4555 self ._total_data_size : int = None
4656 self ._ensemble_input : list [str ] | Path = self ._process_input_to_ensemble (
47- data_file_path
57+ ** kwargs
4858 )
4959 print (f"Total data size (data.pkl) is { self ._total_data_size } " )
5060
5161 self ._device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
5262
5363 self ._models : Dict [str , LightningModule ] = {}
54- self ._dm_labels : Dict [str , int ] = self ._load_data_module_labels (
55- classes_file_path
56- )
64+ self ._dm_labels : Dict [str , int ] = self ._load_data_module_labels ()
5765 self ._num_of_labels : int = len (self ._dm_labels )
5866 print (f"Number of labes for this data is { self ._num_of_labels } " )
5967
@@ -63,7 +71,9 @@ def __init__(
6371 self ._model_queue : Deque [str ] = deque ()
6472
6573 @classmethod
66- def _validate_model_configs (cls , model_configs : Dict [str , Dict [str , Any ]]) -> None :
74+ def _perform_validation_checks (
75+ cls , model_configs : Dict [str , Dict [str , Any ]], operation , ** kwargs
76+ ) -> None :
6777 """
6878 Validates model configuration dictionary for required keys and uniqueness.
6979
@@ -74,6 +84,19 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> No
7484 AttributeError: If any model config is missing required keys.
7585 ValueError: If duplicate paths are found for model checkpoint, class, or labels.
7686 """
87+ if operation not in ["evaluate" , "predict" ]:
88+ raise ValueError (
89+ f"Invalid operation '{ operation } '. Must be 'evaluate' or 'predict'."
90+ )
91+
92+ if operation == "predict" and not kwargs .get ("smiles_list_file_path" , None ):
93+ raise ValueError (
94+ "For 'predict' operation, 'smiles_list_file_path' must be provided."
95+ )
96+
97+ if not Path (kwargs .get ("smiles_list_file_path" )).exists ():
98+ raise FileNotFoundError (f"{ kwargs .get ('smiles_list_file_path' )} " )
99+
77100 class_set , labels_set = set (), set ()
78101 required_keys = {
79102 MODEL_CLS_PATH ,
@@ -103,9 +126,9 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> No
103126 class_set .add (model_class_path )
104127 labels_set .add (model_labels_path )
105128
106- def _process_input_to_ensemble (self , path : str ) :
107- p = Path ( path )
108- if p . is_file ():
129+ def _process_input_to_ensemble (self , ** kwargs : any ) -> list [ str ] | Path :
130+ if self . _operation == PRED_OP :
131+ p = Path ( kwargs [ "smiles_list_file_path" ])
109132 smiles_list = []
110133 with open (p , "r" ) as f :
111134 for line in f :
@@ -116,24 +139,23 @@ def _process_input_to_ensemble(self, path: str):
116139 smiles_list .append (smiles )
117140 self ._total_data_size = len (smiles_list )
118141 return smiles_list
119- elif p . is_dir () :
120- data_pkl_path = p / "data.pkl"
142+ elif self . _operation == EVAL_OP :
143+ data_pkl_path = Path ( self . _data_processed_dir_main ) / "data.pkl"
121144 if not data_pkl_path .exists ():
122145 raise FileNotFoundError ()
123146 self ._total_data_size = len (pd .read_pickle (data_pkl_path ))
124147 return p
125148 else :
126- raise "Invalid path"
149+ raise ValueError ( "Invalid operation" )
127150
128- @staticmethod
129- def _load_data_module_labels (classes_file_path : str ) -> dict [str , int ]:
151+ def _load_data_module_labels (self ) -> dict [str , int ]:
130152 """
131153 Loads class labels from the classes.txt file and sets internal label mapping.
132154
133155 Raises:
134156 FileNotFoundError: If the expected classes.txt file is not found.
135157 """
136- classes_file_path = Path (classes_file_path )
158+ classes_file_path = Path (self . _data_processed_dir_main ) / "classes.txt"
137159 if not classes_file_path .exists ():
138160 raise FileNotFoundError (f"{ classes_file_path } does not exist" )
139161 print (f"Loading { classes_file_path } ...." )
@@ -197,14 +219,13 @@ def _controller(
197219 Returns:
198220 Dict[str, torch.Tensor]: Predictions or confidence scores.
199221 """
200- pass
201222
202223 @abstractmethod
203224 def _consolidator (
204225 self ,
226+ * ,
205227 pred_conf_dict : Dict [str , torch .Tensor ],
206228 model_props : Dict [str , torch .Tensor ],
207- * ,
208229 true_scores : torch .Tensor ,
209230 false_scores : torch .Tensor ,
210231 ** kwargs : Any ,
@@ -214,7 +235,6 @@ def _consolidator(
214235
215236 Should update the provided `true_scores` and `false_scores`.
216237 """
217- pass
218238
219239 @abstractmethod
220240 def _consolidate_on_finish (
@@ -226,4 +246,3 @@ def _consolidate_on_finish(
226246 Returns:
227247 torch.Tensor: Final aggregated predictions.
228248 """
229- pass
0 commit comments