From 4a5619bc59aba04c64fd3457c1adcd2adcb0bf86 Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Thu, 18 Sep 2025 17:16:10 -0400 Subject: [PATCH 1/7] Initial commit of masked binary auroc --- src/torchmetrics/classification/__init__.py | 3 +- src/torchmetrics/classification/auroc.py | 113 ++++++++++++++++++ tests/unittests/classification/_inputs.py | 28 +++++ tests/unittests/classification/test_auroc.py | 116 ++++++++++++++++++- 4 files changed, 256 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 3e41f565879..f24bacaa62c 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.classification.accuracy import Accuracy, BinaryAccuracy, MulticlassAccuracy, MultilabelAccuracy -from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MulticlassAUROC, MultilabelAUROC +from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MaskedBinaryAUROC, MulticlassAUROC, MultilabelAUROC from torchmetrics.classification.average_precision import ( AveragePrecision, BinaryAveragePrecision, @@ -136,6 +136,7 @@ "Accuracy", "AveragePrecision", "BinaryAUROC", + "MaskedBinaryAUROC", "BinaryAccuracy", "BinaryAveragePrecision", "BinaryCalibrationError", diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 68960e203e9..718b76ac93c 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -166,7 +166,120 @@ def plot( # type: ignore[override] """ return self._plot(val, ax) +class MaskedBinaryAUROC(BinaryAUROC): + r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for binary tasks with masking. + The Masked AUROC score summarizes the ROC curve into an single number that describes the performance of a model for + multiple thresholds at the same time with an output mask. + Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)`` containing probabilities or logits for + each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` containing ground truth labels, and + therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the + positive class. + - ``mask`` (:class:`~torch.Tensor`): A boolean tensor of shape ``(N, ...)`` indicating which elements to include + in the metric computation. Elements with a value of `True` will be included, while elements with a value of + `False` will be ignored. + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``b_auroc`` (:class:`~torch.Tensor`): A single scalar with the auroc score of unmasked elements. + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a + binned version that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will + activate the non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the + `thresholds` argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + max_fpr: If not ``None``, calculates standardized partial AUC over the range ``[0, max_fpr]``. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> import torch + >>> from torch import tensor + >>> from torchmetrics.classification import MaskedBinaryAUROC + >>> preds = tensor([0, 0.5, 0.7, 0.8]) + >>> target = tensor([0, 1, 1, 0]) + >>> mask = tensor([1, 1, 0, 1], dtype=torch.bool) + >>> metric = MaskedBinaryAUROC(thresholds=None) + >>> metric(preds, target, mask) + tensor(0.5000) + >>> b_auroc = MaskedBinaryAUROC(thresholds=5) + >>> b_auroc(preds, target, mask) + tensor(0.5000) + + """ + + def update(self, preds: Tensor, target: Tensor, mask: Tensor = None) -> None: + if mask is not None: + if mask.shape != preds.shape: + raise ValueError(f"Mask shape {mask.shape} must match preds/target shape {preds.shape}") + preds = preds[mask] + target = target[mask] + super().update(preds, target) # call the original BinaryAUROC update + + def plot( # type: ignore[override] + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single + >>> import torch + >>> from torchmetrics.classification import MaskedBinaryAUROC + >>> metric = MaskedBinaryAUROC() + >>> metric.update(torch.rand(20,), torch.randint(2, (20,)), mask=torch.rand(20,) > 0.5) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.classification import MaskedBinaryAUROC + >>> metric = MaskedBinaryAUROC() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(20,), torch.randint(2, (20,)), mask=torch.rand(20,) > 0.5)) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) + class MulticlassAUROC(MulticlassPrecisionRecallCurve): r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multiclass tasks. diff --git a/tests/unittests/classification/_inputs.py b/tests/unittests/classification/_inputs.py index 9a4981f041c..14a43a543cf 100644 --- a/tests/unittests/classification/_inputs.py +++ b/tests/unittests/classification/_inputs.py @@ -112,6 +112,34 @@ def _logsoftmax(x: Tensor, dim: int = -1) -> Tensor: ) +_masked_binary_cases = ( + pytest.param( + _GroupInput( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + groups=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.bool), + ), + id="input[single_dim-labels]", + ), + pytest.param( + _GroupInput( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + groups=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.bool), + ), + id="input[single_dim-probs]", + ), + pytest.param( + _GroupInput( + preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + groups=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.bool), + ), + id="input[single_dim-logits]", + ), +) + + def _multiclass_with_missing_class(*shape: Any, num_classes=NUM_CLASSES): """Generate multiclass input where a class is missing. diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index df12366b941..34354e1776c 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -21,15 +21,15 @@ from scipy.special import softmax from sklearn.metrics import roc_auc_score as sk_roc_auc_score -from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MulticlassAUROC, MultilabelAUROC +from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MaskedBinaryAUROC, MulticlassAUROC, MultilabelAUROC from torchmetrics.functional.classification.auroc import binary_auroc, multiclass_auroc, multilabel_auroc from torchmetrics.functional.classification.roc import binary_roc from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all -from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index -from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index, remove_ignore_index_groups +from unittests.classification._inputs import _binary_cases, _masked_binary_cases, _multiclass_cases, _multilabel_cases seed_all(42) @@ -140,6 +140,116 @@ def test_binary_auroc_threshold_arg(self, inputs, threshold_fn): assert torch.allclose(ap1, ap2) +def _reference_sklearn_masked_auroc_binary(preds, target, mask, max_fpr, ignore_index): + + preds = preds.numpy().flatten() + target = target.numpy().flatten() + mask = mask.numpy().flatten() if mask is not None else None + + if not ((preds > 0) & (preds < 1)).all(): + preds = sigmoid(preds) + target, preds, mask = remove_ignore_index_groups(target=target, preds=preds, groups = mask, ignore_index=ignore_index) + if mask is not None: + preds, target = preds[mask], target[mask] + return sk_roc_auc_score(target, preds, max_fpr=max_fpr) + + +@pytest.mark.parametrize("inputs", _masked_binary_cases) +class TestMaskedBinaryAUROC(MetricTester): + """Test class for `MaskedBinaryAUROC` metric.""" + + @pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_binary_auroc(self, inputs, ddp, max_fpr, ignore_index): + """Test class implementation of metric.""" + preds, target, mask = inputs + + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MaskedBinaryAUROC, + reference_metric=partial(_reference_sklearn_masked_auroc_binary, max_fpr=max_fpr, ignore_index=ignore_index), + metric_args={ + "max_fpr": max_fpr, + "thresholds": None, + "ignore_index": ignore_index, + }, + mask=mask, + fragment_kwargs=True, + ) + # self.run_class_metric_test( + # ddp=ddp, + # preds=preds, + # target=target, + # metric_class=MaskedBinaryAUROC, + # reference_metric=partial(_reference_sklearn_masked_auroc_binary, max_fpr=max_fpr, ignore_index=ignore_index), + # metric_args={ + # "max_fpr": max_fpr, + # "thresholds": None, + # "ignore_index": ignore_index, + # }, + # mask=mask + # ) + + + # def test_binary_auroc_differentiability(self, inputs): + # """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + # preds, target, mask = inputs + # self.run_differentiability_test( + # preds=preds, + # target=target, + # metric_module=MaskedBinaryAUROC, + # metric_functional=binary_auroc, + # metric_args={"thresholds": None}, + # ) + + # @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + # def test_binary_auroc_dtype_cpu(self, inputs, dtype): + # """Test dtype support of the metric on CPU.""" + # preds, target = inputs + + # if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + # pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") + # self.run_precision_test_cpu( + # preds=preds, + # target=target, + # metric_module=BinaryAUROC, + # metric_functional=binary_auroc, + # metric_args={"thresholds": None}, + # dtype=dtype, + # ) + + # @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + # @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + # def test_binary_auroc_dtype_gpu(self, inputs, dtype): + # """Test dtype support of the metric on GPU.""" + # preds, target = inputs + # self.run_precision_test_gpu( + # preds=preds, + # target=target, + # metric_module=BinaryAUROC, + # metric_functional=binary_auroc, + # metric_args={"thresholds": None}, + # dtype=dtype, + # ) + + # @pytest.mark.parametrize("threshold_fn", [lambda x: x, lambda x: x.numpy().tolist()], ids=["as tensor", "as list"]) + # def test_binary_auroc_threshold_arg(self, inputs, threshold_fn): + # """Test that different types of `thresholds` argument lead to same result.""" + # preds, target = inputs + + # for pred, true in zip(preds, target): + # _, _, t = binary_roc(pred, true, thresholds=None) + # ap1 = binary_auroc(pred, true, thresholds=None) + # ap2 = binary_auroc(pred, true, thresholds=threshold_fn(t.flip(0))) + # assert torch.allclose(ap1, ap2) + + def _reference_sklearn_auroc_multiclass(preds, target, average="macro", ignore_index=None): preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) target = target.numpy().flatten() From 129e656fddc18ca1bcdde1de856846a284510a2e Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Fri, 19 Sep 2025 10:56:56 -0400 Subject: [PATCH 2/7] Bebugging testcases and linting --- CHANGELOG.md | 2 +- src/torchmetrics/classification/__init__.py | 2 +- src/torchmetrics/classification/auroc.py | 7 +- tests/unittests/classification/_inputs.py | 24 ++- tests/unittests/classification/test_auroc.py | 185 +++++++++++-------- 5 files changed, 128 insertions(+), 92 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ccf8cae705..720cbfa8c89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Added `MaskedBinaryAUROC` implementation to classification domain ([#3096](https://github.com/Lightning-AI/torchmetrics/issues/3096)) ### Changed diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index f24bacaa62c..7ea99f0d4ab 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -136,7 +136,6 @@ "Accuracy", "AveragePrecision", "BinaryAUROC", - "MaskedBinaryAUROC", "BinaryAccuracy", "BinaryAveragePrecision", "BinaryCalibrationError", @@ -173,6 +172,7 @@ "HingeLoss", "JaccardIndex", "LogAUC", + "MaskedBinaryAUROC", "MatthewsCorrCoef", "MulticlassAUROC", "MulticlassAccuracy", diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 718b76ac93c..f194bfd452d 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -166,11 +166,12 @@ def plot( # type: ignore[override] """ return self._plot(val, ax) + class MaskedBinaryAUROC(BinaryAUROC): r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for binary tasks with masking. The Masked AUROC score summarizes the ROC curve into an single number that describes the performance of a model for - multiple thresholds at the same time with an output mask. + multiple thresholds at the same time with an output mask. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing. As input to ``forward`` and ``update`` the metric accepts the following input: @@ -231,6 +232,7 @@ class MaskedBinaryAUROC(BinaryAUROC): """ def update(self, preds: Tensor, target: Tensor, mask: Tensor = None) -> None: + """Update the state with the new data.""" if mask is not None: if mask.shape != preds.shape: raise ValueError(f"Mask shape {mask.shape} must match preds/target shape {preds.shape}") @@ -279,7 +281,8 @@ def plot( # type: ignore[override] """ return self._plot(val, ax) - + + class MulticlassAUROC(MulticlassPrecisionRecallCurve): r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) for multiclass tasks. diff --git a/tests/unittests/classification/_inputs.py b/tests/unittests/classification/_inputs.py index 14a43a543cf..bc6c473b4b2 100644 --- a/tests/unittests/classification/_inputs.py +++ b/tests/unittests/classification/_inputs.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, NamedTuple import pytest import torch @@ -112,28 +112,26 @@ def _logsoftmax(x: Tensor, dim: int = -1) -> Tensor: ) +class _MaskInput(NamedTuple): + preds: Tensor + target: Tensor + mask: Tensor + + _masked_binary_cases = ( pytest.param( - _GroupInput( - preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), - groups=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.bool), - ), - id="input[single_dim-labels]", - ), - pytest.param( - _GroupInput( + _MaskInput( preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), - groups=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.bool), + mask=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.bool), ), id="input[single_dim-probs]", ), pytest.param( - _GroupInput( + _MaskInput( preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE)), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), - groups=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.bool), + mask=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.bool), ), id="input[single_dim-logits]", ), diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index 34354e1776c..ccc3aa89518 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -13,6 +13,7 @@ # limitations under the License. from functools import partial +from typing import Any, Callable, Optional import numpy as np import pytest @@ -20,6 +21,7 @@ from scipy.special import expit as sigmoid from scipy.special import softmax from sklearn.metrics import roc_auc_score as sk_roc_auc_score +from torch import Tensor from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MaskedBinaryAUROC, MulticlassAUROC, MultilabelAUROC from torchmetrics.functional.classification.auroc import binary_auroc, multiclass_auroc, multilabel_auroc @@ -28,7 +30,13 @@ from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from unittests import NUM_CLASSES from unittests._helpers import seed_all -from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index, remove_ignore_index_groups +from unittests._helpers.testers import ( + MetricTester, + inject_ignore_index, + remove_ignore_index, + remove_ignore_index_groups, +) +from unittests._helpers.testers import _assert_requires_grad as _core_assert_requires_grad from unittests.classification._inputs import _binary_cases, _masked_binary_cases, _multiclass_cases, _multilabel_cases seed_all(42) @@ -141,39 +149,90 @@ def test_binary_auroc_threshold_arg(self, inputs, threshold_fn): def _reference_sklearn_masked_auroc_binary(preds, target, mask, max_fpr, ignore_index): - preds = preds.numpy().flatten() target = target.numpy().flatten() mask = mask.numpy().flatten() if mask is not None else None - if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds, mask = remove_ignore_index_groups(target=target, preds=preds, groups = mask, ignore_index=ignore_index) + target, preds, mask = remove_ignore_index_groups(target=target, preds=preds, groups=mask, ignore_index=ignore_index) if mask is not None: preds, target = preds[mask], target[mask] return sk_roc_auc_score(target, preds, max_fpr=max_fpr) +def _assert_requires_grad(metric: Metric, pl_result: Any, key: Optional[str] = None) -> None: + if isinstance(pl_result, dict) and key is None: + for res in pl_result.values(): + _core_assert_requires_grad(metric, res) + else: + _core_assert_requires_grad(metric, pl_result, key) + + +class MaskedBinaryAUROCTester(MetricTester): + """Tester class for `MaskedBinaryAUROC` metric overriding some defaults.""" + + @staticmethod + def run_differentiability_test( + preds: Tensor, + target: Tensor, + metric_module: Metric, + metric_functional: Optional[Callable] = None, + metric_args: Optional[dict] = None, + mask: Optional[Tensor] = None, + ) -> None: + """Test if a metric is differentiable or not. + + Args: + preds: torch tensor with predictions + target: torch tensor with targets + metric_module: the metric module to test + metric_functional: functional version of the metric + metric_args: dict with additional arguments used for class initialization + mask: Tensor with binary mask indicating valid elements. + + """ + metric_args = metric_args or {} + # only floating point tensors can require grad + metric = metric_module(**metric_args) + if preds.is_floating_point(): + preds.requires_grad = True + out = metric(preds[0, :2], target[0, :2], mask[0, :2] if mask is not None else None) + + # Check if requires_grad matches is_differentiable attribute + _assert_requires_grad(metric, out) + + if metric.is_differentiable and metric_functional is not None: + # check for numerical correctness + assert torch.autograd.gradcheck( + partial(metric_functional, **metric_args), (preds[0, :2].double(), target[0, :2]) + ) + + # reset as else it will carry over to other tests + preds.requires_grad = False + + @pytest.mark.parametrize("inputs", _masked_binary_cases) -class TestMaskedBinaryAUROC(MetricTester): +class TestMaskedBinaryAUROC(MaskedBinaryAUROCTester): """Test class for `MaskedBinaryAUROC` metric.""" @pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_binary_auroc(self, inputs, ddp, max_fpr, ignore_index): + def test_masked_binary_auroc(self, inputs, ddp, max_fpr, ignore_index): """Test class implementation of metric.""" preds, target, mask = inputs - + if ignore_index is not None: target = inject_ignore_index(target, ignore_index) - + self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=MaskedBinaryAUROC, - reference_metric=partial(_reference_sklearn_masked_auroc_binary, max_fpr=max_fpr, ignore_index=ignore_index), + reference_metric=partial( + _reference_sklearn_masked_auroc_binary, max_fpr=max_fpr, ignore_index=ignore_index + ), metric_args={ "max_fpr": max_fpr, "thresholds": None, @@ -182,72 +241,48 @@ def test_binary_auroc(self, inputs, ddp, max_fpr, ignore_index): mask=mask, fragment_kwargs=True, ) - # self.run_class_metric_test( - # ddp=ddp, - # preds=preds, - # target=target, - # metric_class=MaskedBinaryAUROC, - # reference_metric=partial(_reference_sklearn_masked_auroc_binary, max_fpr=max_fpr, ignore_index=ignore_index), - # metric_args={ - # "max_fpr": max_fpr, - # "thresholds": None, - # "ignore_index": ignore_index, - # }, - # mask=mask - # ) - - - # def test_binary_auroc_differentiability(self, inputs): - # """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" - # preds, target, mask = inputs - # self.run_differentiability_test( - # preds=preds, - # target=target, - # metric_module=MaskedBinaryAUROC, - # metric_functional=binary_auroc, - # metric_args={"thresholds": None}, - # ) - - # @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - # def test_binary_auroc_dtype_cpu(self, inputs, dtype): - # """Test dtype support of the metric on CPU.""" - # preds, target = inputs - - # if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: - # pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") - # self.run_precision_test_cpu( - # preds=preds, - # target=target, - # metric_module=BinaryAUROC, - # metric_functional=binary_auroc, - # metric_args={"thresholds": None}, - # dtype=dtype, - # ) - - # @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") - # @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - # def test_binary_auroc_dtype_gpu(self, inputs, dtype): - # """Test dtype support of the metric on GPU.""" - # preds, target = inputs - # self.run_precision_test_gpu( - # preds=preds, - # target=target, - # metric_module=BinaryAUROC, - # metric_functional=binary_auroc, - # metric_args={"thresholds": None}, - # dtype=dtype, - # ) - - # @pytest.mark.parametrize("threshold_fn", [lambda x: x, lambda x: x.numpy().tolist()], ids=["as tensor", "as list"]) - # def test_binary_auroc_threshold_arg(self, inputs, threshold_fn): - # """Test that different types of `thresholds` argument lead to same result.""" - # preds, target = inputs - - # for pred, true in zip(preds, target): - # _, _, t = binary_roc(pred, true, thresholds=None) - # ap1 = binary_auroc(pred, true, thresholds=None) - # ap2 = binary_auroc(pred, true, thresholds=threshold_fn(t.flip(0))) - # assert torch.allclose(ap1, ap2) + + def test_masked_binary_auroc_differentiability(self, inputs): + """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + preds, target, mask = inputs + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MaskedBinaryAUROC, + metric_functional=binary_auroc, + metric_args={"thresholds": None}, + mask=mask, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_masked_binary_auroc_dtype_cpu(self, inputs, dtype): + """Test dtype support of the metric on CPU.""" + preds, target, mask = inputs + + if not _TORCH_GREATER_EQUAL_2_1 and (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision for torch<2.1") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MaskedBinaryAUROC, + metric_args={"thresholds": None}, + dtype=dtype, + mask=mask, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_masked_binary_auroc_dtype_gpu(self, inputs, dtype): + """Test dtype support of the metric on GPU.""" + preds, target, mask = inputs + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MaskedBinaryAUROC, + metric_args={"thresholds": None}, + dtype=dtype, + mask=mask, + ) def _reference_sklearn_auroc_multiclass(preds, target, average="macro", ignore_index=None): From 7e8d12f0a28ac930f7c24ca75cbf59814768becd Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Fri, 19 Sep 2025 11:20:14 -0400 Subject: [PATCH 3/7] Adding mask type and shape check along with testcase --- src/torchmetrics/classification/auroc.py | 3 +++ tests/unittests/classification/test_auroc.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index f194bfd452d..20f9d94df8b 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -14,6 +14,7 @@ from collections.abc import Sequence from typing import Any, Optional, Union +import torch from torch import Tensor from typing_extensions import Literal @@ -234,6 +235,8 @@ class MaskedBinaryAUROC(BinaryAUROC): def update(self, preds: Tensor, target: Tensor, mask: Tensor = None) -> None: """Update the state with the new data.""" if mask is not None: + if mask.dtype != torch.bool: + raise ValueError(f"Mask must be boolean, got {mask.dtype}") if mask.shape != preds.shape: raise ValueError(f"Mask shape {mask.shape} must match preds/target shape {preds.shape}") preds = preds[mask] diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index ccc3aa89518..c0e843dd9b9 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -284,6 +284,20 @@ def test_masked_binary_auroc_dtype_gpu(self, inputs, dtype): mask=mask, ) + def test_mask_error_on_wrong_dtype_and_shape(self, inputs): + """Test that errors are raised on wrong mask dtype and shape.""" + preds, target, mask = inputs + + # wrong dtype: mask should be boolean + mask_wrong_dtype = torch.randint(high=2, size=preds.shape, dtype=torch.int32) + with pytest.raises(ValueError, match="Mask must be boolean, got "): + MaskedBinaryAUROC()(preds, target, mask=mask_wrong_dtype) + + # wrong shape: mask must match preds/target shape + mask_wrong_shape = torch.randint(high=2, size=(preds.shape[0],), dtype=torch.bool) + with pytest.raises(ValueError, match="Mask shape "): + MaskedBinaryAUROC()(preds, target, mask=mask_wrong_shape) + def _reference_sklearn_auroc_multiclass(preds, target, average="macro", ignore_index=None): preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1])) From 45670a2512e869a266c062d48e135e44bc10ff4a Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Fri, 19 Sep 2025 14:13:40 -0400 Subject: [PATCH 4/7] Skipping matplotlib test and adding masked binary to wrapper class --- src/torchmetrics/classification/auroc.py | 15 +++++++++------ src/torchmetrics/utilities/enums.py | 1 + 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 20f9d94df8b..0c2a5367e61 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -39,7 +39,7 @@ from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE if not _MATPLOTLIB_AVAILABLE: - __doctest_skip__ = ["BinaryAUROC.plot", "MulticlassAUROC.plot", "MultilabelAUROC.plot"] + __doctest_skip__ = ["BinaryAUROC.plot", "MaskedBinaryAUROC.plot", "MulticlassAUROC.plot", "MultilabelAUROC.plot"] class BinaryAUROC(BinaryPrecisionRecallCurve): @@ -601,10 +601,11 @@ class AUROC(_ClassificationTaskWrapper): corresponds to random guessing. This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the - ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of - :class:`~torchmetrics.classification.BinaryAUROC`, :class:`~torchmetrics.classification.MulticlassAUROC` and - :class:`~torchmetrics.classification.MultilabelAUROC` for the specific details of each argument influence and - examples. + ``task`` argument to either ``'binary'``, ``'maskedbinary'``, ``'multiclass'`` or ``'multilabel'``. + See the documentation of + :class:`~torchmetrics.classification.BinaryAUROC`, :class:`~torchmetrics.classification.MaskedBinaryAUROC`, + :class:`~torchmetrics.classification.MulticlassAUROC` and :class:`~torchmetrics.classification.MultilabelAUROC` + for the specific details of each argument influence and examples. Legacy Example: >>> from torch import tensor @@ -628,7 +629,7 @@ class AUROC(_ClassificationTaskWrapper): def __new__( # type: ignore[misc] cls: type["AUROC"], - task: Literal["binary", "multiclass", "multilabel"], + task: Literal["binary", "maskedbinary", "multiclass", "multilabel"], thresholds: Optional[Union[int, list[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, @@ -643,6 +644,8 @@ def __new__( # type: ignore[misc] kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTask.BINARY: return BinaryAUROC(max_fpr, **kwargs) + if task == ClassificationTask.MASKEDBINARY: + return MaskedBinaryAUROC(max_fpr, **kwargs) if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index 155f1bb8f60..02fa62ea46b 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -117,6 +117,7 @@ def _name() -> str: return "Classification" BINARY = "binary" + MASKEDBINARY = "maskedbinary" MULTICLASS = "multiclass" MULTILABEL = "multilabel" From 8f311a796fecdb6ab5b6947ff57705cdf4741df7 Mon Sep 17 00:00:00 2001 From: Vijay Vignesh Date: Fri, 19 Sep 2025 16:02:25 -0400 Subject: [PATCH 5/7] Update src/torchmetrics/classification/auroc.py Adding optional type for mask Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/torchmetrics/classification/auroc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 0c2a5367e61..9b4ae72b66e 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -232,7 +232,7 @@ class MaskedBinaryAUROC(BinaryAUROC): """ - def update(self, preds: Tensor, target: Tensor, mask: Tensor = None) -> None: + def update(self, preds: Tensor, target: Tensor, mask: Optional[Tensor] = None) -> None: """Update the state with the new data.""" if mask is not None: if mask.dtype != torch.bool: From 303c68e7b3558a906e8a3dfea0b3f7a846b7fe56 Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Tue, 30 Sep 2025 16:04:48 -0400 Subject: [PATCH 6/7] Pushing _masked_binary_cases down in _inputs.py --- tests/unittests/classification/_inputs.py | 51 +++++++++++------------ 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/tests/unittests/classification/_inputs.py b/tests/unittests/classification/_inputs.py index bc6c473b4b2..e421aca9171 100644 --- a/tests/unittests/classification/_inputs.py +++ b/tests/unittests/classification/_inputs.py @@ -112,32 +112,6 @@ def _logsoftmax(x: Tensor, dim: int = -1) -> Tensor: ) -class _MaskInput(NamedTuple): - preds: Tensor - target: Tensor - mask: Tensor - - -_masked_binary_cases = ( - pytest.param( - _MaskInput( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), - mask=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.bool), - ), - id="input[single_dim-probs]", - ), - pytest.param( - _MaskInput( - preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE)), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), - mask=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.bool), - ), - id="input[single_dim-logits]", - ), -) - - def _multiclass_with_missing_class(*shape: Any, num_classes=NUM_CLASSES): """Generate multiclass input where a class is missing. @@ -280,6 +254,31 @@ def _multiclass_with_missing_class(*shape: Any, num_classes=NUM_CLASSES): ), ) +class _MaskInput(NamedTuple): + preds: Tensor + target: Tensor + mask: Tensor + + +_masked_binary_cases = ( + pytest.param( + _MaskInput( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + mask=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.bool), + ), + id="input[single_dim-probs]", + ), + pytest.param( + _MaskInput( + preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + mask=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE), dtype=torch.bool), + ), + id="input[single_dim-logits]", + ), +) + # Generate edge multilabel edge case, where nothing matches (scores are undefined) __temp_preds = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) __temp_target = abs(__temp_preds - 1) From 7f04bc5336664c4ffde0b7ff0cbd1dbf7752dc0e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Sep 2025 20:05:44 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/classification/_inputs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unittests/classification/_inputs.py b/tests/unittests/classification/_inputs.py index e421aca9171..e61942ff1fb 100644 --- a/tests/unittests/classification/_inputs.py +++ b/tests/unittests/classification/_inputs.py @@ -254,6 +254,7 @@ def _multiclass_with_missing_class(*shape: Any, num_classes=NUM_CLASSES): ), ) + class _MaskInput(NamedTuple): preds: Tensor target: Tensor