diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ccf8cae705..720cbfa8c89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Added `MaskedBinaryAUROC` implementation to classification domain ([#3096](https://github.com/Lightning-AI/torchmetrics/issues/3096)) ### Changed diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 3e41f565879..7ea99f0d4ab 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.classification.accuracy import Accuracy, BinaryAccuracy, MulticlassAccuracy, MultilabelAccuracy -from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MulticlassAUROC, MultilabelAUROC +from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MaskedBinaryAUROC, MulticlassAUROC, MultilabelAUROC from torchmetrics.classification.average_precision import ( AveragePrecision, BinaryAveragePrecision, @@ -172,6 +172,7 @@ "HingeLoss", "JaccardIndex", "LogAUC", + "MaskedBinaryAUROC", "MatthewsCorrCoef", "MulticlassAUROC", "MulticlassAccuracy", diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 68960e203e9..9b4ae72b66e 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -14,6 +14,7 @@ from collections.abc import Sequence from typing import Any, Optional, Union +import torch from torch import Tensor from typing_extensions import Literal @@ -38,7 +39,7 @@ from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: - __doctest_skip__ = ["BinaryAUROC.plot", "MulticlassAUROC.plot", "MultilabelAUROC.plot"] + __doctest_skip__ = ["BinaryAUROC.plot", "MaskedBinaryAUROC.plot", "MulticlassAUROC.plot", "MultilabelAUROC.plot"] class BinaryAUROC(BinaryPrecisionRecallCurve): @@ -167,6 +168,124 @@ def plot( # type: ignore[override] return self._plot(val, ax) +class MaskedBinaryAUROC(BinaryAUROC): + r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for binary tasks with masking. + + The Masked AUROC score summarizes the ROC curve into an single number that describes the performance of a model for + multiple thresholds at the same time with an output mask. + Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)`` containing probabilities or logits for + each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` containing ground truth labels, and + therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the + positive class. + - ``mask`` (:class:`~torch.Tensor`): A boolean tensor of shape ``(N, ...)`` indicating which elements to include + in the metric computation. Elements with a value of `True` will be included, while elements with a value of + `False` will be ignored. + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``b_auroc`` (:class:`~torch.Tensor`): A single scalar with the auroc score of unmasked elements. + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a + binned version that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will + activate the non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the + `thresholds` argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + max_fpr: If not ``None``, calculates standardized partial AUC over the range ``[0, max_fpr]``. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> from torch import tensor + >>> from torchmetrics.classification import MaskedBinaryAUROC + >>> preds = tensor([0, 0.5, 0.7, 0.8]) + >>> target = tensor([0, 1, 1, 0]) + >>> mask = tensor([1, 1, 0, 1], dtype=torch.bool) + >>> metric = MaskedBinaryAUROC(thresholds=None) + >>> metric(preds, target, mask) + tensor(0.5000) + >>> b_auroc = MaskedBinaryAUROC(thresholds=5) + >>> b_auroc(preds, target, mask) + tensor(0.5000) + + """ + + def update(self, preds: Tensor, target: Tensor, mask: Optional[Tensor] = None) -> None: + """Update the state with the new data.""" + if mask is not None: + if mask.dtype != torch.bool: + raise ValueError(f"Mask must be boolean, got {mask.dtype}") + if mask.shape != preds.shape: + raise ValueError(f"Mask shape {mask.shape} must match preds/target shape {preds.shape}") + preds = preds[mask] + target = target[mask] + super().update(preds, target) # call the original BinaryAUROC update + + def plot( # type: ignore[override] + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single + >>> import torch + >>> from torchmetrics.classification import MaskedBinaryAUROC + >>> metric = MaskedBinaryAUROC() + >>> metric.update(torch.rand(20,), torch.randint(2, (20,)), mask=torch.rand(20,) > 0.5) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.classification import MaskedBinaryAUROC + >>> metric = MaskedBinaryAUROC() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(20,), torch.randint(2, (20,)), mask=torch.rand(20,) > 0.5)) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) + + class MulticlassAUROC(MulticlassPrecisionRecallCurve): r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multiclass tasks. @@ -482,10 +601,11 @@ class AUROC(_ClassificationTaskWrapper): corresponds to random guessing. This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the - ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of - :class:`~torchmetrics.classification.BinaryAUROC`, :class:`~torchmetrics.classification.MulticlassAUROC` and - :class:`~torchmetrics.classification.MultilabelAUROC` for the specific details of each argument influence and - examples. + ``task`` argument to either ``'binary'``, ``'maskedbinary'``, ``'multiclass'`` or ``'multilabel'``. + See the documentation of + :class:`~torchmetrics.classification.BinaryAUROC`, :class:`~torchmetrics.classification.MaskedBinaryAUROC`, + :class:`~torchmetrics.classification.MulticlassAUROC` and :class:`~torchmetrics.classification.MultilabelAUROC` + for the specific details of each argument influence and examples. Legacy Example: >>> from torch import tensor @@ -509,7 +629,7 @@ class AUROC(_ClassificationTaskWrapper): def __new__( # type: ignore[misc] cls: type["AUROC"], - task: Literal["binary", "multiclass", "multilabel"], + task: Literal["binary", "maskedbinary", "multiclass", "multilabel"], thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, @@ -524,6 +644,8 @@ def __new__( # type: ignore[misc] kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTask.BINARY: return BinaryAUROC(max_fpr, **kwargs) + if task == ClassificationTask.MASKEDBINARY: + return MaskedBinaryAUROC(max_fpr, **kwargs) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index 155f1bb8f60..02fa62ea46b 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -117,6 +117,7 @@ def _name() -> str: return "Classification" BINARY = "binary" + MASKEDBINARY = "maskedbinary" MULTICLASS = "multiclass" MULTILABEL = "multilabel" diff --git a/tests/unittests/classification/_inputs.py b/tests/unittests/classification/_inputs.py index 9a4981f041c..e61942ff1fb 100644 --- a/tests/unittests/classification/_inputs.py +++ b/tests/unittests/classification/_inputs.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, NamedTuple import pytest import torch @@ -254,6 +254,32 @@ def _multiclass_with_missing_class(*shape: Any, num_classes=NUM_CLASSES): ), ) + +class _MaskInput(NamedTuple): + preds: Tensor + target: Tensor + mask: Tensor + + +_masked_binary_cases = ( + pytest.param( + _MaskInput( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + mask=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.bool), + ), + id="input[single_dim-probs]", + ), + pytest.param( + _MaskInput( + preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + mask=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.bool), + ), + id="input[single_dim-logits]", + ), +) + # Generate edge multilabel edge case, where nothing matches (scores are undefined) __temp_preds = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) __temp_target = abs(__temp_preds - 1) diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index df12366b941..c0e843dd9b9 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -13,6 +13,7 @@ # limitations under the License. from functools import partial +from typing import Any, Callable, Optional import numpy as np import pytest @@ -20,16 +21,23 @@ from scipy.special import expit as sigmoid from scipy.special import softmax from sklearn.metrics import roc_auc_score as sk_roc_auc_score +from torch import Tensor -from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MulticlassAUROC, MultilabelAUROC +from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MaskedBinaryAUROC, MulticlassAUROC, MultilabelAUROC from torchmetrics.functional.classification.auroc import binary_auroc, multiclass_auroc, multilabel_auroc from torchmetrics.functional.classification.roc import binary_roc from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all -from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index -from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests._helpers.testers import ( + MetricTester, + inject_ignore_index, + remove_ignore_index, + remove_ignore_index_groups, +) +from unittests._helpers.testers import _assert_requires_grad as _core_assert_requires_grad +from unittests.classification._inputs import _binary_cases, _masked_binary_cases, _multiclass_cases, _multilabel_cases seed_all(42) @@ -140,6 +148,157 @@ def test_binary_auroc_threshold_arg(self, inputs, threshold_fn): assert torch.allclose(ap1, ap2) +def _reference_sklearn_masked_auroc_binary(preds, target, mask, max_fpr, ignore_index): + preds = preds.numpy().flatten() + target = target.numpy().flatten() + mask = mask.numpy().flatten() if mask is not None else None + if not ((preds > 0) & (preds < 1)).all(): + preds = sigmoid(preds) + target, preds, mask = remove_ignore_index_groups(target=target, preds=preds, groups=mask, ignore_index=ignore_index) + if mask is not None: + preds, target = preds[mask], target[mask] + return sk_roc_auc_score(target, preds, max_fpr=max_fpr) + + +def _assert_requires_grad(metric: Metric, pl_result: Any, key: Optional[str] = None) -> None: + if isinstance(pl_result, dict) and key is None: + for res in pl_result.values(): + _core_assert_requires_grad(metric, res) + else: + _core_assert_requires_grad(metric, pl_result, key) + + +class MaskedBinaryAUROCTester(MetricTester): + """Tester class for `MaskedBinaryAUROC` metric overriding some defaults.""" + + @staticmethod + def run_differentiability_test( + preds: Tensor, + target: Tensor, + metric_module: Metric, + metric_functional: Optional[Callable] = None, + metric_args: Optional[dict] = None, + mask: Optional[Tensor] = None, + ) -> None: + """Test if a metric is differentiable or not. + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_module: the metric module to test + metric_functional: functional version of the metric + metric_args: dict with additional arguments used for class initialization + mask: Tensor with binary mask indicating valid elements. + + """ + metric_args = metric_args or {} + # only floating point tensors can require grad + metric = metric_module(**metric_args) + if preds.is_floating_point(): + preds.requires_grad = True + out = metric(preds[0, :2], target[0, :2], mask[0, :2] if mask is not None else None) + + # Check if requires_grad matches is_differentiable attribute + _assert_requires_grad(metric, out) + + if metric.is_differentiable and metric_functional is not None: + # check for numerical correctness + assert torch.autograd.gradcheck( + partial(metric_functional, **metric_args), (preds[0, :2].double(), target[0, :2]) + ) + + # reset as else it will carry over to other tests + preds.requires_grad = False + + +@pytest.mark.parametrize("inputs", _masked_binary_cases) +class TestMaskedBinaryAUROC(MaskedBinaryAUROCTester): + """Test class for `MaskedBinaryAUROC` metric.""" + + @pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_masked_binary_auroc(self, inputs, ddp, max_fpr, ignore_index): + """Test class implementation of metric.""" + preds, target, mask = inputs + + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MaskedBinaryAUROC, + reference_metric=partial( + _reference_sklearn_masked_auroc_binary, max_fpr=max_fpr, ignore_index=ignore_index + ), + metric_args={ + "max_fpr": max_fpr, + "thresholds": None, + "ignore_index": ignore_index, + }, + mask=mask, + fragment_kwargs=True, + ) + + def test_masked_binary_auroc_differentiability(self, inputs): + """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + preds, target, mask = inputs + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MaskedBinaryAUROC, + metric_functional=binary_auroc, + metric_args={"thresholds": None}, + mask=mask, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_masked_binary_auroc_dtype_cpu(self, inputs, dtype): + """Test dtype support of the metric on CPU.""" + preds, target, mask = inputs + + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MaskedBinaryAUROC, + metric_args={"thresholds": None}, + dtype=dtype, + mask=mask, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_masked_binary_auroc_dtype_gpu(self, inputs, dtype): + """Test dtype support of the metric on GPU.""" + preds, target, mask = inputs + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MaskedBinaryAUROC, + metric_args={"thresholds": None}, + dtype=dtype, + mask=mask, + ) + + def test_mask_error_on_wrong_dtype_and_shape(self, inputs): + """Test that errors are raised on wrong mask dtype and shape.""" + preds, target, mask = inputs + + # wrong dtype: mask should be boolean + mask_wrong_dtype = torch.randint(high=2, size=preds.shape, dtype=torch.int32) + with pytest.raises(ValueError, match="Mask must be boolean, got "): + MaskedBinaryAUROC()(preds, target, mask=mask_wrong_dtype) + + # wrong shape: mask must match preds/target shape + mask_wrong_shape = torch.randint(high=2, size=(preds.shape[0],), dtype=torch.bool) + with pytest.raises(ValueError, match="Mask shape "): + MaskedBinaryAUROC()(preds, target, mask=mask_wrong_shape) + + def _reference_sklearn_auroc_multiclass(preds, target, average="macro", ignore_index=None): preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) target = target.numpy().flatten()