@@ -61,7 +61,9 @@ def __init__(self, model_configs: Dict[str, Dict], **kwargs):
6161 def _validate_model_configs (cls , model_configs : Dict [str , Dict ]):
6262 path_set = set ()
6363 class_set = set ()
64+ labels_set = set ()
6465
66+ sets_ = {"path" : path_set , "class" : class_set , "labels" : labels_set }
6567 required_keys = {"class_path" , "ckpt_path" }
6668
6769 for model_name , config in model_configs .items ():
@@ -88,10 +90,12 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]):
8890 path_set .add (model_path )
8991 class_set .add (class_path )
9092
91- cls ._extra_validation (model_name , config )
93+ cls ._extra_validation (model_name , config , sets_ )
9294
9395 @classmethod
94- def _extra_validation (cls , model_name : str , config : Dict [str , Any ]):
96+ def _extra_validation (
97+ cls , model_name : str , config : Dict [str , Any ], sets_ : Dict [str , set ]
98+ ):
9599 pass
96100
97101 @abstractmethod
@@ -110,9 +114,23 @@ def __init__(self, model_configs: Dict[str, Dict], **kwargs):
110114
111115 # Add a dummy trainable parameter
112116 self .dummy_param = torch .nn .Parameter (torch .randn (1 , requires_grad = True ))
117+ self ._num_models_per_label : Optional [torch .Tensor ] = None
118+ self ._generate_model_label_mask ()
113119
114120 @classmethod
115- def _extra_validation (cls , model_name : str , config : Dict [str , Any ]):
121+ def _extra_validation (
122+ cls , model_name : str , config : Dict [str , Any ], sets_ : Dict [str , set ]
123+ ):
124+
125+ if "labels_path" not in config :
126+ raise AttributeError ("Missing 'labels_path' key in config!" )
127+
128+ labels_path = config ["labels_path" ]
129+ # if labels_path not in sets_["labels"]:
130+ # raise ValueError(
131+ # f"Duplicate labels path detected: '{labels_path}'. Each model must have a unique path."
132+ # )
133+ sets_ ["labels" ].add (labels_path )
116134
117135 if "TPV" not in config .keys () or "FPV" not in config .keys ():
118136 raise AttributeError (
@@ -132,19 +150,62 @@ def _extra_validation(cls, model_name: str, config: Dict[str, Any]):
132150 f"'{ key } ' in model '{ model_name } ' must be a float or convertible to float, but got { config [key ]} ."
133151 )
134152
153+ def _generate_model_label_mask (self ):
154+ labels_dict = {}
155+ num_models_per_label = torch .zeros (1 , self .out_dim , device = self .device )
156+ for model_name , model_config in self .model_configs .items ():
157+ labels_path = model_config ["labels_path" ]
158+ if not os .path .exists (labels_path ):
159+ raise FileNotFoundError (f"Labels path '{ labels_path } ' does not exist." )
160+
161+ with open (labels_path , "r" ) as f :
162+ labels_list = [int (line .strip ()) for line in f ]
163+
164+ model_label_indices = []
165+ for label in labels_list :
166+ if label not in labels_dict :
167+ labels_dict [label ] = len (labels_dict )
168+
169+ model_label_indices .append (labels_dict [label ])
170+
171+ # Create masks to apply predictions only to known classes
172+ mask = torch .zeros (self .out_dim , device = self .device , dtype = torch .bool )
173+ mask [
174+ torch .tensor (model_label_indices , dtype = torch .int , device = self .device )
175+ ] = True
176+
177+ self .model_configs [model_name ]["labels_mask" ] = mask
178+ num_models_per_label += mask
179+
180+ self ._num_models_per_label = num_models_per_label
181+
135182 def forward (self , data : Dict [str , Tensor ], ** kwargs : Any ) -> Dict [str , Any ]:
136183 predictions = {}
137184 confidences = {}
185+
186+ assert data ["labels" ].shape [1 ] == self .out_dim
187+
188+ # Initialize total_logits with zeros
138189 total_logits = torch .zeros (
139- data ["labels" ].shape [0 ], data [ "labels" ]. shape [ 1 ] , device = self .device
190+ data ["labels" ].shape [0 ], self . out_dim , device = self .device
140191 )
141192
142193 for name , model in self .models .items ():
143194 output = model (data )
195+ mask = self .model_configs [name ]["labels_mask" ]
196+
197+ # Consider logits and confidence only for valid classes
144198 sigmoid_logits = torch .sigmoid (output ["logits" ])
145- confidences [name ] = sigmoid_logits
146- predictions [name ] = (sigmoid_logits > 0.5 ).long ()
147- total_logits += output ["logits" ]
199+ prediction = torch .full_like (total_logits , - 1 , dtype = torch .bool )
200+ confidence = torch .full_like (total_logits , - 1 , dtype = torch .float )
201+ prediction [:, mask ] = sigmoid_logits > 0.5
202+ confidence [:, mask ] = sigmoid_logits
203+
204+ predictions [name ] = prediction
205+ confidences [name ] = confidence
206+ total_logits += output [
207+ "logits"
208+ ] # Don't play a role here, just for lightning flow completeness
148209
149210 return {
150211 "logits" : total_logits ,
@@ -250,15 +311,25 @@ def aggregate_predictions(self, predictions, confidences):
250311 true_scores = torch .zeros (batch_size , num_classes , device = self .device )
251312 false_scores = torch .zeros (batch_size , num_classes , device = self .device )
252313
253- for model , preds in predictions .items ():
314+ for model , conf in confidences .items ():
254315 tpv = float (self .model_configs [model ]["TPV" ])
255316 npv = float (self .model_configs [model ]["FPV" ])
256- weight = confidences [model ] * (tpv * preds + npv * (1 - preds ))
257317
258- true_scores += weight * preds
259- false_scores += weight * (1 - preds )
318+ # Determine which classes the model provides predictions for
319+ mask = self .model_configs [model ]["labels_mask" ]
320+ weight = conf * (tpv * conf + npv * (1 - conf ))
321+
322+ # Apply mask: Only update scores for valid classes
323+ true_scores += weight * conf * mask
324+ false_scores += weight * (1 - conf ) * mask
325+
326+ # Avoid division by zero: Set valid_counts to 1 where it's zero
327+ valid_counts = self ._num_models_per_label .clamp (min = 1 )
328+
329+ # Normalize by valid contributions to prevent bias, this step can be optional depending upon scenario
330+ final_preds = (true_scores / valid_counts ) > (false_scores / valid_counts )
260331
261- return ( true_scores > false_scores ). long ()
332+ return final_preds
262333
263334 def _process_for_loss (
264335 self ,
0 commit comments