Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Added `MaskedBinaryAUROC` implementation to classification domain ([#3096](https://github.com/Lightning-AI/torchmetrics/issues/3096))


### Changed
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.classification.accuracy import Accuracy, BinaryAccuracy, MulticlassAccuracy, MultilabelAccuracy
from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MulticlassAUROC, MultilabelAUROC
from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MaskedBinaryAUROC, MulticlassAUROC, MultilabelAUROC
from torchmetrics.classification.average_precision import (
AveragePrecision,
BinaryAveragePrecision,
Expand Down Expand Up @@ -172,6 +172,7 @@
"HingeLoss",
"JaccardIndex",
"LogAUC",
"MaskedBinaryAUROC",
"MatthewsCorrCoef",
"MulticlassAUROC",
"MulticlassAccuracy",
Expand Down
134 changes: 128 additions & 6 deletions src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from collections.abc import Sequence
from typing import Any, Optional, Union

import torch
from torch import Tensor
from typing_extensions import Literal

Expand All @@ -38,7 +39,7 @@
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["BinaryAUROC.plot", "MulticlassAUROC.plot", "MultilabelAUROC.plot"]
__doctest_skip__ = ["BinaryAUROC.plot", "MaskedBinaryAUROC.plot", "MulticlassAUROC.plot", "MultilabelAUROC.plot"]


class BinaryAUROC(BinaryPrecisionRecallCurve):
Expand Down Expand Up @@ -167,6 +168,124 @@ def plot( # type: ignore[override]
return self._plot(val, ax)


class MaskedBinaryAUROC(BinaryAUROC):
r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for binary tasks with masking.

The Masked AUROC score summarizes the ROC curve into an single number that describes the performance of a model for
multiple thresholds at the same time with an output mask.
Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.

As input to ``forward`` and ``update`` the metric accepts the following input:

- ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)`` containing probabilities or logits for
each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
sigmoid per element.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` containing ground truth labels, and
therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the
positive class.
- ``mask`` (:class:`~torch.Tensor`): A boolean tensor of shape ``(N, ...)`` indicating which elements to include
in the metric computation. Elements with a value of `True` will be included, while elements with a value of
`False` will be ignored.

As output to ``forward`` and ``compute`` the metric returns the following output:

- ``b_auroc`` (:class:`~torch.Tensor`): A single scalar with the auroc score of unmasked elements.

Additional dimension ``...`` will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a
binned version that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will
activate the non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the
`thresholds` argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
size :math:`\mathcal{O}(n_{thresholds})` (constant memory).

Args:
max_fpr: If not ``None``, calculates standardized partial AUC over the range ``[0, max_fpr]``.
thresholds:
Can be one of:

- If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
all the data. Most accurate but also most memory consuming approach.
- If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
0 to 1 as bins for the calculation.
- If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.

validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> import torch
>>> from torch import tensor
>>> from torchmetrics.classification import MaskedBinaryAUROC
>>> preds = tensor([0, 0.5, 0.7, 0.8])
>>> target = tensor([0, 1, 1, 0])
>>> mask = tensor([1, 1, 0, 1], dtype=torch.bool)
>>> metric = MaskedBinaryAUROC(thresholds=None)
>>> metric(preds, target, mask)
tensor(0.5000)
>>> b_auroc = MaskedBinaryAUROC(thresholds=5)
>>> b_auroc(preds, target, mask)
tensor(0.5000)

"""

def update(self, preds: Tensor, target: Tensor, mask: Tensor = None) -> None:
"""Update the state with the new data."""
if mask is not None:
if mask.dtype != torch.bool:
raise ValueError(f"Mask must be boolean, got {mask.dtype}")
if mask.shape != preds.shape:
raise ValueError(f"Mask shape {mask.shape} must match preds/target shape {preds.shape}")
preds = preds[mask]
target = target[mask]
super().update(preds, target) # call the original BinaryAUROC update

def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single
>>> import torch
>>> from torchmetrics.classification import MaskedBinaryAUROC
>>> metric = MaskedBinaryAUROC()
>>> metric.update(torch.rand(20,), torch.randint(2, (20,)), mask=torch.rand(20,) > 0.5)
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.classification import MaskedBinaryAUROC
>>> metric = MaskedBinaryAUROC()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.rand(20,), torch.randint(2, (20,)), mask=torch.rand(20,) > 0.5))
>>> fig_, ax_ = metric.plot(values)

"""
return self._plot(val, ax)


class MulticlassAUROC(MulticlassPrecisionRecallCurve):
r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multiclass tasks.

Expand Down Expand Up @@ -482,10 +601,11 @@ class AUROC(_ClassificationTaskWrapper):
corresponds to random guessing.

This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the
``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
:class:`~torchmetrics.classification.BinaryAUROC`, :class:`~torchmetrics.classification.MulticlassAUROC` and
:class:`~torchmetrics.classification.MultilabelAUROC` for the specific details of each argument influence and
examples.
``task`` argument to either ``'binary'``, ``'maskedbinary'``, ``'multiclass'`` or ``'multilabel'``.
See the documentation of
:class:`~torchmetrics.classification.BinaryAUROC`, :class:`~torchmetrics.classification.MaskedBinaryAUROC`,
:class:`~torchmetrics.classification.MulticlassAUROC` and :class:`~torchmetrics.classification.MultilabelAUROC`
for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
Expand All @@ -509,7 +629,7 @@ class AUROC(_ClassificationTaskWrapper):

def __new__( # type: ignore[misc]
cls: type["AUROC"],
task: Literal["binary", "multiclass", "multilabel"],
task: Literal["binary", "maskedbinary", "multiclass", "multilabel"],
thresholds: Optional[Union[int, list[float], Tensor]] = None,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
Expand All @@ -524,6 +644,8 @@ def __new__( # type: ignore[misc]
kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args})
if task == ClassificationTask.BINARY:
return BinaryAUROC(max_fpr, **kwargs)
if task == ClassificationTask.MASKEDBINARY:
return MaskedBinaryAUROC(max_fpr, **kwargs)
if task == ClassificationTask.MULTICLASS:
if not isinstance(num_classes, int):
raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def _name() -> str:
return "Classification"

BINARY = "binary"
MASKEDBINARY = "maskedbinary"
MULTICLASS = "multiclass"
MULTILABEL = "multilabel"

Expand Down
28 changes: 27 additions & 1 deletion tests/unittests/classification/_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, NamedTuple

import pytest
import torch
Expand Down Expand Up @@ -112,6 +112,32 @@ def _logsoftmax(x: Tensor, dim: int = -1) -> Tensor:
)


class _MaskInput(NamedTuple):
preds: Tensor
target: Tensor
mask: Tensor


_masked_binary_cases = (
pytest.param(
_MaskInput(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)),
mask=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.bool),
),
id="input[single_dim-probs]",
),
pytest.param(
_MaskInput(
preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE)),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)),
mask=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.bool),
),
id="input[single_dim-logits]",
),
)


def _multiclass_with_missing_class(*shape: Any, num_classes=NUM_CLASSES):
"""Generate multiclass input where a class is missing.

Expand Down
Loading
Loading