diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 3e41f565879..f9e0ab658a4 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -19,6 +19,12 @@ MulticlassAveragePrecision, MultilabelAveragePrecision, ) +from torchmetrics.classification.brier import ( + BinaryBrier, + Brier, + MulticlassBrier, + MultilabelBrier, +) from torchmetrics.classification.calibration_error import ( BinaryCalibrationError, CalibrationError, diff --git a/src/torchmetrics/classification/brier.py b/src/torchmetrics/classification/brier.py new file mode 100644 index 00000000000..ff31c46f5ba --- /dev/null +++ b/src/torchmetrics/classification/brier.py @@ -0,0 +1,565 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.classification.base import _ClassificationTaskWrapper +from torchmetrics.functional.classification.brier import ( + _binary_brier_format, + _mean_brier_score_and_decomposition, + _multiclass_brier_format, +) +from torchmetrics.functional.classification.confusion_matrix import ( + _binary_confusion_matrix_arg_validation, + _binary_confusion_matrix_format, + _binary_confusion_matrix_tensor_validation, + _binary_confusion_matrix_update, + _multiclass_confusion_matrix_arg_validation, + _multiclass_confusion_matrix_format, + _multiclass_confusion_matrix_tensor_validation, + _multiclass_confusion_matrix_update, + _multilabel_confusion_matrix_arg_validation, + _multilabel_confusion_matrix_compute, + _multilabel_confusion_matrix_format, + _multilabel_confusion_matrix_tensor_validation, + _multilabel_confusion_matrix_update, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _CMAP_TYPE, _PLOT_OUT_TYPE, plot_confusion_matrix + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = [ + "BinaryBrier.plot", + "MulticlassBrier.plot", + "MultilabelBrier.plot", + ] + + +class BinaryBrier(Metric): + r"""Compute the `confusion matrix`_ for binary tasks. + + The confusion matrix :math:`C` is constructed such that :math:`C_{i, j}` is equal to the number of observations + known to be in class :math:`i` but predicted to be in class :math:`j`. Thus row indices of the confusion matrix + correspond to the true class labels and column indices correspond to the predicted class labels. + + For binary tasks, the confusion matrix is a 2x2 matrix with the following structure: + + - :math:`C_{0, 0}`: True negatives + - :math:`C_{0, 1}`: False positives + - :math:`C_{1, 0}`: False negatives + - :math:`C_{1, 1}`: True positives + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating point + tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per + element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``confusion_matrix`` (:class:`~torch.Tensor`): A tensor containing a ``(2, 2)`` matrix + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.classification import BinaryBrier + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> bcm = BinaryBrier() + >>> bcm(preds, target) + tensor([[2, 0], + [1, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import BinaryBrier + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> bcm = BinaryBrier() + >>> bcm(preds, target) + tensor([[2, 0], + [1, 1]]) + + """ + + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + confmat: Tensor + + def __init__( + self, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize) + self.threshold = threshold + self.ignore_index = ignore_index + self.normalize = normalize + self.validate_args = validate_args + + self.add_state("confmat", torch.zeros(2, 2, dtype=torch.long), dist_reduce_fx="sum") + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + if self.validate_args: + _binary_confusion_matrix_tensor_validation(preds, target, self.ignore_index) + + preds_conf, target_conf = _binary_confusion_matrix_format(preds, target, self.threshold, self.ignore_index) + confmat = _binary_confusion_matrix_update(preds_conf, target_conf) + + self.confmat += confmat + + preds_br, target_br = _binary_brier_format(preds, target, self.ignore_index) + self.preds.append(preds_br) + self.target.append(target_br) + + def compute(self) -> dict[str, torch.Tensor]: + """Compute confusion matrix.""" + return _mean_brier_score_and_decomposition( + dim_zero_cat(self.target), dim_zero_cat(self.preds), self.confmat.float() + ) + + def plot( + self, + val: Optional[Tensor] = None, + ax: Optional[_AX_TYPE] = None, + add_text: bool = True, + labels: Optional[list[str]] = None, + cmap: Optional[_CMAP_TYPE] = None, + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + add_text: if the value of each cell should be added to the plot + labels: a list of strings, if provided will be added to the plot to indicate the different classes + cmap: matplotlib colormap to use for the confusion matrix + https://matplotlib.org/stable/users/explain/colors/colormaps.html + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randint + >>> from torchmetrics.classification import MulticlassBrier + >>> metric = MulticlassBrier(num_classes=5) + >>> metric.update(randint(5, (20,)), randint(5, (20,))) + >>> fig_, ax_ = metric.plot() + + """ + val = val if val is not None else self.compute() + if not isinstance(val, Tensor): + raise TypeError(f"Expected val to be a single tensor but got {val}") + fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap) + return fig, ax + + +class MulticlassBrier(Metric): + r"""Compute the `confusion matrix`_ for multiclass tasks. + + The confusion matrix :math:`C` is constructed such that :math:`C_{i, j}` is equal to the number of observations + known to be in class :math:`i` but predicted to be in class :math:`j`. Thus row indices of the confusion matrix + correspond to the true class labels and column indices correspond to the predicted class labels. + + For multiclass tasks, the confusion matrix is a NxN matrix, where: + + - :math:`C_{i, i}` represents the number of true positives for class :math:`i` + - :math:`\sum_{j=1, j\neq i}^N C_{i, j}` represents the number of false negatives for class :math:`i` + - :math:`\sum_{j=1, j\neq i}^N C_{j, i}` represents the number of false positives for class :math:`i` + - the sum of the remaining cells in the matrix represents the number of true negatives for class :math:`i` + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating point + tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per + element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``confusion_matrix``: [num_classes, num_classes] matrix + + Args: + num_classes: Integer specifying the number of classes + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (pred is integer tensor): + >>> from torch import tensor + >>> from torchmetrics.classification import MulticlassBrier + >>> target = tensor([2, 1, 0, 0]) + >>> preds = tensor([2, 1, 0, 1]) + >>> metric = MulticlassBrier(num_classes=3) + >>> metric(preds, target) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + + Example (pred is float tensor): + >>> from torchmetrics.classification import MulticlassBrier + >>> target = tensor([2, 1, 0, 0]) + >>> preds = tensor([[0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13]]) + >>> metric = MulticlassBrier(num_classes=3) + >>> metric(preds, target) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + + """ + + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + confmat: Tensor + + def __init__( + self, + num_classes: int, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["none", "true", "pred", "all"]] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize) + self.num_classes = num_classes + self.ignore_index = ignore_index + self.normalize = normalize + self.validate_args = validate_args + + self.add_state("confmat", torch.zeros(num_classes, num_classes, dtype=torch.long), dist_reduce_fx="sum") + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + if self.validate_args: + _multiclass_confusion_matrix_tensor_validation(preds, target, self.num_classes, self.ignore_index) + preds_conf, target_conf = _multiclass_confusion_matrix_format(preds, target, self.ignore_index) + confmat = _multiclass_confusion_matrix_update(preds_conf, target_conf, self.num_classes) + self.confmat += confmat + + preds_br, target_br = _multiclass_brier_format(preds, target, self.num_classes, self.ignore_index) + self.preds.append(preds_br) + self.target.append(target_br) + + def compute(self) -> dict[str, torch.Tensor]: + """Compute confusion matrix.""" + return _mean_brier_score_and_decomposition( + dim_zero_cat(self.target), dim_zero_cat(self.preds), self.confmat.float() + ) + + def plot( + self, + val: Optional[Tensor] = None, + ax: Optional[_AX_TYPE] = None, + add_text: bool = True, + labels: Optional[list[str]] = None, + cmap: Optional[_CMAP_TYPE] = None, + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + add_text: if the value of each cell should be added to the plot + labels: a list of strings, if provided will be added to the plot to indicate the different classes + cmap: matplotlib colormap to use for the confusion matrix + https://matplotlib.org/stable/users/explain/colors/colormaps.html + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randint + >>> from torchmetrics.classification import MulticlassBrier + >>> metric = MulticlassBrier(num_classes=5) + >>> metric.update(randint(5, (20,)), randint(5, (20,))) + >>> fig_, ax_ = metric.plot() + + """ + val = val if val is not None else self.compute() + if not isinstance(val, Tensor): + raise TypeError(f"Expected val to be a single tensor but got {val}") + fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap) + return fig, ax + + +class MultilabelBrier(Metric): + r"""Compute the `confusion matrix`_ for multilabel tasks. + + The confusion matrix :math:`C` is constructed such that :math:`C_{i, j}` is equal to the number of observations + known to be in class :math:`i` but predicted to be in class :math:`j`. Thus row indices of the confusion matrix + correspond to the true class labels and column indices correspond to the predicted class labels. + + For multilabel tasks, the confusion matrix is a Nx2x2 tensor, where each 2x2 matrix corresponds to the confusion + for that label. The structure of each 2x2 matrix is as follows: + + - :math:`C_{0, 0}`: True negatives + - :math:`C_{0, 1}`: False positives + - :math:`C_{1, 0}`: False negatives + - :math:`C_{1, 1}`: True positives + + As input to 'update' the metric accepts the following input: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + As output of 'compute' the metric returns the following output: + + - ``confusion matrix``: [num_labels,2,2] matrix + + Args: + num_classes: Integer specifying the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torch import tensor + >>> from torchmetrics.classification import MultilabelBrier + >>> target = tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelBrier(num_labels=3) + >>> metric(preds, target) + tensor([[[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[0, 1], [0, 1]]]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MultilabelBrier + >>> target = tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelBrier(num_labels=3) + >>> metric(preds, target) + tensor([[[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[0, 1], [0, 1]]]) + + """ + + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + confmat: Tensor + + def __init__( + self, + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["none", "true", "pred", "all"]] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index, normalize) + self.num_labels = num_labels + self.threshold = threshold + self.ignore_index = ignore_index + self.normalize = normalize + self.validate_args = validate_args + + self.add_state("confmat", torch.zeros(num_labels, 2, 2, dtype=torch.long), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + if self.validate_args: + _multilabel_confusion_matrix_tensor_validation(preds, target, self.num_labels, self.ignore_index) + preds, target = _multilabel_confusion_matrix_format( + preds, target, self.num_labels, self.threshold, self.ignore_index + ) + confmat = _multilabel_confusion_matrix_update(preds, target, self.num_labels) + self.confmat += confmat + + def compute(self) -> Tensor: + """Compute confusion matrix.""" + return _multilabel_confusion_matrix_compute(self.confmat, self.normalize) + + def plot( + self, + val: Optional[Tensor] = None, + ax: Optional[_AX_TYPE] = None, + add_text: bool = True, + labels: Optional[list[str]] = None, + cmap: Optional[_CMAP_TYPE] = None, + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + add_text: if the value of each cell should be added to the plot + labels: a list of strings, if provided will be added to the plot to indicate the different classes + cmap: matplotlib colormap to use for the confusion matrix + https://matplotlib.org/stable/users/explain/colors/colormaps.html + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randint + >>> from torchmetrics.classification import MulticlassBrier + >>> metric = MulticlassBrier(num_classes=5) + >>> metric.update(randint(5, (20,)), randint(5, (20,))) + >>> fig_, ax_ = metric.plot() + + """ + val = val if val is not None else self.compute() + if not isinstance(val, Tensor): + raise TypeError(f"Expected val to be a single tensor but got {val}") + fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap) + return fig, ax + + +class Brier(_ClassificationTaskWrapper): + r"""Compute the `confusion matrix`_. + + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of + :class:`~torchmetrics.classification.BinaryBrier`, + :class:`~torchmetrics.classification.MulticlassBrier` and + :class:`~torchmetrics.classification.MultilabelBrier` for the specific details of each argument influence + and examples. + + Legacy Example: + >>> from torch import tensor + >>> target = tensor([1, 1, 0, 0]) + >>> preds = tensor([0, 1, 0, 0]) + >>> confmat = Brier(task="binary", num_classes=2) + >>> confmat(preds, target) + tensor([[2, 0], + [1, 1]]) + + >>> target = tensor([2, 1, 0, 0]) + >>> preds = tensor([2, 1, 0, 1]) + >>> confmat = Brier(task="multiclass", num_classes=3) + >>> confmat(preds, target) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + + >>> target = tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) + >>> confmat = Brier(task="multilabel", num_labels=3) + >>> confmat(preds, target) + tensor([[[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[0, 1], [0, 1]]]) + + """ + + def __new__( # type: ignore[misc] + cls: type["Brier"], + task: Literal["binary", "multiclass", "multilabel"], + threshold: float = 0.5, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + """Initialize task metric.""" + task = ClassificationTask.from_str(task) + kwargs.update({"normalize": normalize, "ignore_index": ignore_index, "validate_args": validate_args}) + if task == ClassificationTask.BINARY: + return BinaryBrier(threshold, **kwargs) + if task == ClassificationTask.MULTICLASS: + if not isinstance(num_classes, int): + raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") + return MulticlassBrier(num_classes, **kwargs) + if task == ClassificationTask.MULTILABEL: + if not isinstance(num_labels, int): + raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") + return MultilabelBrier(num_labels, threshold, **kwargs) + raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/functional/classification/brier.py b/src/torchmetrics/functional/classification/brier.py new file mode 100644 index 00000000000..3b965c83071 --- /dev/null +++ b/src/torchmetrics/functional/classification/brier.py @@ -0,0 +1,202 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.utilities.compute import normalize_logits_if_needed + + +def _brier_decomposition( + probabilities: torch.Tensor = None, + confusion_matrix: torch.Tensor = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Decompose the Brier score into uncertainty, resolution, and reliability. + + [Proper scoring rules][1] measure the quality of probabilistic predictions; + any proper scoring rule admits a [unique decomposition][2] as + `Score = Uncertainty - Resolution + Reliability`, where: + + * `Uncertainty`, is a generalized entropy of the average predictive + distribution; it can both be positive or negative. + * `Resolution`, is a generalized variance of individual predictive + distributions; it is always non-negative. Difference in predictions reveal + information, that is why a larger resolution improves the predictive score. + * `Reliability`, a measure of calibration of predictions against the true + frequency of events. It is always non-negative and a lower value here + indicates better calibration. + + Args: + labels: Tensor, (n,), with torch.int64 elements containing ground + truth class labels in the range [0, nlabels]. + logits: Tensor, (n, nlabels), with logits for n instances and nlabels. + probabilities: Tensor, (n, nlabels), with predictive probability + distribution (alternative to logits argument). + confusion_matrix: Tensor, (nlabels, nlabels), the confusion matrix. + + Returns: + uncertainty: Tensor, scalar, the uncertainty component of the + decomposition. + resolution: Tensor, scalar, the resolution component of the decomposition. + reliability: Tensor, scalar, the reliability component of the + decomposition. + + """ + n, nlabels = probabilities.shape # Implicit rank check. + + confusion_matrix = confusion_matrix.T + + # Compute pbar, the average distribution + pred_class = torch.argmax(probabilities, dim=1) + dist_weights = confusion_matrix.sum(dim=1) + dist_weights /= dist_weights.sum() + pbar = confusion_matrix.sum(dim=0) + pbar /= pbar.sum() + + # dist_mean[k,:] contains the empirical distribution for the set M_k + # Some outcomes may not realize, corresponding to dist_weights[k] = 0 + dist_mean = confusion_matrix / (confusion_matrix.sum(dim=1, keepdim=True) + 1.0e-7) + + # Uncertainty: quadratic entropy of the average label distribution + uncertainty = torch.sum(pbar - pbar**2) + + # Resolution: expected quadratic divergence of predictive to mean + resolution = (pbar.unsqueeze(1) - dist_mean) ** 2 + resolution = torch.sum(dist_weights * resolution.sum(dim=1)) + + # Reliability: expected quadratic divergence of predictive to true + prob_true = dist_mean[pred_class] + reliability = torch.sum((prob_true - probabilities) ** 2, dim=1) + reliability = torch.mean(reliability) + + return uncertainty, resolution, reliability + + +def _mean_brier_score(labels: torch.Tensor, probabilities: torch.Tensor = None) -> torch.Tensor: + """Compute elementwise Brier score. + + The Brier score is a proper scoring rule that measures the accuracy of probabilistic predictions. + It is calculated as the squared difference between the predicted probability distribution and + the actual outcome. + + Args: + labels (torch.Tensor): Tensor of integer labels with shape [N1, N2, ...]. + probs (torch.Tensor, optional): Tensor of categorical probabilities with shape [N1, N2, ..., M]. + logits (torch.Tensor, optional): If `probs` is None, class probabilities are computed as a + softmax over these logits. This argument is ignored if `probs` is provided. + + Returns: + torch.Tensor: Tensor of shape [N1, N2, ...] consisting of the Brier score contribution + from each element. The full-dataset Brier score is the average of these values. + + """ + nlabels = probabilities.shape[-1] + flat_probabilities = probabilities.view(-1, nlabels) + flat_labels = labels.view(-1) + + # Gather the probabilities corresponding to the true labels + plabel = flat_probabilities[torch.arange(len(flat_labels)), flat_labels] + out = torch.sum(flat_probabilities**2, dim=-1) - 2 * plabel + 1 + + return out.view(labels.shape).mean() + + +def _mean_brier_score_and_decomposition( + labels: torch.Tensor, + probabilities: torch.Tensor = None, + confusion_matrix: torch.Tensor = None, +) -> dict[str, torch.Tensor]: + mean_brier = _mean_brier_score(labels, probabilities) + uncertainty, resolution, reliability = _brier_decomposition(probabilities, confusion_matrix) + + return { + "MeanBrier": mean_brier, + "Uncertainty": uncertainty, + "Resolution": resolution, + "Reliability": reliability, + } + + +def _adjust_threshold_arg( + thresholds: Optional[Union[int, list[float], Tensor]] = None, device: Optional[torch.device] = None +) -> Optional[Tensor]: + """Convert threshold arg for list and int to tensor format.""" + if isinstance(thresholds, int): + return torch.linspace(0, 1, thresholds, device=device) + if isinstance(thresholds, list): + return torch.tensor(thresholds, device=device) + return thresholds + + +def _binary_brier_format( + preds: Tensor, + target: Tensor, + ignore_index: Optional[int] = None, + normalization: Optional[Literal["sigmoid", "softmax"]] = "sigmoid", +) -> tuple[Tensor, Tensor]: + """Convert all input to the right format. + + - flattens additional dimensions + - Remove all datapoints that should be ignored + - Applies sigmoid if pred tensor not in [0,1] range + - Format thresholds arg to be a tensor + + """ + preds = preds.flatten() + target = target.flatten() + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + + preds = normalize_logits_if_needed(preds, normalization) + + probs_zero_class = torch.ones(preds.shape) - preds + preds = torch.cat([probs_zero_class.unsqueeze(dim=-1), preds.unsqueeze(dim=-1)], dim=-1) + + return preds, target + + +def _multiclass_brier_format( + preds: Tensor, + target: Tensor, + num_classes: int, + ignore_index: Optional[int] = None, +) -> tuple[Tensor, Tensor]: + """Convert all input to the right format. + + - flattens additional dimensions + - Remove all datapoints that should be ignored + - Applies softmax if pred tensor not in [0,1] range + - Format thresholds arg to be a tensor + + """ + preds = preds.transpose(0, 1).reshape(num_classes, -1).T + target = target.flatten() + + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + + preds = normalize_logits_if_needed(preds, "softmax") + + return preds, target + + +def _brier_binary_validation(): + return diff --git a/tests/unittests/classification/test_brier.py b/tests/unittests/classification/test_brier.py new file mode 100644 index 00000000000..3e2ab351cba --- /dev/null +++ b/tests/unittests/classification/test_brier.py @@ -0,0 +1,80 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# from sklearn.metrics import brier as sk_brier +from torch import tensor + +from torchmetrics.classification.brier import ( + BinaryBrier, + MulticlassBrier, +) + +# from torchmetrics.functional.classification.brier import ( +# binary_brier, +# multiclass_brier, +# multilabel_brier, +# ) +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester + +seed_all(42) + + +class TestBinaryBrier(MetricTester): + """Test class for `BinaryBrier` metric.""" + + def test_binary_brier(self): + """Test class implementation of metric.""" + target = tensor([0, 1, 0, 1, 0, 1]) + preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + metric = BinaryBrier() + print("\n") + print("BinaryBrier:") + print(metric(preds, target)) + print("\n") + + +class TestMulticlassBrier(MetricTester): + """Test class for `BinaryBrier` metric.""" + + def test_multiclass_brier(self): + """Test class implementation of metric.""" + target = tensor([1, 4, 3, 0, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0]) + preds = tensor([ + [0.9, 0.05, 0.02, 0.01, 0.02], + [0.01, 0.01, 0.01, 0.01, 0.96], + [0.85, 0.05, 0.05, 0.02, 0.03], + [0.01, 0.01, 0.9, 0.05, 0.03], + [0.8, 0.1, 0.05, 0.03, 0.02], + [0.01, 0.01, 0.01, 0.9, 0.07], + [0.05, 0.1, 0.8, 0.03, 0.02], + [0.02, 0.02, 0.02, 0.9, 0.04], + [0.01, 0.01, 0.01, 0.02, 0.95], + [0.7, 0.2, 0.05, 0.03, 0.02], + [0.02, 0.9, 0.05, 0.02, 0.01], + [0.1, 0.2, 0.6, 0.05, 0.05], + [0.03, 0.03, 0.03, 0.9, 0.01], + [0.02, 0.02, 0.02, 0.02, 0.92], + [0.75, 0.15, 0.05, 0.03, 0.02], + [0.02, 0.85, 0.1, 0.02, 0.01], + [0.1, 0.15, 0.7, 0.03, 0.02], + [0.02, 0.02, 0.02, 0.92, 0.02], + [0.01, 0.01, 0.01, 0.03, 0.94], + [0.8, 0.1, 0.05, 0.03, 0.02], + ]) + metric = MulticlassBrier(num_classes=5) + print("MulticlassBrier:") + print(metric(preds, target)) + print("\n")