77
88import torch
99from lightning import LightningModule
10- from lightning_utilities .core .rank_zero import rank_zero_info
1110
1211from chebai .models import ChebaiBaseNet
1312from chebai .preprocessing .structures import XYData
1413from chebai .result .classification import print_metrics
1514
16- from ._constants import *
15+ from ._constants import (
16+ MODEL_CKPT_PATH ,
17+ MODEL_CLS_PATH ,
18+ MODEL_LBL_PATH ,
19+ READER_CLS_PATH ,
20+ WRAPPER_CLS_PATH ,
21+ )
22+ from ._utils import _load_class
23+ from ._wrappers import BaseWrapper
1724
1825
1926class EnsembleBase (ABC ):
@@ -41,18 +48,17 @@ def __init__(
4148 if bool (kwargs .get ("_validate_configs" , True )):
4249 self ._validate_model_configs (model_configs )
4350
44- self .model_configs : Dict [str , Dict [str , Any ]] = model_configs
45- self .data_processed_dir_main : str = data_processed_dir_main
46- self .input_dim : Optional [int ] = kwargs .get ("input_dim" , None )
51+ self ._model_configs : Dict [str , Dict [str , Any ]] = model_configs
52+ self ._data_processed_dir_main : str = data_processed_dir_main
53+ self ._input_dim : Optional [int ] = kwargs .get ("input_dim" , None )
54+ self ._total_data_size : int = len (self ._collated_data )
4755
4856 self ._device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
49- self ._num_of_labels : Optional [int ] = (
50- None # will be set by `_load_data_module_labels` method
51- )
57+
5258 self ._models : Dict [str , LightningModule ] = {}
53- self ._dm_labels : Dict [str , int ] = {}
59+ self ._dm_labels : Dict [str , int ] = self ._load_data_module_labels ()
60+ self ._num_of_labels : int = len (self ._dm_labels )
5461
55- self ._load_data_module_labels ()
5662 self ._num_models_per_label : torch .Tensor = torch .zeros (
5763 1 , self ._num_of_labels , device = self ._device
5864 )
@@ -72,13 +78,11 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> No
7278 AttributeError: If any model config is missing required keys.
7379 ValueError: If duplicate paths are found for model checkpoint, class, or labels.
7480 """
75- path_set , class_set , labels_set = set (), set (), set ()
81+ class_set , labels_set = set (), set ()
7682 required_keys = {
77- MODEL_CKPT_PATH ,
7883 MODEL_CLS_PATH ,
7984 MODEL_LBL_PATH ,
8085 WRAPPER_CLS_PATH ,
81- READER_CLS_PATH ,
8286 }
8387
8488 for model_name , config in model_configs .items ():
@@ -88,44 +92,41 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict[str, Any]]) -> No
8892 f"Missing keys { missing_keys } in model '{ model_name } ' configuration."
8993 )
9094
91- model_ckpt_path , model_class_path , model_labels_path = (
92- config [MODEL_CKPT_PATH ],
95+ model_class_path , model_labels_path = (
9396 config [MODEL_CLS_PATH ],
9497 config [MODEL_LBL_PATH ],
9598 )
9699
97- if model_ckpt_path in path_set :
98- raise ValueError (f"Duplicate model path detected: '{ model_ckpt_path } '." )
99100 if model_class_path in class_set :
100101 raise ValueError (
101102 f"Duplicate class path detected: '{ model_class_path } '."
102103 )
103104 if model_labels_path in labels_set :
104105 raise ValueError (f"Duplicate labels path: { model_labels_path } ." )
105106
106- path_set .add (model_ckpt_path )
107107 class_set .add (model_class_path )
108108 labels_set .add (model_labels_path )
109109
110- def _load_data_module_labels (self ) -> None :
110+ def _load_data_module_labels (self ) -> dict [ str , int ] :
111111 """
112112 Loads class labels from the classes.txt file and sets internal label mapping.
113113
114114 Raises:
115115 FileNotFoundError: If the expected classes.txt file is not found.
116116 """
117- classes_txt_file = os .path .join (self .data_processed_dir_main , "classes.txt" )
118- rank_zero_info (f"Loading { classes_txt_file } ...." )
117+ classes_txt_file = os .path .join (self ._data_processed_dir_main , "classes.txt" )
118+ print (f"Loading { classes_txt_file } ...." )
119119
120120 if not os .path .exists (classes_txt_file ):
121121 raise FileNotFoundError (f"{ classes_txt_file } does not exist" )
122122
123+ dm_labels_dict = {}
123124 with open (classes_txt_file , "r" ) as f :
124125 for line in f :
125126 label = line .strip ()
126- if label not in self . _dm_labels :
127- self . _dm_labels [label ] = len (self . _dm_labels )
128- self . _num_of_labels = len ( self . _dm_labels )
127+ if label not in dm_labels_dict :
128+ dm_labels_dict [label ] = len (dm_labels_dict )
129+ return dm_labels_dict
129130
130131 def run_ensemble (self ) -> None :
131132 """
@@ -140,22 +141,20 @@ def run_ensemble(self) -> None:
140141
141142 while self ._model_queue :
142143 model_name = self ._model_queue .popleft ()
143- rank_zero_info (f"Processing model: { model_name } " )
144- model , model_props = self ._load_model_and_its_props (model_name )
144+ print (f"Processing model: { model_name } " )
145145
146- rank_zero_info ("\t Passing model to controller to generate predictions..." )
147- pred_conf_dict = self ._controller (model , model_props )
148- del model # Model can be huge to keep it in memory, delete as no longer needed
146+ print ("\t Passing model to controller to generate predictions..." )
147+ pred_conf_dict , model_props = self ._controller (model_name )
149148
150- rank_zero_info ("\t Passing predictions to consolidator for aggregation..." )
149+ print ("\t Passing predictions to consolidator for aggregation..." )
151150 self ._consolidator (
152151 pred_conf_dict ,
153152 model_props ,
154153 true_scores = true_scores ,
155154 false_scores = false_scores ,
156155 )
157156
158- rank_zero_info (f"Consolidating predictions for { self .__class__ .__name__ } " )
157+ print (f"Consolidating predictions for { self .__class__ .__name__ } " )
159158 final_preds = self ._consolidate_on_finish (
160159 true_scores = true_scores , false_scores = false_scores
161160 )
0 commit comments