Skip to content
57 changes: 57 additions & 0 deletions ignite/contrib/metrics/ExpectedCalibrationError.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, let's put it into ignite/metrics/ExpectedCalibrationError.py instead of ignite/contrib/metrics/ExpectedCalibrationError.py


from ignite.exceptions import NotComputableError
from ignite.metrics import Metric


class ExpectedCalibrationError(Metric):
def __init__(self, num_bins=10, device=None):
super(ExpectedCalibrationError, self).__init__()
self.num_bins = num_bins
self.device = device
self.reset()

def reset(self):
self.confidences = torch.tensor([], device=self.device)
self.corrects = torch.tensor([], device=self.device)

def update(self, output):
y_pred, y = output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually call .detach on both to stop grad computation like here:

y_pred, y = output[0].detach(), output[1].detach()


assert y_pred.dim() == 2 and y_pred.shape[1] == 2, "This metric is for binary classification."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use the following way to raise errors instead of assert:

if not (y_pred.dim() == 2 and y_pred.shape[1] == 2):
    raise ValueError("This metric is for binary classification")

To assert if the input is binary we were doing previously something like here:

def _check_binary_multilabel_cases(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
if not torch.equal(y, y**2):
raise ValueError("For binary cases, y must be comprised of 0's and 1's.")
if not torch.equal(y_pred, y_pred**2):
raise ValueError("For binary cases, y_pred must be comprised of 0's and 1's.")
def _check_type(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output
if y.ndimension() + 1 == y_pred.ndimension():
num_classes = y_pred.shape[1]
if num_classes == 1:
update_type = "binary"
self._check_binary_multilabel_cases((y_pred, y))


softmax_probs = torch.softmax(y_pred, dim=1)
max_probs, predicted_class = torch.max(softmax_probs, dim=1)

self.confidences = torch.cat((self.confidences, max_probs))
self.corrects = torch.cat((self.corrects, predicted_class == y))

def compute(self):
if self.confidences.numel() == 0:
raise NotComputableError(
"ExpectedCalibrationError must have at least one example before it can be computed."
)

bin_edges = torch.linspace(0, 1, self.num_bins + 1, device=self.device)

bin_indices = torch.searchsorted(bin_edges, self.confidences)

ece = 0.0
bin_sizes = torch.zeros(self.num_bins, device=self.device)
bin_accuracies = torch.zeros(self.num_bins, device=self.device)

for i in range(self.num_bins):
mask = bin_indices == i
bin_confidences = self.confidences[mask]
bin_corrects = self.corrects[mask]

accuracy = torch.mean(bin_corrects)

avg_confidence = torch.mean(bin_confidences)

bin_size = bin_confidences.numel()
ece += (bin_size / len(self.confidences)) * abs(accuracy - avg_confidence)
bin_sizes[i] = bin_size
bin_accuracies[i] = accuracy

return ece