|
| 1 | +from abc import ABC |
| 2 | + |
| 3 | +from chebai.ensemble.base import EnsembleBase |
| 4 | + |
| 5 | + |
| 6 | +class WeightedMajorityVoting(EnsembleBase, ABC): |
| 7 | + def _consolidator( |
| 8 | + self, pred_conf_dict, model_props, *, true_scores, false_scores, **kwargs |
| 9 | + ): |
| 10 | + tpv = model_props["tpv_tensor"] |
| 11 | + npv = model_props["fpv_tensor"] |
| 12 | + conf = pred_conf_dict["confidence"] |
| 13 | + |
| 14 | + # Determine which classes the model provides predictions for |
| 15 | + mask = model_props["mask"] |
| 16 | + weight = conf * (tpv * conf + npv * (1 - conf)) |
| 17 | + |
| 18 | + # Apply mask: Only update scores for valid classes |
| 19 | + true_scores += weight * conf * mask |
| 20 | + false_scores += weight * (1 - conf) * mask |
| 21 | + |
| 22 | + def _consolidate_on_finish(self, *, true_scores, false_scores): |
| 23 | + # Avoid division by zero: Set valid_counts to 1 where it's zero |
| 24 | + valid_counts = self._num_models_per_label.clamp(min=1) |
| 25 | + |
| 26 | + # Normalize by valid contributions to prevent bias |
| 27 | + final_preds = (true_scores / valid_counts) > (false_scores / valid_counts) |
| 28 | + return final_preds |
| 29 | + |
| 30 | + |
| 31 | +class MajorityVoting(EnsembleBase, ABC): |
| 32 | + def _consolidator( |
| 33 | + self, pred_conf_dict, model_props, *, true_scores, false_scores, **kwargs |
| 34 | + ): |
| 35 | + conf = pred_conf_dict["confidence"] |
| 36 | + |
| 37 | + # Determine which classes the model provides predictions for |
| 38 | + mask = model_props["mask"] |
| 39 | + # Apply mask: Only update scores for valid classes |
| 40 | + true_scores += conf * mask |
| 41 | + false_scores += (1 - conf) * mask |
| 42 | + |
| 43 | + def _consolidate_on_finish(self, *, true_scores, false_scores): |
| 44 | + return true_scores > false_scores |
0 commit comments