@@ -61,12 +61,9 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]):
6161 AttributeError: If required keys are missing in the configuration.
6262 ValueError: If there are duplicate model paths or class paths.
6363 """
64- path_set = set ()
65- class_set = set ()
66- labels_set = set ()
64+ path_set , class_set , labels_set = set (), set (), set ()
6765
68- sets_ = {"path" : path_set , "class" : class_set , "labels" : labels_set }
69- required_keys = {"class_path" , "ckpt_path" }
66+ required_keys = {"class_path" , "ckpt_path" , "labels_path" }
7067
7168 for model_name , config in model_configs .items ():
7269 missing_keys = required_keys - config .keys ()
@@ -78,37 +75,26 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]):
7875
7976 model_path = config ["ckpt_path" ]
8077 class_path = config ["class_path" ]
78+ labels_path = config ["labels_path" ]
8179
8280 if model_path in path_set :
8381 raise ValueError (
84- f"Duplicate model path detected: '{ model_path } '. Each model must have a unique path."
82+ f"Duplicate model path detected: '{ model_path } '. Each model must have a unique model-checkpoint path."
8583 )
8684
87- if class_path not in class_set :
85+ if class_path in class_set :
8886 raise ValueError (
89- f"Duplicate class path detected: '{ class_path } '. Each model must have a unique path."
87+ f"Duplicate class path detected: '{ class_path } '. Each model must have a unique class path."
88+ )
89+
90+ if labels_path in labels_set :
91+ raise ValueError (
92+ f"Duplicate labels path: { labels_path } . Each model must have unique labels path."
9093 )
9194
9295 path_set .add (model_path )
9396 class_set .add (class_path )
94-
95- cls ._extra_validation (model_name , config , sets_ )
96-
97- @classmethod
98- def _extra_validation (
99- cls , model_name : str , config : Dict [str , Any ], sets_ : Dict [str , set ]
100- ):
101- """
102- Perform extra validation on the model configuration, if necessary.
103-
104- This method can be extended by subclasses to add additional validation logic.
105-
106- Args:
107- model_name (str): The name of the model.
108- config (Dict[str, Any]): The configuration dictionary for the model.
109- sets_ (Dict[str, set]): A dictionary of sets to track model paths, class paths, and labels.
110- """
111- pass
97+ labels_path .add (labels_path )
11298
11399 def _load_ensemble_models (self ):
114100 """
@@ -122,6 +108,7 @@ def _load_ensemble_models(self):
122108 for model_name in self .model_configs :
123109 model_ckpt_path = self .model_configs [model_name ]["ckpt_path" ]
124110 model_class_path = self .model_configs [model_name ]["class_path" ]
111+ model_labels_path = self .model_configs [model_name ]["labels_path" ]
125112 if not os .path .exists (model_ckpt_path ):
126113 raise FileNotFoundError (
127114 f"Model path '{ model_ckpt_path } ' for '{ model_name } ' does not exist."
@@ -134,10 +121,15 @@ def _load_ensemble_models(self):
134121 module = importlib .import_module (module_path )
135122 lightning_cls : LightningModule = getattr (module , class_name )
136123
137- model = lightning_cls .load_from_checkpoint (model_ckpt_path )
124+ model = lightning_cls .load_from_checkpoint (
125+ model_ckpt_path , input_dim = self .input_dim
126+ )
138127 model .eval ()
139128 model .freeze ()
140129 self .models [model_name ] = model
130+ self .models_configs [model_name ]["labels" ] = self ._load_model_labels (
131+ model_labels_path
132+ )
141133
142134 except ModuleNotFoundError :
143135 print (f"Module '{ module_path } ' not found!" )
@@ -149,21 +141,37 @@ def _load_ensemble_models(self):
149141 f"Failed to load model '{ model_name } ' from { model_ckpt_path } : \n { e } "
150142 )
151143
152- def _load_data_module_labels (self ):
153- """
154- Loads the label mapping from the classes.txt file for loaded data.
144+ @staticmethod
145+ def _load_model_labels (labels_path : str ) -> Dict [str , float ]:
146+ if not os .path .exists (labels_path ):
147+ raise FileNotFoundError (f"{ labels_path } does not exist." )
155148
156- Raises:
157- FileNotFoundError: If the classes.txt file does not exist.
158- """
159- classes_txt_file = os .path .join (self .data_processed_dir_main , "classes.txt" )
160- if not os .path .exists (classes_txt_file ):
161- raise FileNotFoundError (f"{ classes_txt_file } does not exist" )
162- else :
163- with open (classes_txt_file , "r" ) as f :
164- for line in f :
165- if line .strip () not in self .dm_labels :
166- self .dm_labels [line .strip ()] = len (self .dm_labels )
149+ if not labels_path .endswith (".json" ):
150+ raise TypeError (f"{ labels_path } is not a JSON file." )
151+
152+ with open (labels_path , "r" ) as f :
153+ model_labels = json .load (f )
154+
155+ labels_dict = {}
156+ for label , label_dict in model_labels .items ():
157+ msg = f"for label { label } "
158+ if "TPV" not in label_dict .keys () or "FPV" not in label_dict .keys ():
159+ raise AttributeError (f"Missing keys 'TPV' and/or 'FPV' { msg } " )
160+
161+ # Validate 'tpv' and 'fpv' are either floats or convertible to float
162+ for key in ["TPV" , "FPV" ]:
163+ try :
164+ value = float (label_dict [key ])
165+ if value < 0 :
166+ raise ValueError (
167+ f"'{ key } ' must be non-negative but got { value } { msg } "
168+ )
169+ except (TypeError , ValueError ):
170+ raise ValueError (
171+ f"'{ key } ' must be a float or convertible to float, but got { label_dict [key ]} { msg } "
172+ )
173+ labels_dict [label ][key ] = value
174+ return labels_dict
167175
168176 @abstractmethod
169177 def _get_prediction_and_labels (
@@ -182,6 +190,12 @@ def _get_prediction_and_labels(
182190 """
183191 pass
184192
193+ def controller (self ):
194+ pass
195+
196+ def consolidator (self ):
197+ pass
198+
185199
186200class ChebiEnsemble (_EnsembleBase ):
187201 """
@@ -212,56 +226,6 @@ def __init__(self, model_configs: Dict[str, Dict], **kwargs):
212226 self ._num_models_per_label : Optional [torch .Tensor ] = None
213227 self ._generate_model_label_mask ()
214228
215- @classmethod
216- def _extra_validation (
217- cls , model_name : str , config : Dict [str , Any ], sets_ : Dict [str , set ]
218- ):
219- """
220- Additional validation for the ensemble model configuration.
221-
222- Args:
223- model_name (str): The model name.
224- config (Dict[str, Any]): The configuration dictionary.
225- sets_ (Dict[str, set]): The set of paths for labels.
226-
227- Raises:
228- AttributeError: If the 'labels_path' key is missing.
229- ValueError: If the 'labels_path' contains duplicate entries or certain are not convertible to float.
230- """
231- if "labels_path" not in config :
232- raise AttributeError ("Missing 'labels_path' key in config!" )
233-
234- labels_path = config ["labels_path" ]
235- if labels_path not in sets_ ["labels" ]:
236- raise ValueError (
237- f"Duplicate labels path detected: '{ labels_path } '. Each model must have a unique path."
238- )
239-
240- sets_ ["labels" ].add (labels_path )
241-
242- with open (labels_path , "r" ) as f :
243- model_labels = json .load (f )
244-
245- for label , label_dict in model_labels .items ():
246-
247- if "TPV" not in label_dict .keys () or "FPV" not in label_dict .keys ():
248- raise AttributeError (
249- f"Missing keys 'TPV' and/or 'FPV' in model '{ model_name } ' configuration."
250- )
251-
252- # Validate 'tpv' and 'fpv' are either floats or convertible to float
253- for key in ["TPV" , "FPV" ]:
254- try :
255- value = float (label_dict [key ])
256- if value < 0 :
257- raise ValueError (
258- f"'{ key } ' in model '{ model_name } ' and label '{ label } ' must be non-negative, but got { value } ."
259- )
260- except (TypeError , ValueError ):
261- raise ValueError (
262- f"'{ key } ' in model '{ model_name } ' and label '{ label } ' must be a float or convertible to float, but got { label_dict [key ]} ."
263- )
264-
265229 def _generate_model_label_mask (self ):
266230 """
267231 Generates a mask indicating the labels handled by each model, and retrieves corresponding the TPV and FPV values
0 commit comments