Skip to content

Commit 82a96dc

Browse files
committed
ensemble: changes for out of scope labels for certain models
1 parent 72a6b37 commit 82a96dc

File tree

1 file changed

+83
-12
lines changed

1 file changed

+83
-12
lines changed

chebai/models/ensemble.py

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def __init__(self, model_configs: Dict[str, Dict], **kwargs):
6161
def _validate_model_configs(cls, model_configs: Dict[str, Dict]):
6262
path_set = set()
6363
class_set = set()
64+
labels_set = set()
6465

66+
sets_ = {"path": path_set, "class": class_set, "labels": labels_set}
6567
required_keys = {"class_path", "ckpt_path"}
6668

6769
for model_name, config in model_configs.items():
@@ -88,10 +90,12 @@ def _validate_model_configs(cls, model_configs: Dict[str, Dict]):
8890
path_set.add(model_path)
8991
class_set.add(class_path)
9092

91-
cls._extra_validation(model_name, config)
93+
cls._extra_validation(model_name, config, sets_)
9294

9395
@classmethod
94-
def _extra_validation(cls, model_name: str, config: Dict[str, Any]):
96+
def _extra_validation(
97+
cls, model_name: str, config: Dict[str, Any], sets_: Dict[str, set]
98+
):
9599
pass
96100

97101
@abstractmethod
@@ -110,9 +114,23 @@ def __init__(self, model_configs: Dict[str, Dict], **kwargs):
110114

111115
# Add a dummy trainable parameter
112116
self.dummy_param = torch.nn.Parameter(torch.randn(1, requires_grad=True))
117+
self._num_models_per_label: Optional[torch.Tensor] = None
118+
self._generate_model_label_mask()
113119

114120
@classmethod
115-
def _extra_validation(cls, model_name: str, config: Dict[str, Any]):
121+
def _extra_validation(
122+
cls, model_name: str, config: Dict[str, Any], sets_: Dict[str, set]
123+
):
124+
125+
if "labels_path" not in config:
126+
raise AttributeError("Missing 'labels_path' key in config!")
127+
128+
labels_path = config["labels_path"]
129+
# if labels_path not in sets_["labels"]:
130+
# raise ValueError(
131+
# f"Duplicate labels path detected: '{labels_path}'. Each model must have a unique path."
132+
# )
133+
sets_["labels"].add(labels_path)
116134

117135
if "TPV" not in config.keys() or "FPV" not in config.keys():
118136
raise AttributeError(
@@ -132,19 +150,62 @@ def _extra_validation(cls, model_name: str, config: Dict[str, Any]):
132150
f"'{key}' in model '{model_name}' must be a float or convertible to float, but got {config[key]}."
133151
)
134152

153+
def _generate_model_label_mask(self):
154+
labels_dict = {}
155+
num_models_per_label = torch.zeros(1, self.out_dim, device=self.device)
156+
for model_name, model_config in self.model_configs.items():
157+
labels_path = model_config["labels_path"]
158+
if not os.path.exists(labels_path):
159+
raise FileNotFoundError(f"Labels path '{labels_path}' does not exist.")
160+
161+
with open(labels_path, "r") as f:
162+
labels_list = [int(line.strip()) for line in f]
163+
164+
model_label_indices = []
165+
for label in labels_list:
166+
if label not in labels_dict:
167+
labels_dict[label] = len(labels_dict)
168+
169+
model_label_indices.append(labels_dict[label])
170+
171+
# Create masks to apply predictions only to known classes
172+
mask = torch.zeros(self.out_dim, device=self.device, dtype=torch.bool)
173+
mask[
174+
torch.tensor(model_label_indices, dtype=torch.int, device=self.device)
175+
] = True
176+
177+
self.model_configs[model_name]["labels_mask"] = mask
178+
num_models_per_label += mask
179+
180+
self._num_models_per_label = num_models_per_label
181+
135182
def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
136183
predictions = {}
137184
confidences = {}
185+
186+
assert data["labels"].shape[1] == self.out_dim
187+
188+
# Initialize total_logits with zeros
138189
total_logits = torch.zeros(
139-
data["labels"].shape[0], data["labels"].shape[1], device=self.device
190+
data["labels"].shape[0], self.out_dim, device=self.device
140191
)
141192

142193
for name, model in self.models.items():
143194
output = model(data)
195+
mask = self.model_configs[name]["labels_mask"]
196+
197+
# Consider logits and confidence only for valid classes
144198
sigmoid_logits = torch.sigmoid(output["logits"])
145-
confidences[name] = sigmoid_logits
146-
predictions[name] = (sigmoid_logits > 0.5).long()
147-
total_logits += output["logits"]
199+
prediction = torch.full_like(total_logits, -1, dtype=torch.bool)
200+
confidence = torch.full_like(total_logits, -1, dtype=torch.float)
201+
prediction[:, mask] = sigmoid_logits > 0.5
202+
confidence[:, mask] = sigmoid_logits
203+
204+
predictions[name] = prediction
205+
confidences[name] = confidence
206+
total_logits += output[
207+
"logits"
208+
] # Don't play a role here, just for lightning flow completeness
148209

149210
return {
150211
"logits": total_logits,
@@ -250,15 +311,25 @@ def aggregate_predictions(self, predictions, confidences):
250311
true_scores = torch.zeros(batch_size, num_classes, device=self.device)
251312
false_scores = torch.zeros(batch_size, num_classes, device=self.device)
252313

253-
for model, preds in predictions.items():
314+
for model, conf in confidences.items():
254315
tpv = float(self.model_configs[model]["TPV"])
255316
npv = float(self.model_configs[model]["FPV"])
256-
weight = confidences[model] * (tpv * preds + npv * (1 - preds))
257317

258-
true_scores += weight * preds
259-
false_scores += weight * (1 - preds)
318+
# Determine which classes the model provides predictions for
319+
mask = self.model_configs[model]["labels_mask"]
320+
weight = conf * (tpv * conf + npv * (1 - conf))
321+
322+
# Apply mask: Only update scores for valid classes
323+
true_scores += weight * conf * mask
324+
false_scores += weight * (1 - conf) * mask
325+
326+
# Avoid division by zero: Set valid_counts to 1 where it's zero
327+
valid_counts = self._num_models_per_label.clamp(min=1)
328+
329+
# Normalize by valid contributions to prevent bias, this step can be optional depending upon scenario
330+
final_preds = (true_scores / valid_counts) > (false_scores / valid_counts)
260331

261-
return (true_scores > false_scores).long()
332+
return final_preds
262333

263334
def _process_for_loss(
264335
self,

0 commit comments

Comments
 (0)