Skip to content

Commit bc6e131

Browse files
committed
add consolidator
1 parent 37d46f7 commit bc6e131

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

chebai/ensemble/consolidator.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from abc import ABC
2+
3+
from chebai.ensemble.base import EnsembleBase
4+
5+
6+
class WeightedMajorityVoting(EnsembleBase, ABC):
7+
def _consolidator(
8+
self, pred_conf_dict, model_props, *, true_scores, false_scores, **kwargs
9+
):
10+
tpv = model_props["tpv_tensor"]
11+
npv = model_props["fpv_tensor"]
12+
conf = pred_conf_dict["confidence"]
13+
14+
# Determine which classes the model provides predictions for
15+
mask = model_props["mask"]
16+
weight = conf * (tpv * conf + npv * (1 - conf))
17+
18+
# Apply mask: Only update scores for valid classes
19+
true_scores += weight * conf * mask
20+
false_scores += weight * (1 - conf) * mask
21+
22+
def _consolidate_on_finish(self, *, true_scores, false_scores):
23+
# Avoid division by zero: Set valid_counts to 1 where it's zero
24+
valid_counts = self._num_models_per_label.clamp(min=1)
25+
26+
# Normalize by valid contributions to prevent bias
27+
final_preds = (true_scores / valid_counts) > (false_scores / valid_counts)
28+
return final_preds
29+
30+
31+
class MajorityVoting(EnsembleBase, ABC):
32+
def _consolidator(
33+
self, pred_conf_dict, model_props, *, true_scores, false_scores, **kwargs
34+
):
35+
conf = pred_conf_dict["confidence"]
36+
37+
# Determine which classes the model provides predictions for
38+
mask = model_props["mask"]
39+
# Apply mask: Only update scores for valid classes
40+
true_scores += conf * mask
41+
false_scores += (1 - conf) * mask
42+
43+
def _consolidate_on_finish(self, *, true_scores, false_scores):
44+
return true_scores > false_scores

0 commit comments

Comments
 (0)