Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Added `MaskedBinaryAUROC` implementation to classification domain ([#3096](https://github.com/Lightning-AI/torchmetrics/issues/3096))


### Changed
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.classification.accuracy import Accuracy, BinaryAccuracy, MulticlassAccuracy, MultilabelAccuracy
from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MulticlassAUROC, MultilabelAUROC
from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MaskedBinaryAUROC, MulticlassAUROC, MultilabelAUROC
from torchmetrics.classification.average_precision import (
AveragePrecision,
BinaryAveragePrecision,
Expand Down Expand Up @@ -172,6 +172,7 @@
"HingeLoss",
"JaccardIndex",
"LogAUC",
"MaskedBinaryAUROC",
"MatthewsCorrCoef",
"MulticlassAUROC",
"MulticlassAccuracy",
Expand Down
134 changes: 128 additions & 6 deletions src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from collections.abc import Sequence
from typing import Any, Optional, Union

import torch
from torch import Tensor
from typing_extensions import Literal

Expand All @@ -38,7 +39,7 @@
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["BinaryAUROC.plot", "MulticlassAUROC.plot", "MultilabelAUROC.plot"]
__doctest_skip__ = ["BinaryAUROC.plot", "MaskedBinaryAUROC.plot", "MulticlassAUROC.plot", "MultilabelAUROC.plot"]


class BinaryAUROC(BinaryPrecisionRecallCurve):
Expand Down Expand Up @@ -167,6 +168,124 @@ def plot( # type: ignore[override]
return self._plot(val, ax)


class MaskedBinaryAUROC(BinaryAUROC):
r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for binary tasks with masking.

The Masked AUROC score summarizes the ROC curve into an single number that describes the performance of a model for
multiple thresholds at the same time with an output mask.
Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing.

As input to ``forward`` and ``update`` the metric accepts the following input:

- ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)`` containing probabilities or logits for
each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
sigmoid per element.
- ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` containing ground truth labels, and
therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the
positive class.
- ``mask`` (:class:`~torch.Tensor`): A boolean tensor of shape ``(N, ...)`` indicating which elements to include
in the metric computation. Elements with a value of `True` will be included, while elements with a value of
`False` will be ignored.

As output to ``forward`` and ``compute`` the metric returns the following output:

- ``b_auroc`` (:class:`~torch.Tensor`): A single scalar with the auroc score of unmasked elements.

Additional dimension ``...`` will be flattened into the batch dimension.

The implementation both supports calculating the metric in a non-binned but accurate version and a
binned version that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will
activate the non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the
`thresholds` argument to either an integer, list or a 1d tensor will use a binned version that uses memory of
size :math:`\mathcal{O}(n_{thresholds})` (constant memory).

Args:
max_fpr: If not ``None``, calculates standardized partial AUC over the range ``[0, max_fpr]``.
thresholds:
Can be one of:

- If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from
all the data. Most accurate but also most memory consuming approach.
- If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from
0 to 1 as bins for the calculation.
- If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.

validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> import torch
>>> from torch import tensor
>>> from torchmetrics.classification import MaskedBinaryAUROC
>>> preds = tensor([0, 0.5, 0.7, 0.8])
>>> target = tensor([0, 1, 1, 0])
>>> mask = tensor([1, 1, 0, 1], dtype=torch.bool)
>>> metric = MaskedBinaryAUROC(thresholds=None)
>>> metric(preds, target, mask)
tensor(0.5000)
>>> b_auroc = MaskedBinaryAUROC(thresholds=5)
>>> b_auroc(preds, target, mask)
tensor(0.5000)

"""

def update(self, preds: Tensor, target: Tensor, mask: Optional[Tensor] = None) -> None:
"""Update the state with the new data."""
if mask is not None:
if mask.dtype != torch.bool:
raise ValueError(f"Mask must be boolean, got {mask.dtype}")
if mask.shape != preds.shape:
raise ValueError(f"Mask shape {mask.shape} must match preds/target shape {preds.shape}")
preds = preds[mask]
target = target[mask]
super().update(preds, target) # call the original BinaryAUROC update

def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single
>>> import torch
>>> from torchmetrics.classification import MaskedBinaryAUROC
>>> metric = MaskedBinaryAUROC()
>>> metric.update(torch.rand(20,), torch.randint(2, (20,)), mask=torch.rand(20,) > 0.5)
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.classification import MaskedBinaryAUROC
>>> metric = MaskedBinaryAUROC()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(torch.rand(20,), torch.randint(2, (20,)), mask=torch.rand(20,) > 0.5))
>>> fig_, ax_ = metric.plot(values)

"""
return self._plot(val, ax)


class MulticlassAUROC(MulticlassPrecisionRecallCurve):
r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multiclass tasks.

Expand Down Expand Up @@ -482,10 +601,11 @@ class AUROC(_ClassificationTaskWrapper):
corresponds to random guessing.

This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the
``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of
:class:`~torchmetrics.classification.BinaryAUROC`, :class:`~torchmetrics.classification.MulticlassAUROC` and
:class:`~torchmetrics.classification.MultilabelAUROC` for the specific details of each argument influence and
examples.
``task`` argument to either ``'binary'``, ``'maskedbinary'``, ``'multiclass'`` or ``'multilabel'``.
See the documentation of
:class:`~torchmetrics.classification.BinaryAUROC`, :class:`~torchmetrics.classification.MaskedBinaryAUROC`,
:class:`~torchmetrics.classification.MulticlassAUROC` and :class:`~torchmetrics.classification.MultilabelAUROC`
for the specific details of each argument influence and examples.

Legacy Example:
>>> from torch import tensor
Expand All @@ -509,7 +629,7 @@ class AUROC(_ClassificationTaskWrapper):

def __new__( # type: ignore[misc]
cls: type["AUROC"],
task: Literal["binary", "multiclass", "multilabel"],
task: Literal["binary", "maskedbinary", "multiclass", "multilabel"],
thresholds: Optional[Union[int, list[float], Tensor]] = None,
num_classes: Optional[int] = None,
num_labels: Optional[int] = None,
Expand All @@ -524,6 +644,8 @@ def __new__( # type: ignore[misc]
kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args})
if task == ClassificationTask.BINARY:
return BinaryAUROC(max_fpr, **kwargs)
if task == ClassificationTask.MASKEDBINARY:
return MaskedBinaryAUROC(max_fpr, **kwargs)
if task == ClassificationTask.MULTICLASS:
if not isinstance(num_classes, int):
raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`")
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def _name() -> str:
return "Classification"

BINARY = "binary"
MASKEDBINARY = "maskedbinary"
MULTICLASS = "multiclass"
MULTILABEL = "multilabel"

Expand Down
28 changes: 27 additions & 1 deletion tests/unittests/classification/_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, NamedTuple

import pytest
import torch
Expand Down Expand Up @@ -254,6 +254,32 @@ def _multiclass_with_missing_class(*shape: Any, num_classes=NUM_CLASSES):
),
)


class _MaskInput(NamedTuple):
preds: Tensor
target: Tensor
mask: Tensor


_masked_binary_cases = (
pytest.param(
_MaskInput(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)),
mask=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.bool),
),
id="input[single_dim-probs]",
),
pytest.param(
_MaskInput(
preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE)),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)),
mask=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.bool),
),
id="input[single_dim-logits]",
),
)

# Generate edge multilabel edge case, where nothing matches (scores are undefined)
__temp_preds = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES))
__temp_target = abs(__temp_preds - 1)
Expand Down
Loading