From 06b51e22a44768963de39dd98a8d44125a6b33d4 Mon Sep 17 00:00:00 2001 From: Konstantinos Pitas Date: Mon, 22 Sep 2025 16:06:49 +0200 Subject: [PATCH 1/6] BinaryBrier first version --- src/torchmetrics/classification/__init__.py | 6 + src/torchmetrics/classification/brier.py | 555 ++++++++++++++++++ .../functional/classification/brier.py | 175 ++++++ tests/unittests/classification/test_brier.py | 55 ++ 4 files changed, 791 insertions(+) create mode 100644 src/torchmetrics/classification/brier.py create mode 100644 src/torchmetrics/functional/classification/brier.py create mode 100644 tests/unittests/classification/test_brier.py diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 3e41f565879..3f68631fe11 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -19,6 +19,12 @@ MulticlassAveragePrecision, MultilabelAveragePrecision, ) +from torchmetrics.classification.brier import ( + Brier, + BinaryBrier, + MulticlassBrier, + MultilabelBrier, +) from torchmetrics.classification.calibration_error import ( BinaryCalibrationError, CalibrationError, diff --git a/src/torchmetrics/classification/brier.py b/src/torchmetrics/classification/brier.py new file mode 100644 index 00000000000..2985acff8d5 --- /dev/null +++ b/src/torchmetrics/classification/brier.py @@ -0,0 +1,555 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.classification.base import _ClassificationTaskWrapper +from torchmetrics.functional.classification.confusion_matrix import ( + _binary_confusion_matrix_arg_validation, + _binary_confusion_matrix_compute, + _binary_confusion_matrix_format, + _binary_confusion_matrix_tensor_validation, + _binary_confusion_matrix_update, + _multiclass_confusion_matrix_arg_validation, + _multiclass_confusion_matrix_compute, + _multiclass_confusion_matrix_format, + _multiclass_confusion_matrix_tensor_validation, + _multiclass_confusion_matrix_update, + _multilabel_confusion_matrix_arg_validation, + _multilabel_confusion_matrix_compute, + _multilabel_confusion_matrix_format, + _multilabel_confusion_matrix_tensor_validation, + _multilabel_confusion_matrix_update, +) +from torchmetrics.functional.classification.brier import _mean_brier_score_and_decomposition, _brier_binary_format +from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _CMAP_TYPE, _PLOT_OUT_TYPE, plot_confusion_matrix +from torchmetrics.utilities.data import dim_zero_cat + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = [ + "BinaryBrier.plot", + "MulticlassBrier.plot", + "MultilabelBrier.plot", + ] + + +class BinaryBrier(Metric): + r"""Compute the `confusion matrix`_ for binary tasks. + + The confusion matrix :math:`C` is constructed such that :math:`C_{i, j}` is equal to the number of observations + known to be in class :math:`i` but predicted to be in class :math:`j`. Thus row indices of the confusion matrix + correspond to the true class labels and column indices correspond to the predicted class labels. + + For binary tasks, the confusion matrix is a 2x2 matrix with the following structure: + + - :math:`C_{0, 0}`: True negatives + - :math:`C_{0, 1}`: False positives + - :math:`C_{1, 0}`: False negatives + - :math:`C_{1, 1}`: True positives + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating point + tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per + element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``confusion_matrix`` (:class:`~torch.Tensor`): A tensor containing a ``(2, 2)`` matrix + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.classification import BinaryBrier + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> bcm = BinaryBrier() + >>> bcm(preds, target) + tensor([[2, 0], + [1, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import BinaryBrier + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> bcm = BinaryBrier() + >>> bcm(preds, target) + tensor([[2, 0], + [1, 1]]) + + """ + + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + confmat: Tensor + + def __init__( + self, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize) + self.threshold = threshold + self.ignore_index = ignore_index + self.normalize = normalize + self.validate_args = validate_args + + self.add_state("confmat", torch.zeros(2, 2, dtype=torch.long), dist_reduce_fx="sum") + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + if self.validate_args: + _binary_confusion_matrix_tensor_validation(preds, target, self.ignore_index) + + preds_conf, target_conf = _binary_confusion_matrix_format(preds, target, self.threshold, self.ignore_index) + confmat = _binary_confusion_matrix_update(preds_conf, target_conf) + + self.confmat += confmat + + preds_br, target_br = _brier_binary_format(preds, target, self.ignore_index) + self.preds.append(preds_br) + self.target.append(target_br) + + def compute(self) -> tuple[Tensor, Tensor, Tensor, Tensor]: + """Compute confusion matrix.""" + return _mean_brier_score_and_decomposition( + dim_zero_cat(self.target), dim_zero_cat(self.preds), self.confmat.float() + ) + + def plot( + self, + val: Optional[Tensor] = None, + ax: Optional[_AX_TYPE] = None, + add_text: bool = True, + labels: Optional[list[str]] = None, + cmap: Optional[_CMAP_TYPE] = None, + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + add_text: if the value of each cell should be added to the plot + labels: a list of strings, if provided will be added to the plot to indicate the different classes + cmap: matplotlib colormap to use for the confusion matrix + https://matplotlib.org/stable/users/explain/colors/colormaps.html + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randint + >>> from torchmetrics.classification import MulticlassBrier + >>> metric = MulticlassBrier(num_classes=5) + >>> metric.update(randint(5, (20,)), randint(5, (20,))) + >>> fig_, ax_ = metric.plot() + + """ + val = val if val is not None else self.compute() + if not isinstance(val, Tensor): + raise TypeError(f"Expected val to be a single tensor but got {val}") + fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap) + return fig, ax + + +class MulticlassBrier(Metric): + r"""Compute the `confusion matrix`_ for multiclass tasks. + + The confusion matrix :math:`C` is constructed such that :math:`C_{i, j}` is equal to the number of observations + known to be in class :math:`i` but predicted to be in class :math:`j`. Thus row indices of the confusion matrix + correspond to the true class labels and column indices correspond to the predicted class labels. + + For multiclass tasks, the confusion matrix is a NxN matrix, where: + + - :math:`C_{i, i}` represents the number of true positives for class :math:`i` + - :math:`\sum_{j=1, j\neq i}^N C_{i, j}` represents the number of false negatives for class :math:`i` + - :math:`\sum_{j=1, j\neq i}^N C_{j, i}` represents the number of false positives for class :math:`i` + - the sum of the remaining cells in the matrix represents the number of true negatives for class :math:`i` + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating point + tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per + element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``confusion_matrix``: [num_classes, num_classes] matrix + + Args: + num_classes: Integer specifying the number of classes + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (pred is integer tensor): + >>> from torch import tensor + >>> from torchmetrics.classification import MulticlassBrier + >>> target = tensor([2, 1, 0, 0]) + >>> preds = tensor([2, 1, 0, 1]) + >>> metric = MulticlassBrier(num_classes=3) + >>> metric(preds, target) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + + Example (pred is float tensor): + >>> from torchmetrics.classification import MulticlassBrier + >>> target = tensor([2, 1, 0, 0]) + >>> preds = tensor([[0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13]]) + >>> metric = MulticlassBrier(num_classes=3) + >>> metric(preds, target) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + + """ + + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + confmat: Tensor + + def __init__( + self, + num_classes: int, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["none", "true", "pred", "all"]] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize) + self.num_classes = num_classes + self.ignore_index = ignore_index + self.normalize = normalize + self.validate_args = validate_args + + self.add_state("confmat", torch.zeros(num_classes, num_classes, dtype=torch.long), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + if self.validate_args: + _multiclass_confusion_matrix_tensor_validation(preds, target, self.num_classes, self.ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, self.ignore_index) + confmat = _multiclass_confusion_matrix_update(preds, target, self.num_classes) + self.confmat += confmat + + def compute(self) -> Tensor: + """Compute confusion matrix.""" + return _multiclass_confusion_matrix_compute(self.confmat, self.normalize) + + def plot( + self, + val: Optional[Tensor] = None, + ax: Optional[_AX_TYPE] = None, + add_text: bool = True, + labels: Optional[list[str]] = None, + cmap: Optional[_CMAP_TYPE] = None, + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + add_text: if the value of each cell should be added to the plot + labels: a list of strings, if provided will be added to the plot to indicate the different classes + cmap: matplotlib colormap to use for the confusion matrix + https://matplotlib.org/stable/users/explain/colors/colormaps.html + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randint + >>> from torchmetrics.classification import MulticlassBrier + >>> metric = MulticlassBrier(num_classes=5) + >>> metric.update(randint(5, (20,)), randint(5, (20,))) + >>> fig_, ax_ = metric.plot() + + """ + val = val if val is not None else self.compute() + if not isinstance(val, Tensor): + raise TypeError(f"Expected val to be a single tensor but got {val}") + fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap) + return fig, ax + + +class MultilabelBrier(Metric): + r"""Compute the `confusion matrix`_ for multilabel tasks. + + The confusion matrix :math:`C` is constructed such that :math:`C_{i, j}` is equal to the number of observations + known to be in class :math:`i` but predicted to be in class :math:`j`. Thus row indices of the confusion matrix + correspond to the true class labels and column indices correspond to the predicted class labels. + + For multilabel tasks, the confusion matrix is a Nx2x2 tensor, where each 2x2 matrix corresponds to the confusion + for that label. The structure of each 2x2 matrix is as follows: + + - :math:`C_{0, 0}`: True negatives + - :math:`C_{0, 1}`: False positives + - :math:`C_{1, 0}`: False negatives + - :math:`C_{1, 1}`: True positives + + As input to 'update' the metric accepts the following input: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + As output of 'compute' the metric returns the following output: + + - ``confusion matrix``: [num_labels,2,2] matrix + + Args: + num_classes: Integer specifying the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torch import tensor + >>> from torchmetrics.classification import MultilabelBrier + >>> target = tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelBrier(num_labels=3) + >>> metric(preds, target) + tensor([[[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[0, 1], [0, 1]]]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MultilabelBrier + >>> target = tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelBrier(num_labels=3) + >>> metric(preds, target) + tensor([[[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[0, 1], [0, 1]]]) + + """ + + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + confmat: Tensor + + def __init__( + self, + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["none", "true", "pred", "all"]] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index, normalize) + self.num_labels = num_labels + self.threshold = threshold + self.ignore_index = ignore_index + self.normalize = normalize + self.validate_args = validate_args + + self.add_state("confmat", torch.zeros(num_labels, 2, 2, dtype=torch.long), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + if self.validate_args: + _multilabel_confusion_matrix_tensor_validation(preds, target, self.num_labels, self.ignore_index) + preds, target = _multilabel_confusion_matrix_format( + preds, target, self.num_labels, self.threshold, self.ignore_index + ) + confmat = _multilabel_confusion_matrix_update(preds, target, self.num_labels) + self.confmat += confmat + + def compute(self) -> Tensor: + """Compute confusion matrix.""" + return _multilabel_confusion_matrix_compute(self.confmat, self.normalize) + + def plot( + self, + val: Optional[Tensor] = None, + ax: Optional[_AX_TYPE] = None, + add_text: bool = True, + labels: Optional[list[str]] = None, + cmap: Optional[_CMAP_TYPE] = None, + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + add_text: if the value of each cell should be added to the plot + labels: a list of strings, if provided will be added to the plot to indicate the different classes + cmap: matplotlib colormap to use for the confusion matrix + https://matplotlib.org/stable/users/explain/colors/colormaps.html + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randint + >>> from torchmetrics.classification import MulticlassBrier + >>> metric = MulticlassBrier(num_classes=5) + >>> metric.update(randint(5, (20,)), randint(5, (20,))) + >>> fig_, ax_ = metric.plot() + + """ + val = val if val is not None else self.compute() + if not isinstance(val, Tensor): + raise TypeError(f"Expected val to be a single tensor but got {val}") + fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap) + return fig, ax + + +class Brier(_ClassificationTaskWrapper): + r"""Compute the `confusion matrix`_. + + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of + :class:`~torchmetrics.classification.BinaryBrier`, + :class:`~torchmetrics.classification.MulticlassBrier` and + :class:`~torchmetrics.classification.MultilabelBrier` for the specific details of each argument influence + and examples. + + Legacy Example: + >>> from torch import tensor + >>> target = tensor([1, 1, 0, 0]) + >>> preds = tensor([0, 1, 0, 0]) + >>> confmat = Brier(task="binary", num_classes=2) + >>> confmat(preds, target) + tensor([[2, 0], + [1, 1]]) + + >>> target = tensor([2, 1, 0, 0]) + >>> preds = tensor([2, 1, 0, 1]) + >>> confmat = Brier(task="multiclass", num_classes=3) + >>> confmat(preds, target) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + + >>> target = tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) + >>> confmat = Brier(task="multilabel", num_labels=3) + >>> confmat(preds, target) + tensor([[[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[0, 1], [0, 1]]]) + + """ + + def __new__( # type: ignore[misc] + cls: type["Brier"], + task: Literal["binary", "multiclass", "multilabel"], + threshold: float = 0.5, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + """Initialize task metric.""" + task = ClassificationTask.from_str(task) + kwargs.update({"normalize": normalize, "ignore_index": ignore_index, "validate_args": validate_args}) + if task == ClassificationTask.BINARY: + return BinaryBrier(threshold, **kwargs) + if task == ClassificationTask.MULTICLASS: + if not isinstance(num_classes, int): + raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") + return MulticlassBrier(num_classes, **kwargs) + if task == ClassificationTask.MULTILABEL: + if not isinstance(num_labels, int): + raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") + return MultilabelBrier(num_labels, threshold, **kwargs) + raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/functional/classification/brier.py b/src/torchmetrics/functional/classification/brier.py new file mode 100644 index 00000000000..3c5ae3710c5 --- /dev/null +++ b/src/torchmetrics/functional/classification/brier.py @@ -0,0 +1,175 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence +from typing import List, Optional, Union + +import torch +from torch import Tensor, tensor +from torch.nn import functional as F # noqa: N812 +from typing_extensions import Literal + +from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.compute import _safe_divide, interp, normalize_logits_if_needed +from torchmetrics.utilities.data import _bincount, _cumsum +from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.prints import rank_zero_warn +import torch +import torch.nn.functional as F + + +def _brier_decomposition( + probabilities: torch.Tensor = None, + confusion_matrix: torch.Tensor = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + r"""Decompose the Brier score into uncertainty, resolution, and reliability. + + [Proper scoring rules][1] measure the quality of probabilistic predictions; + any proper scoring rule admits a [unique decomposition][2] as + `Score = Uncertainty - Resolution + Reliability`, where: + + * `Uncertainty`, is a generalized entropy of the average predictive + distribution; it can both be positive or negative. + * `Resolution`, is a generalized variance of individual predictive + distributions; it is always non-negative. Difference in predictions reveal + information, that is why a larger resolution improves the predictive score. + * `Reliability`, a measure of calibration of predictions against the true + frequency of events. It is always non-negative and a lower value here + indicates better calibration. + + Args: + labels: Tensor, (n,), with torch.int64 elements containing ground + truth class labels in the range [0, nlabels]. + logits: Tensor, (n, nlabels), with logits for n instances and nlabels. + probabilities: Tensor, (n, nlabels), with predictive probability + distribution (alternative to logits argument). + confusion_matrix: Tensor, (nlabels, nlabels), the confusion matrix. + + Returns: + uncertainty: Tensor, scalar, the uncertainty component of the + decomposition. + resolution: Tensor, scalar, the resolution component of the decomposition. + reliability: Tensor, scalar, the reliability component of the + decomposition. + """ + n, nlabels = probabilities.shape # Implicit rank check. + + # Compute pbar, the average distribution + pred_class = torch.argmax(probabilities, dim=1) + dist_weights = confusion_matrix.sum(dim=1) + dist_weights /= dist_weights.sum() + pbar = confusion_matrix.sum(dim=0) + pbar /= pbar.sum() + + # dist_mean[k,:] contains the empirical distribution for the set M_k + # Some outcomes may not realize, corresponding to dist_weights[k] = 0 + dist_mean = confusion_matrix / (confusion_matrix.sum(dim=1, keepdim=True) + 1.0e-7) + + # Uncertainty: quadratic entropy of the average label distribution + uncertainty = -torch.sum(pbar**2) + + # Resolution: expected quadratic divergence of predictive to mean + resolution = (pbar.unsqueeze(1) - dist_mean) ** 2 + resolution = torch.sum(dist_weights * resolution.sum(dim=1)) + + # Reliability: expected quadratic divergence of predictive to true + prob_true = dist_mean[pred_class] + reliability = torch.sum((prob_true - probabilities) ** 2, dim=1) + reliability = torch.mean(reliability) + + return uncertainty, resolution, reliability + + +def _mean_brier_score(labels: torch.Tensor, probabilities: torch.Tensor = None) -> torch.Tensor: + """Compute elementwise Brier score. + + The Brier score is a proper scoring rule that measures the accuracy of probabilistic predictions. + It is calculated as the squared difference between the predicted probability distribution and + the actual outcome. + + Args: + labels (torch.Tensor): Tensor of integer labels with shape [N1, N2, ...]. + probs (torch.Tensor, optional): Tensor of categorical probabilities with shape [N1, N2, ..., M]. + logits (torch.Tensor, optional): If `probs` is None, class probabilities are computed as a + softmax over these logits. This argument is ignored if `probs` is provided. + + Returns: + torch.Tensor: Tensor of shape [N1, N2, ...] consisting of the Brier score contribution + from each element. The full-dataset Brier score is the average of these values. + """ + + nlabels = probabilities.shape[-1] + flat_probabilities = probabilities.view(-1, nlabels) + flat_labels = labels.view(-1) + + # Gather the probabilities corresponding to the true labels + plabel = flat_probabilities[torch.arange(len(flat_labels)), flat_labels] + out = torch.sum(flat_probabilities**2, dim=-1) - 2 * plabel + + return out.view(labels.shape).mean() + + +def _mean_brier_score_and_decomposition( + labels: torch.Tensor, + probabilities: torch.Tensor = None, + confusion_matrix: torch.Tensor = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + mean_brier = _mean_brier_score(labels, probabilities) + uncertainty, resolution, reliability = _brier_decomposition(probabilities, confusion_matrix) + + return mean_brier, uncertainty, resolution, reliability + + +def _adjust_threshold_arg( + thresholds: Optional[Union[int, list[float], Tensor]] = None, device: Optional[torch.device] = None +) -> Optional[Tensor]: + """Convert threshold arg for list and int to tensor format.""" + if isinstance(thresholds, int): + return torch.linspace(0, 1, thresholds, device=device) + if isinstance(thresholds, list): + return torch.tensor(thresholds, device=device) + return thresholds + + +def _brier_binary_format( + preds: Tensor, + target: Tensor, + ignore_index: Optional[int] = None, + normalization: Optional[Literal["sigmoid", "softmax"]] = "sigmoid", +) -> tuple[Tensor, Tensor]: + """Convert all input to the right format. + + - flattens additional dimensions + - Remove all datapoints that should be ignored + - Applies sigmoid if pred tensor not in [0,1] range + - Format thresholds arg to be a tensor + + """ + preds = preds.flatten() + target = target.flatten() + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + + preds = normalize_logits_if_needed(preds, normalization) + + probs_zero_class = torch.ones(preds.shape) - preds + preds = torch.cat([probs_zero_class.unsqueeze(dim=-1), preds.unsqueeze(dim=-1)], dim=-1) + + return preds, target + + +def _brier_binary_validation(): + return diff --git a/tests/unittests/classification/test_brier.py b/tests/unittests/classification/test_brier.py new file mode 100644 index 00000000000..f3239badf17 --- /dev/null +++ b/tests/unittests/classification/test_brier.py @@ -0,0 +1,55 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import numpy as np +import pytest +import torch +from scipy.special import expit as sigmoid + +# from sklearn.metrics import brier as sk_brier +from torch import tensor + +from torchmetrics.classification.brier import ( + BinaryBrier, + Brier, + MulticlassBrier, + MultilabelBrier, +) + +# from torchmetrics.functional.classification.brier import ( +# binary_brier, +# multiclass_brier, +# multilabel_brier, +# ) +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 +from unittests import NUM_CLASSES, THRESHOLD +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases + +seed_all(42) + + +class TestBinaryBrier(MetricTester): + """Test class for `BinaryBrier` metric.""" + + def test_binary_brier(self): + """Test class implementation of metric.""" + target = tensor([0, 1, 0, 1, 0, 1]) + preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + metric = BinaryBrier() + print("BinaryBrier: \n") + print(metric(preds, target)) From 1b18c98732bfbd2b14f086dda154df912da746fc Mon Sep 17 00:00:00 2001 From: Konstantinos Pitas Date: Mon, 22 Sep 2025 16:08:05 +0200 Subject: [PATCH 2/6] Fix bugs in metric and decomposition calculation The mean Brier score and it's decomposition should satisfy the following equation Brier = Uncertainty - Resolution + Reliability. After inspecting the results on a toy test case, the Uncertainty was estimated wrongly with a negative sign, also the Brier score formula was slightly wrong (f-1)^2=f^2-2f+1 not (f-1)^2=f^2-2f causing it to also be negative. With these fixes the decomposition equation is satisfied. --- src/torchmetrics/functional/classification/brier.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/classification/brier.py b/src/torchmetrics/functional/classification/brier.py index 3c5ae3710c5..4a6e968c22e 100644 --- a/src/torchmetrics/functional/classification/brier.py +++ b/src/torchmetrics/functional/classification/brier.py @@ -77,7 +77,7 @@ def _brier_decomposition( dist_mean = confusion_matrix / (confusion_matrix.sum(dim=1, keepdim=True) + 1.0e-7) # Uncertainty: quadratic entropy of the average label distribution - uncertainty = -torch.sum(pbar**2) + uncertainty = torch.sum(pbar**2) # Resolution: expected quadratic divergence of predictive to mean resolution = (pbar.unsqueeze(1) - dist_mean) ** 2 @@ -115,7 +115,7 @@ def _mean_brier_score(labels: torch.Tensor, probabilities: torch.Tensor = None) # Gather the probabilities corresponding to the true labels plabel = flat_probabilities[torch.arange(len(flat_labels)), flat_labels] - out = torch.sum(flat_probabilities**2, dim=-1) - 2 * plabel + out = torch.sum(flat_probabilities**2, dim=-1) - 2 * plabel + 1 return out.view(labels.shape).mean() From 9a2f5e4092255aadc5b8d1afcd36d7d900eb889a Mon Sep 17 00:00:00 2001 From: Konstantinos Pitas Date: Mon, 22 Sep 2025 17:09:07 +0200 Subject: [PATCH 3/6] Multiclass Brier first implementation --- src/torchmetrics/classification/brier.py | 24 +++++++++++---- .../functional/classification/brier.py | 29 ++++++++++++++++++- tests/unittests/classification/test_brier.py | 24 ++++++++++++++- 3 files changed, 69 insertions(+), 8 deletions(-) diff --git a/src/torchmetrics/classification/brier.py b/src/torchmetrics/classification/brier.py index 2985acff8d5..e6a15f3d6ea 100644 --- a/src/torchmetrics/classification/brier.py +++ b/src/torchmetrics/classification/brier.py @@ -35,7 +35,11 @@ _multilabel_confusion_matrix_tensor_validation, _multilabel_confusion_matrix_update, ) -from torchmetrics.functional.classification.brier import _mean_brier_score_and_decomposition, _brier_binary_format +from torchmetrics.functional.classification.brier import ( + _mean_brier_score_and_decomposition, + _binary_brier_format, + _multiclass_brier_format, +) from torchmetrics.metric import Metric from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE @@ -147,7 +151,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.confmat += confmat - preds_br, target_br = _brier_binary_format(preds, target, self.ignore_index) + preds_br, target_br = _binary_brier_format(preds, target, self.ignore_index) self.preds.append(preds_br) self.target.append(target_br) @@ -288,18 +292,26 @@ def __init__( self.validate_args = validate_args self.add_state("confmat", torch.zeros(num_classes, num_classes, dtype=torch.long), dist_reduce_fx="sum") + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" if self.validate_args: _multiclass_confusion_matrix_tensor_validation(preds, target, self.num_classes, self.ignore_index) - preds, target = _multiclass_confusion_matrix_format(preds, target, self.ignore_index) - confmat = _multiclass_confusion_matrix_update(preds, target, self.num_classes) + preds_conf, target_conf = _multiclass_confusion_matrix_format(preds, target, self.ignore_index) + confmat = _multiclass_confusion_matrix_update(preds_conf, target_conf, self.num_classes) self.confmat += confmat - def compute(self) -> Tensor: + preds_br, target_br = _multiclass_brier_format(preds, target, self.num_classes, self.ignore_index) + self.preds.append(preds_br) + self.target.append(target_br) + + def compute(self) -> tuple[Tensor, Tensor, Tensor, Tensor]: """Compute confusion matrix.""" - return _multiclass_confusion_matrix_compute(self.confmat, self.normalize) + return _mean_brier_score_and_decomposition( + dim_zero_cat(self.target), dim_zero_cat(self.preds), self.confmat.float() + ) def plot( self, diff --git a/src/torchmetrics/functional/classification/brier.py b/src/torchmetrics/functional/classification/brier.py index 4a6e968c22e..7d6c8533ff0 100644 --- a/src/torchmetrics/functional/classification/brier.py +++ b/src/torchmetrics/functional/classification/brier.py @@ -142,7 +142,7 @@ def _adjust_threshold_arg( return thresholds -def _brier_binary_format( +def _binary_brier_format( preds: Tensor, target: Tensor, ignore_index: Optional[int] = None, @@ -171,5 +171,32 @@ def _brier_binary_format( return preds, target +def _multiclass_brier_format( + preds: Tensor, + target: Tensor, + num_classes: int, + ignore_index: Optional[int] = None, +) -> tuple[Tensor, Tensor]: + """Convert all input to the right format. + + - flattens additional dimensions + - Remove all datapoints that should be ignored + - Applies softmax if pred tensor not in [0,1] range + - Format thresholds arg to be a tensor + + """ + preds = preds.transpose(0, 1).reshape(num_classes, -1).T + target = target.flatten() + + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + + preds = normalize_logits_if_needed(preds, "softmax") + + return preds, target + + def _brier_binary_validation(): return diff --git a/tests/unittests/classification/test_brier.py b/tests/unittests/classification/test_brier.py index f3239badf17..c48a4e05961 100644 --- a/tests/unittests/classification/test_brier.py +++ b/tests/unittests/classification/test_brier.py @@ -51,5 +51,27 @@ def test_binary_brier(self): target = tensor([0, 1, 0, 1, 0, 1]) preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) metric = BinaryBrier() - print("BinaryBrier: \n") + print("\n") + print("BinaryBrier:") print(metric(preds, target)) + print("\n") + + +class TestMulticlassBrier(MetricTester): + """Test class for `BinaryBrier` metric.""" + + def test_multiclass_brier(self): + """Test class implementation of metric.""" + target = tensor([1, 5, 3, 0, 0, 3]) + preds = tensor([ + [0.9, 0.05, 0.02, 0.01, 0.01, 0.01], + [0.01, 0.01, 0.01, 0.01, 0.01, 0.95], + [0.85, 0.05, 0.05, 0.02, 0.02, 0.01], + [0.01, 0.01, 0.9, 0.05, 0.02, 0.01], + [0.8, 0.1, 0.05, 0.03, 0.01, 0.01], + [0.01, 0.01, 0.01, 0.9, 0.05, 0.02], + ]) + metric = MulticlassBrier(num_classes=6) + print("MulticlassBrier:") + print(metric(preds, target)) + print("\n") From 0aee78912025a6eebd1ab1dc51edf902c01b6db3 Mon Sep 17 00:00:00 2001 From: Konstantinos Pitas Date: Mon, 22 Sep 2025 18:25:47 +0200 Subject: [PATCH 4/6] Some more fixes for the decomposition There were some more mistakes in estimating the Uncertainty. Also, the Confusion matrix had to be transposed such that the true labels are on the x axis. --- .../functional/classification/brier.py | 4 ++- tests/unittests/classification/test_brier.py | 30 ++++++++++++++----- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/torchmetrics/functional/classification/brier.py b/src/torchmetrics/functional/classification/brier.py index 7d6c8533ff0..a0843eb6fcf 100644 --- a/src/torchmetrics/functional/classification/brier.py +++ b/src/torchmetrics/functional/classification/brier.py @@ -65,6 +65,8 @@ def _brier_decomposition( """ n, nlabels = probabilities.shape # Implicit rank check. + confusion_matrix = confusion_matrix.T + # Compute pbar, the average distribution pred_class = torch.argmax(probabilities, dim=1) dist_weights = confusion_matrix.sum(dim=1) @@ -77,7 +79,7 @@ def _brier_decomposition( dist_mean = confusion_matrix / (confusion_matrix.sum(dim=1, keepdim=True) + 1.0e-7) # Uncertainty: quadratic entropy of the average label distribution - uncertainty = torch.sum(pbar**2) + uncertainty = torch.sum(pbar - pbar**2) # Resolution: expected quadratic divergence of predictive to mean resolution = (pbar.unsqueeze(1) - dist_mean) ** 2 diff --git a/tests/unittests/classification/test_brier.py b/tests/unittests/classification/test_brier.py index c48a4e05961..b7aeb443112 100644 --- a/tests/unittests/classification/test_brier.py +++ b/tests/unittests/classification/test_brier.py @@ -62,16 +62,30 @@ class TestMulticlassBrier(MetricTester): def test_multiclass_brier(self): """Test class implementation of metric.""" - target = tensor([1, 5, 3, 0, 0, 3]) + target = tensor([1, 4, 3, 0, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0]) preds = tensor([ - [0.9, 0.05, 0.02, 0.01, 0.01, 0.01], - [0.01, 0.01, 0.01, 0.01, 0.01, 0.95], - [0.85, 0.05, 0.05, 0.02, 0.02, 0.01], - [0.01, 0.01, 0.9, 0.05, 0.02, 0.01], - [0.8, 0.1, 0.05, 0.03, 0.01, 0.01], - [0.01, 0.01, 0.01, 0.9, 0.05, 0.02], + [0.9, 0.05, 0.02, 0.01, 0.02], + [0.01, 0.01, 0.01, 0.01, 0.96], + [0.85, 0.05, 0.05, 0.02, 0.03], + [0.01, 0.01, 0.9, 0.05, 0.03], + [0.8, 0.1, 0.05, 0.03, 0.02], + [0.01, 0.01, 0.01, 0.9, 0.07], + [0.05, 0.1, 0.8, 0.03, 0.02], + [0.02, 0.02, 0.02, 0.9, 0.04], + [0.01, 0.01, 0.01, 0.02, 0.95], + [0.7, 0.2, 0.05, 0.03, 0.02], + [0.02, 0.9, 0.05, 0.02, 0.01], + [0.1, 0.2, 0.6, 0.05, 0.05], + [0.03, 0.03, 0.03, 0.9, 0.01], + [0.02, 0.02, 0.02, 0.02, 0.92], + [0.75, 0.15, 0.05, 0.03, 0.02], + [0.02, 0.85, 0.1, 0.02, 0.01], + [0.1, 0.15, 0.7, 0.03, 0.02], + [0.02, 0.02, 0.02, 0.92, 0.02], + [0.01, 0.01, 0.01, 0.03, 0.94], + [0.8, 0.1, 0.05, 0.03, 0.02], ]) - metric = MulticlassBrier(num_classes=6) + metric = MulticlassBrier(num_classes=5) print("MulticlassBrier:") print(metric(preds, target)) print("\n") From 563883cee7a986edf8ff2bc8acc95b71117fff67 Mon Sep 17 00:00:00 2001 From: Konstantinos Pitas Date: Mon, 22 Sep 2025 18:46:00 +0200 Subject: [PATCH 5/6] Change return type to dict --- src/torchmetrics/classification/brier.py | 4 ++-- src/torchmetrics/functional/classification/brier.py | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/classification/brier.py b/src/torchmetrics/classification/brier.py index e6a15f3d6ea..9717cd44c29 100644 --- a/src/torchmetrics/classification/brier.py +++ b/src/torchmetrics/classification/brier.py @@ -155,7 +155,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.preds.append(preds_br) self.target.append(target_br) - def compute(self) -> tuple[Tensor, Tensor, Tensor, Tensor]: + def compute(self) -> dict[str, torch.Tensor]: """Compute confusion matrix.""" return _mean_brier_score_and_decomposition( dim_zero_cat(self.target), dim_zero_cat(self.preds), self.confmat.float() @@ -307,7 +307,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.preds.append(preds_br) self.target.append(target_br) - def compute(self) -> tuple[Tensor, Tensor, Tensor, Tensor]: + def compute(self) -> dict[str, torch.Tensor]: """Compute confusion matrix.""" return _mean_brier_score_and_decomposition( dim_zero_cat(self.target), dim_zero_cat(self.preds), self.confmat.float() diff --git a/src/torchmetrics/functional/classification/brier.py b/src/torchmetrics/functional/classification/brier.py index a0843eb6fcf..e2ec42e1e02 100644 --- a/src/torchmetrics/functional/classification/brier.py +++ b/src/torchmetrics/functional/classification/brier.py @@ -126,11 +126,16 @@ def _mean_brier_score_and_decomposition( labels: torch.Tensor, probabilities: torch.Tensor = None, confusion_matrix: torch.Tensor = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> dict[str, torch.Tensor]: mean_brier = _mean_brier_score(labels, probabilities) uncertainty, resolution, reliability = _brier_decomposition(probabilities, confusion_matrix) - return mean_brier, uncertainty, resolution, reliability + return { + "MeanBrier": mean_brier, + "Uncertainty": uncertainty, + "Resolution": resolution, + "Reliability": reliability, + } def _adjust_threshold_arg( From 5343a79534e158b7f2f401b677632500deddf5d3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Sep 2025 16:57:20 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/classification/__init__.py | 2 +- src/torchmetrics/classification/brier.py | 14 ++++++-------- .../functional/classification/brier.py | 17 +++++------------ tests/unittests/classification/test_brier.py | 13 +------------ 4 files changed, 13 insertions(+), 33 deletions(-) diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 3f68631fe11..f9e0ab658a4 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -20,8 +20,8 @@ MultilabelAveragePrecision, ) from torchmetrics.classification.brier import ( - Brier, BinaryBrier, + Brier, MulticlassBrier, MultilabelBrier, ) diff --git a/src/torchmetrics/classification/brier.py b/src/torchmetrics/classification/brier.py index 9717cd44c29..ff31c46f5ba 100644 --- a/src/torchmetrics/classification/brier.py +++ b/src/torchmetrics/classification/brier.py @@ -18,14 +18,17 @@ from typing_extensions import Literal from torchmetrics.classification.base import _ClassificationTaskWrapper +from torchmetrics.functional.classification.brier import ( + _binary_brier_format, + _mean_brier_score_and_decomposition, + _multiclass_brier_format, +) from torchmetrics.functional.classification.confusion_matrix import ( _binary_confusion_matrix_arg_validation, - _binary_confusion_matrix_compute, _binary_confusion_matrix_format, _binary_confusion_matrix_tensor_validation, _binary_confusion_matrix_update, _multiclass_confusion_matrix_arg_validation, - _multiclass_confusion_matrix_compute, _multiclass_confusion_matrix_format, _multiclass_confusion_matrix_tensor_validation, _multiclass_confusion_matrix_update, @@ -35,16 +38,11 @@ _multilabel_confusion_matrix_tensor_validation, _multilabel_confusion_matrix_update, ) -from torchmetrics.functional.classification.brier import ( - _mean_brier_score_and_decomposition, - _binary_brier_format, - _multiclass_brier_format, -) from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _CMAP_TYPE, _PLOT_OUT_TYPE, plot_confusion_matrix -from torchmetrics.utilities.data import dim_zero_cat if not _MATPLOTLIB_AVAILABLE: __doctest_skip__ = [ diff --git a/src/torchmetrics/functional/classification/brier.py b/src/torchmetrics/functional/classification/brier.py index e2ec42e1e02..3b965c83071 100644 --- a/src/torchmetrics/functional/classification/brier.py +++ b/src/torchmetrics/functional/classification/brier.py @@ -12,21 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Sequence -from typing import List, Optional, Union +from typing import Optional, Union import torch -from torch import Tensor, tensor -from torch.nn import functional as F # noqa: N812 +from torch import Tensor from typing_extensions import Literal -from torchmetrics.utilities.checks import _check_same_shape -from torchmetrics.utilities.compute import _safe_divide, interp, normalize_logits_if_needed -from torchmetrics.utilities.data import _bincount, _cumsum -from torchmetrics.utilities.enums import ClassificationTask -from torchmetrics.utilities.prints import rank_zero_warn -import torch -import torch.nn.functional as F +from torchmetrics.utilities.compute import normalize_logits_if_needed def _brier_decomposition( @@ -62,6 +54,7 @@ def _brier_decomposition( resolution: Tensor, scalar, the resolution component of the decomposition. reliability: Tensor, scalar, the reliability component of the decomposition. + """ n, nlabels = probabilities.shape # Implicit rank check. @@ -109,8 +102,8 @@ def _mean_brier_score(labels: torch.Tensor, probabilities: torch.Tensor = None) Returns: torch.Tensor: Tensor of shape [N1, N2, ...] consisting of the Brier score contribution from each element. The full-dataset Brier score is the average of these values. - """ + """ nlabels = probabilities.shape[-1] flat_probabilities = probabilities.view(-1, nlabels) flat_labels = labels.view(-1) diff --git a/tests/unittests/classification/test_brier.py b/tests/unittests/classification/test_brier.py index b7aeb443112..3e2ab351cba 100644 --- a/tests/unittests/classification/test_brier.py +++ b/tests/unittests/classification/test_brier.py @@ -11,21 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial -import numpy as np -import pytest -import torch -from scipy.special import expit as sigmoid # from sklearn.metrics import brier as sk_brier from torch import tensor from torchmetrics.classification.brier import ( BinaryBrier, - Brier, MulticlassBrier, - MultilabelBrier, ) # from torchmetrics.functional.classification.brier import ( @@ -33,12 +26,8 @@ # multiclass_brier, # multilabel_brier, # ) -from torchmetrics.metric import Metric -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 -from unittests import NUM_CLASSES, THRESHOLD from unittests._helpers import seed_all -from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index -from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests._helpers.testers import MetricTester seed_all(42)