@@ -19,13 +19,19 @@ def __init__(self, model_configs: Dict[str, ModelConfig], **kwargs):
1919
2020 for model_name in self .model_configs :
2121 model_path = self .model_configs [model_name ]["path" ]
22- if os .path .exists (model_path ):
22+ if not os .path .exists (model_path ):
23+ raise FileNotFoundError (
24+ f"Model path '{ model_path } ' for '{ model_name } ' does not exist."
25+ )
26+
27+ # Attempt to load the model to check validity
28+ try :
2329 self .models [model_name ] = Electra .load_from_checkpoint (
24- model_path , map_location = "cpu"
30+ model_path , map_location = self . device
2531 )
26- else :
27- raise FileNotFoundError (
28- f"Model { model_name } does not exist in the given path { model_path } "
32+ except Exception as e :
33+ raise RuntimeError (
34+ f"Failed to load model ' { model_name } ' from { model_path } : { e } "
2935 )
3036
3137 for model in self .models .values ():
@@ -70,10 +76,6 @@ def _validate_model_configs(cls, model_configs: Dict[str, ModelConfig]):
7076 )
7177
7278 model_path = config ["path" ]
73- if not os .path .exists (model_path ):
74- raise FileNotFoundError (
75- f"Model path '{ model_path } ' for '{ model_name } ' does not exist."
76- )
7779
7880 # if model_path in path_set:
7981 # raise ValueError(
@@ -100,14 +102,13 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
100102 confidences = {}
101103 total_logits = torch .zeros (
102104 data ["labels" ].shape [0 ], data ["labels" ].shape [1 ], device = self .device
103- ). to ( self . device )
105+ )
104106
105107 for name , model in self .models .items ():
106108 output = model (data )
107- confidences [name ] = torch .sigmoid (output ["logits" ])
108- predictions [name ] = (
109- torch .sigmoid (output ["logits" ]) > 0.5
110- ).long () # Multi-label classification
109+ sigmoid_logits = torch .sigmoid (output ["logits" ])
110+ confidences [name ] = sigmoid_logits
111+ predictions [name ] = (sigmoid_logits > 0.5 ).long ()
111112 total_logits += output ["logits" ]
112113
113114 return {
@@ -211,21 +212,18 @@ def _execute(
211212 def aggregate_predictions (self , predictions , confidences ):
212213 """Implements weighted voting based on trustworthiness."""
213214 batch_size , num_classes = list (predictions .values ())[0 ].shape
214-
215215 true_scores = torch .zeros (batch_size , num_classes , device = self .device )
216216 false_scores = torch .zeros (batch_size , num_classes , device = self .device )
217217
218218 for model , preds in predictions .items ():
219219 tpv = float (self .model_configs [model ]["TPV" ])
220220 npv = float (self .model_configs [model ]["FPV" ])
221-
222- confidence = confidences [model ]
223- weight = confidence * (tpv * preds + npv * (1 - preds ))
221+ weight = confidences [model ] * (tpv * preds + npv * (1 - preds ))
224222
225223 true_scores += weight * preds
226224 false_scores += weight * (1 - preds )
227225
228- return (true_scores > false_scores ).long () # Final class decision
226+ return (true_scores > false_scores ).long ()
229227
230228 def _process_for_loss (
231229 self ,
@@ -264,11 +262,7 @@ def __init__(self, model_configs: Dict[str, Dict], **kwargs):
264262 self .ffn : FFN = FFN (** ffn_kwargs )
265263
266264 def forward (self , data : Dict [str , Tensor ], ** kwargs : Any ) -> Dict [str , Any ]:
267- logits_list = []
268- for name , model in self .models .items ():
269- output = model (data )
270- logits_list .append (output ["logits" ])
271-
265+ logits_list = [model (data )["logits" ] for model in self .models .values ()]
272266 return self .ffn ({"features" : torch .cat (logits_list , dim = 1 )})
273267
274268 def _get_prediction_and_labels (
0 commit comments