diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ccf8cae705..b0a273ea235 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `ClassificationReport` with support for binary, multiclass, and multilabel classification tasks ([#3116](https://github.com/Lightning-AI/torchmetrics/pull/3116)) + + - diff --git a/docs/source/classification/classification_report.rst b/docs/source/classification/classification_report.rst new file mode 100644 index 00000000000..e25766561a0 --- /dev/null +++ b/docs/source/classification/classification_report.rst @@ -0,0 +1,55 @@ +.. customcarditem:: + :header: Classification Report + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Classification + +.. include:: ../links.rst + +##################### +Classification Report +##################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.ClassificationReport + :exclude-members: update, compute + :special-members: __new__ + +BinaryClassificationReport +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryClassificationReport + :exclude-members: update, compute + +MulticlassClassificationReport +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassClassificationReport + :exclude-members: update, compute + +MultilabelClassificationReport +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelClassificationReport + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.classification.classification_report + +binary_classification_report +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_classification_report + +multiclass_classification_report +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_classification_report + +multilabel_classification_report +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_classification_report diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index d660d3354b9..86d207583b4 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -59,6 +59,7 @@ Accuracy, AveragePrecision, CalibrationError, + ClassificationReport, CohenKappa, ConfusionMatrix, ExactMatch, @@ -177,6 +178,7 @@ "CalibrationError", "CatMetric", "CharErrorRate", + "ClassificationReport", "ClasswiseWrapper", "CohenKappa", "ConcordanceCorrCoef", diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 3e41f565879..dd11fa8fcf8 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -24,6 +24,12 @@ CalibrationError, MulticlassCalibrationError, ) +from torchmetrics.classification.classification_report import ( + BinaryClassificationReport, + ClassificationReport, + MulticlassClassificationReport, + MultilabelClassificationReport, +) from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa from torchmetrics.classification.confusion_matrix import ( BinaryConfusionMatrix, @@ -139,6 +145,7 @@ "BinaryAccuracy", "BinaryAveragePrecision", "BinaryCalibrationError", + "BinaryClassificationReport", "BinaryCohenKappa", "BinaryConfusionMatrix", "BinaryEER", @@ -163,6 +170,7 @@ "BinarySpecificityAtSensitivity", "BinaryStatScores", "CalibrationError", + "ClassificationReport", "CohenKappa", "ConfusionMatrix", "ExactMatch", @@ -177,6 +185,7 @@ "MulticlassAccuracy", "MulticlassAveragePrecision", "MulticlassCalibrationError", + "MulticlassClassificationReport", "MulticlassCohenKappa", "MulticlassConfusionMatrix", "MulticlassEER", @@ -202,6 +211,7 @@ "MultilabelAUROC", "MultilabelAccuracy", "MultilabelAveragePrecision", + "MultilabelClassificationReport", "MultilabelConfusionMatrix", "MultilabelCoverageError", "MultilabelEER", diff --git a/src/torchmetrics/classification/classification_report.py b/src/torchmetrics/classification/classification_report.py new file mode 100644 index 00000000000..12455d89fdf --- /dev/null +++ b/src/torchmetrics/classification/classification_report.py @@ -0,0 +1,550 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Sequence +from typing import Any, Dict, List, Optional, Union + +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.classification.base import _ClassificationTaskWrapper +from torchmetrics.functional.classification.classification_report import ( + binary_classification_report, + multiclass_classification_report, + multilabel_classification_report, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = [ + "BinaryClassificationReport.plot", + "MulticlassClassificationReport.plot", + "MultilabelClassificationReport.plot", + "ClassificationReport.plot", + ] + +__all__ = [ + "BinaryClassificationReport", + "ClassificationReport", + "MulticlassClassificationReport", + "MultilabelClassificationReport", +] + + +class _BaseClassificationReport(Metric): + """Base class for classification reports with shared functionality.""" + + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + + # Make mypy aware of the dynamically added states + preds: List[Tensor] + target: List[Tensor] + + def __init__( + self, + target_names: Optional[Sequence[str]] = None, + sample_weight: Optional[Tensor] = None, + digits: int = 2, + output_dict: bool = False, + zero_division: Union[str, int] = "warn", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.provided_target_names = target_names + self.sample_weight = sample_weight + self.digits = digits + self.output_dict = output_dict + self.zero_division = zero_division + self.target_names: List[str] = [] + + # Add states for tracking data + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update metric with predictions and targets.""" + self.preds.append(preds) + self.target.append(target) + + def compute(self) -> Union[Dict[str, Union[float, Dict[str, Union[float, int]]]], str]: + """Compute the classification report using functional interface.""" + preds = dim_zero_cat(self.preds) + target = dim_zero_cat(self.target) + + return self._call_functional_report(preds, target) + + def _call_functional_report( + self, preds: Tensor, target: Tensor + ) -> Union[Dict[str, Union[float, Dict[str, Union[float, int]]]], str]: + """Call the appropriate functional classification report.""" + raise NotImplementedError("Subclasses must implement _call_functional_report") + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + """ + if not self.output_dict: + raise ValueError("Plotting is only supported when output_dict=True") + return self._plot(val, ax) + + +class BinaryClassificationReport(_BaseClassificationReport): + r"""Compute precision, recall, F-measure and support for binary classification tasks. + + The classification report provides detailed metrics for each class in a binary classification task: + precision, recall, F1-score, and support. + + .. math:: + \text{Precision} = \frac{TP}{TP + FP} + + \text{Recall} = \frac{TP}{TP + FN} + + \text{F1} = 2 * \frac{\text{Precision} * \text{Recall}}{\text{Precision} + \text{Recall}} + + \text{Support} = \sum_i^N 1(y_i = k) + + Where :math:`TP` is true positives, :math:`FP` is false positives, :math:`FN` is false negatives, + :math:`y` is a tensor of target values, :math:`k` is the class, and :math:`N` is the number of samples. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): A tensor of predictions of shape ``(N, ...)`` where ``N`` is + the batch size. If preds is a floating point tensor with values outside [0,1] range we consider + the input to be logits and will auto apply sigmoid per element. Additionally, we convert to int + tensor with thresholding using the value in ``threshold``. + - ``target`` (:class:`~torch.Tensor`): A tensor of targets of shape ``(N, ...)`` where ``N`` is the batch size. + + As output to ``forward`` and ``compute`` the metric returns either: + + - A formatted string report if ``output_dict=False`` + - A dictionary of metrics if ``output_dict=True`` + + Args: + threshold: Threshold for transforming probability to binary (0,1) predictions + target_names: Optional list of names for each class + sample_weight: Optional weights for each sample + digits: Number of decimal places to display in the report + output_dict: If True, return a dict instead of a string report + zero_division: Value to use when dividing by zero + + Example: + >>> from torch import tensor + >>> from torchmetrics.classification.classification_report import binary_classification_report + >>> target = tensor([0, 1, 0, 1]) + >>> preds = tensor([0, 1, 1, 1]) + >>> target_names = ['0', '1'] + >>> report = binary_classification_report( + ... preds=preds, + ... target=target, + ... target_names=target_names, + ... digits=2 + ... ) + >>> print(report) # doctest: +NORMALIZE_WHITESPACE + precision recall f1-score support + + 0 1.00 0.50 0.67 2 + 1 0.67 1.00 0.80 2 + + accuracy 0.75 4 + macro avg 0.83 0.75 0.73 4 + weighted avg 0.83 0.75 0.73 4 + + """ + + def __init__( + self, + threshold: float = 0.5, + target_names: Optional[Sequence[str]] = None, + sample_weight: Optional[Tensor] = None, + digits: int = 2, + output_dict: bool = False, + zero_division: Union[str, int] = "warn", + ignore_index: Optional[int] = None, + **kwargs: Any, + ) -> None: + super().__init__( + target_names=target_names, + sample_weight=sample_weight, + digits=digits, + output_dict=output_dict, + zero_division=zero_division, + **kwargs, + ) + self.threshold = threshold + self.ignore_index = ignore_index + + # Set target names if they were provided + if target_names is not None: + self.target_names = list(target_names) + else: + self.target_names = ["0", "1"] + + def _call_functional_report( + self, preds: Tensor, target: Tensor + ) -> Union[Dict[str, Union[float, Dict[str, Union[float, int]]]], str]: + """Call binary classification report from functional interface.""" + return binary_classification_report( + preds=preds, + target=target, + threshold=self.threshold, + target_names=self.target_names, + digits=self.digits, + output_dict=self.output_dict, + zero_division=self.zero_division, + ignore_index=self.ignore_index, + ) + + +class MulticlassClassificationReport(_BaseClassificationReport): + r"""Compute precision, recall, F-measure and support for multiclass classification tasks. + + The classification report provides detailed metrics for each class in a multiclass classification task: + precision, recall, F1-score, and support. + + .. math:: + \text{Precision} = \frac{TP}{TP + FP} + + \text{Recall} = \frac{TP}{TP + FN} + + \text{F1} = 2 * \frac{\text{Precision} * \text{Recall}}{\text{Precision} + \text{Recall}} + + \text{Support} = \sum_i^N 1(y_i = k) + + Where :math:`TP` is true positives, :math:`FP` is false positives, :math:`FN` is false negatives, + :math:`y` is a tensor of target values, :math:`k` is the class, and :math:`N` is the number of samples. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): A tensor of predictions. If preds is a floating point tensor with values + outside [0,1] range we consider the input to be logits and will auto apply softmax per sample. + Additionally, we convert to int tensor with argmax. + - ``target`` (:class:`~torch.Tensor`): A tensor of integer targets. + + As output to ``forward`` and ``compute`` the metric returns either: + + - A formatted string report if ``output_dict=False`` + - A dictionary of metrics if ``output_dict=True`` + + Args: + num_classes: Number of classes in the dataset + target_names: Optional list of names for each class + sample_weight: Optional weights for each sample + digits: Number of decimal places to display in the report + output_dict: If True, return a dict instead of a string report + zero_division: Value to use when dividing by zero + top_k: Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + + Example: + >>> from torch import tensor + >>> from torchmetrics.classification.classification_report import multiclass_classification_report + >>> target = tensor([0, 1, 2, 2, 2]) + >>> preds = tensor([0, 0, 2, 2, 1]) + >>> target_names = ["class 0", "class 1", "class 2"] + >>> report = multiclass_classification_report( + ... preds=preds, + ... target=target, + ... num_classes=3, + ... target_names=target_names, + ... digits=2 + ... ) + >>> print(report) # doctest: +NORMALIZE_WHITESPACE + precision recall f1-score support + + class 0 0.50 1.00 0.67 1 + class 1 0.00 0.00 0.00 1 + class 2 1.00 0.67 0.80 3 + + accuracy 0.60 5 + macro avg 0.50 0.56 0.49 5 + weighted avg 0.70 0.60 0.61 5 + + """ + + plot_legend_name: str = "Class" + + def __init__( + self, + num_classes: int, + target_names: Optional[Sequence[str]] = None, + sample_weight: Optional[Tensor] = None, + digits: int = 2, + output_dict: bool = False, + zero_division: Union[str, int] = "warn", + ignore_index: Optional[int] = None, + top_k: int = 1, + **kwargs: Any, + ) -> None: + super().__init__( + target_names=target_names, + sample_weight=sample_weight, + digits=digits, + output_dict=output_dict, + zero_division=zero_division, + **kwargs, + ) + self.num_classes = num_classes + self.ignore_index = ignore_index + self.top_k = top_k + + # Set target names if they were provided + if target_names is not None: + self.target_names = list(target_names) + else: + self.target_names = [str(i) for i in range(num_classes)] + + def _call_functional_report( + self, preds: Tensor, target: Tensor + ) -> Union[Dict[str, Union[float, Dict[str, Union[float, int]]]], str]: + """Call multiclass classification report from functional interface.""" + return multiclass_classification_report( + preds=preds, + target=target, + num_classes=self.num_classes, + target_names=self.target_names, + digits=self.digits, + output_dict=self.output_dict, + zero_division=self.zero_division, + ignore_index=self.ignore_index, + top_k=self.top_k, + ) + + +class MultilabelClassificationReport(_BaseClassificationReport): + r"""Compute precision, recall, F-measure and support for multilabel classification tasks. + + The classification report provides detailed metrics for each class in a multilabel classification task: + precision, recall, F1-score, and support. + + .. math:: + \text{Precision} = \frac{TP}{TP + FP} + + \text{Recall} = \frac{TP}{TP + FN} + + \text{F1} = 2 * \frac{\text{Precision} * \text{Recall}}{\text{Precision} + \text{Recall}} + + \text{Support} = \sum_i^N 1(y_i = k) + + Where :math:`TP` is true positives, :math:`FP` is false positives, :math:`FN` is false negatives, + :math:`y` is a tensor of target values, :math:`k` is the class, and :math:`N` is the number of samples. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): A tensor of predictions of shape ``(N, C)`` where ``N`` is the + batch size and ``C`` is the number of labels. If preds is a floating point tensor with values + outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. + Additionally, we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (:class:`~torch.Tensor`): A tensor of targets of shape ``(N, C)`` where ``N`` is the + batch size and ``C`` is the number of labels. + + As output to ``forward`` and ``compute`` the metric returns either: + + - A formatted string report if ``output_dict=False`` + - A dictionary of metrics if ``output_dict=True`` + + Args: + num_labels: Number of labels in the dataset + target_names: Optional list of names for each label + threshold: Threshold for transforming probability to binary (0,1) predictions + sample_weight: Optional weights for each sample + digits: Number of decimal places to display in the report + output_dict: If True, return a dict instead of a string report + zero_division: Value to use when dividing by zero + + Example: + >>> from torch import tensor + >>> from torchmetrics.classification.classification_report import multilabel_classification_report + >>> target = tensor([[1, 0, 1], [0, 1, 0], [1, 1, 0]]) + >>> preds = tensor([[1, 0, 1], [0, 1, 1], [1, 0, 0]]) + >>> target_names = ["Label A", "Label B", "Label C"] + >>> report = multilabel_classification_report( + ... preds=preds, + ... target=target, + ... num_labels=len(target_names), + ... target_names=target_names, + ... digits=2, + ... ) + >>> print(report) # doctest: +NORMALIZE_WHITESPACE + precision recall f1-score support + + Label A 1.00 1.00 1.00 2 + Label B 1.00 0.50 0.67 2 + Label C 0.50 1.00 0.67 1 + + micro avg 0.80 0.80 0.80 5 + macro avg 0.83 0.83 0.78 5 + weighted avg 0.90 0.80 0.80 5 + samples avg 0.83 0.83 0.78 5 + + """ + + plot_legend_name: str = "Label" + + def __init__( + self, + num_labels: int, + target_names: Optional[Sequence[str]] = None, + threshold: float = 0.5, + sample_weight: Optional[Tensor] = None, + digits: int = 2, + output_dict: bool = False, + zero_division: Union[str, int] = "warn", + ignore_index: Optional[int] = None, + **kwargs: Any, + ) -> None: + super().__init__( + target_names=target_names, + sample_weight=sample_weight, + digits=digits, + output_dict=output_dict, + zero_division=zero_division, + **kwargs, + ) + self.threshold = threshold + self.num_labels = num_labels + self.ignore_index = ignore_index + + # Set target names if they were provided + if target_names is not None: + self.target_names = list(target_names) + else: + self.target_names = [str(i) for i in range(num_labels)] + + def _call_functional_report( + self, preds: Tensor, target: Tensor + ) -> Union[Dict[str, Union[float, Dict[str, Union[float, int]]]], str]: + """Call multilabel classification report from functional interface.""" + return multilabel_classification_report( + preds=preds, + target=target, + num_labels=self.num_labels, + threshold=self.threshold, + target_names=self.target_names, + digits=self.digits, + output_dict=self.output_dict, + zero_division=self.zero_division, + ignore_index=self.ignore_index, + ) + + +class ClassificationReport(_ClassificationTaskWrapper): + r"""Compute precision, recall, F-measure and support for each class. + + .. math:: + \text{Precision} = \frac{TP}{TP + FP} + + \text{Recall} = \frac{TP}{TP + FN} + + \text{F1} = 2 * \frac{\text{Precision} * \text{Recall}}{\text{Precision} + \text{Recall}} + + \text{Support} = \sum_i^N 1(y_i = k) + + Where :math:`TP` is true positives, :math:`FP` is false positives, :math:`FN` is false negatives, + :math:`y` is a tensor of target values, :math:`k` is the class, and :math:`N` is the number of samples. + + This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``'multilabel'``. See the documentation of + :class:`~torchmetrics.classification.BinaryClassificationReport`, + :class:`~torchmetrics.classification.MulticlassClassificationReport` and + :class:`~torchmetrics.classification.MultilabelClassificationReport` for the specific details of each argument + influence and examples. + + Example (Binary Classification): + >>> from torch import tensor + >>> from torchmetrics.classification import ClassificationReport + >>> target = tensor([0, 1, 0, 1]) + >>> preds = tensor([0, 1, 1, 1]) + >>> target_names = ['0', '1'] + >>> report = ClassificationReport( + ... task="binary", + ... target_names=target_names, + ... digits=2 + ... ) + >>> report.update(preds, target) + >>> print(report.compute()) # doctest: +NORMALIZE_WHITESPACE + precision recall f1-score support + + 0 1.00 0.50 0.67 2 + 1 0.67 1.00 0.80 2 + + accuracy 0.75 4 + macro avg 0.83 0.75 0.73 4 + weighted avg 0.83 0.75 0.73 4 + + """ + + def __new__( # type: ignore[misc] + cls: type["ClassificationReport"], + task: Literal["binary", "multiclass", "multilabel"], + threshold: float = 0.5, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + target_names: Optional[Sequence[str]] = None, + sample_weight: Optional[Tensor] = None, + digits: int = 2, + output_dict: bool = False, + zero_division: Union[str, int] = "warn", + ignore_index: Optional[int] = None, + top_k: int = 1, + **kwargs: Any, + ) -> Metric: + """Initialize task metric.""" + task = ClassificationTask.from_str(task) + + kwargs.update({ + "target_names": target_names, + "sample_weight": sample_weight, + "digits": digits, + "output_dict": output_dict, + "zero_division": zero_division, + "ignore_index": ignore_index, + }) + + if task == ClassificationTask.BINARY: + return BinaryClassificationReport(threshold, **kwargs) + if task == ClassificationTask.MULTICLASS: + if not isinstance(num_classes, int): + raise ValueError( + f"Optional arg `num_classes` must be type `int` when task is {task}. Got {type(num_classes)}" + ) + kwargs.update({"top_k": top_k}) + return MulticlassClassificationReport(num_classes, **kwargs) + if task == ClassificationTask.MULTILABEL: + if not isinstance(num_labels, int): + raise ValueError( + f"Optional arg `num_labels` must be type `int` when task is {task}. Got {type(num_labels)}" + ) + return MultilabelClassificationReport(num_labels, **kwargs, threshold=threshold) + raise ValueError(f"Not handled value: {task}") diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index d3847b37ce1..8fcf45ae1fd 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -25,8 +25,10 @@ accuracy, auroc, average_precision, + binary_classification_report, binary_eer, calibration_error, + classification_report, cohen_kappa, confusion_matrix, eer, @@ -38,7 +40,9 @@ jaccard_index, logauc, matthews_corrcoef, + multiclass_classification_report, multiclass_eer, + multilabel_classification_report, multilabel_eer, negative_predictive_value, precision, @@ -149,11 +153,13 @@ "accuracy", "auroc", "average_precision", + "binary_classification_report", "binary_eer", "bleu_score", "calibration_error", "char_error_rate", "chrf_score", + "classification_report", "cohen_kappa", "concordance_corrcoef", "confusion_matrix", @@ -185,7 +191,9 @@ "mean_squared_error", "mean_squared_log_error", "minkowski_distance", + "multiclass_classification_report", "multiclass_eer", + "multilabel_classification_report", "multilabel_eer", "multiscale_structural_similarity_index_measure", "negative_predictive_value", diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index 6deb86fce28..6b03b004f0a 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -29,6 +29,12 @@ calibration_error, multiclass_calibration_error, ) +from torchmetrics.functional.classification.classification_report import ( + binary_classification_report, + classification_report, + multiclass_classification_report, + multilabel_classification_report, +) from torchmetrics.functional.classification.cohen_kappa import binary_cohen_kappa, cohen_kappa, multiclass_cohen_kappa from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, @@ -156,6 +162,7 @@ "binary_auroc", "binary_average_precision", "binary_calibration_error", + "binary_classification_report", "binary_cohen_kappa", "binary_confusion_matrix", "binary_eer", @@ -180,6 +187,7 @@ "binary_specificity_at_sensitivity", "binary_stat_scores", "calibration_error", + "classification_report", "cohen_kappa", "confusion_matrix", "demographic_parity", @@ -197,6 +205,7 @@ "multiclass_auroc", "multiclass_average_precision", "multiclass_calibration_error", + "multiclass_classification_report", "multiclass_cohen_kappa", "multiclass_confusion_matrix", "multiclass_eer", @@ -222,6 +231,7 @@ "multilabel_accuracy", "multilabel_auroc", "multilabel_average_precision", + "multilabel_classification_report", "multilabel_confusion_matrix", "multilabel_coverage_error", "multilabel_eer", diff --git a/src/torchmetrics/functional/classification/classification_report.py b/src/torchmetrics/functional/classification/classification_report.py new file mode 100644 index 00000000000..6fb9177f9d1 --- /dev/null +++ b/src/torchmetrics/functional/classification/classification_report.py @@ -0,0 +1,925 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Optional, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.classification.accuracy import ( + binary_accuracy, + multiclass_accuracy, + multilabel_accuracy, +) +from torchmetrics.functional.classification.f_beta import ( + binary_fbeta_score, + multiclass_fbeta_score, + multilabel_fbeta_score, +) +from torchmetrics.functional.classification.precision_recall import ( + binary_precision, + binary_recall, + multiclass_precision, + multiclass_recall, + multilabel_precision, + multilabel_recall, +) +from torchmetrics.utilities.enums import ClassificationTask + + +def _handle_zero_division(value: float, zero_division: Union[str, float]) -> float: + """Handle NaN values based on zero_division parameter.""" + if torch.isnan(torch.tensor(value)): + if zero_division == "warn": + return 0.0 + if isinstance(zero_division, (int, float)): + return float(zero_division) + return value + + +def _compute_averages( + class_metrics: Dict[str, Dict[str, Union[float, int]]], + micro_metrics: Optional[Dict[str, float]] = None, + show_micro_avg: bool = False, + is_multilabel: bool = False, + preds: Optional[Tensor] = None, + target: Optional[Tensor] = None, + threshold: float = 0.5, +) -> Dict[str, Dict[str, Union[float, int]]]: + """Compute macro, micro, weighted, and samples averages for the classification report.""" + total_support = int(sum(metrics["support"] for metrics in class_metrics.values())) + num_classes = len(class_metrics) + + averages: Dict[str, Dict[str, Union[float, int]]] = {} + + # Add micro average if provided and should be shown + if micro_metrics is not None and show_micro_avg: + averages["micro avg"] = { + "precision": micro_metrics["precision"], + "recall": micro_metrics["recall"], + "f1-score": micro_metrics["f1-score"], + "support": total_support, + } + + # Calculate macro and weighted averages + for avg_name in ["macro avg", "weighted avg"]: + is_weighted = avg_name == "weighted avg" + + if total_support == 0: + avg_precision = avg_recall = avg_f1 = 0.0 + else: + if is_weighted: + weights = [float(metrics["support"]) / float(total_support) for metrics in class_metrics.values()] + else: + weights = [1.0 / float(num_classes) for _ in range(num_classes)] + + # Calculate weighted metrics more efficiently + metric_names = ["precision", "recall", "f1-score"] + avg_metrics = {} + + for metric_name in metric_names: + avg_metrics[metric_name] = sum( + float(metrics.get(metric_name, 0.0)) * w for metrics, w in zip(class_metrics.values(), weights) + ) + + avg_precision = avg_metrics["precision"] + avg_recall = avg_metrics["recall"] + avg_f1 = avg_metrics["f1-score"] + + averages[avg_name] = { + "precision": avg_precision, + "recall": avg_recall, + "f1-score": avg_f1, + "support": total_support, + } + + # Add samples average for multilabel classification + if is_multilabel and preds is not None and target is not None: + # Convert to binary predictions + binary_preds = (preds >= threshold).float() + + # Calculate per-sample metrics + n_samples = preds.shape[0] + sample_precision = torch.zeros(n_samples, dtype=torch.float32) + sample_recall = torch.zeros(n_samples, dtype=torch.float32) + sample_f1 = torch.zeros(n_samples, dtype=torch.float32) + + for i in range(n_samples): + true_positives = torch.sum(binary_preds[i] * target[i]) + pred_positives = torch.sum(binary_preds[i]) + actual_positives = torch.sum(target[i]) + + if pred_positives > 0: + sample_precision[i] = true_positives / pred_positives + if actual_positives > 0: + sample_recall[i] = true_positives / actual_positives + if pred_positives > 0 and actual_positives > 0: + sample_f1[i] = 2 * (sample_precision[i] * sample_recall[i]) / (sample_precision[i] + sample_recall[i]) + + # Average across samples + avg_precision = torch.mean(sample_precision).item() + avg_recall = torch.mean(sample_recall).item() + avg_f1 = torch.mean(sample_f1).item() + + averages["samples avg"] = { + "precision": avg_precision, + "recall": avg_recall, + "f1-score": avg_f1, + "support": total_support, + } + + return averages + + +def _format_report( + class_metrics: Dict[str, Dict[str, Union[float, int]]], + accuracy: float, + target_names: Optional[List[str]] = None, + digits: int = 2, + output_dict: bool = False, + micro_metrics: Optional[Dict[str, float]] = None, + show_micro_avg: bool = False, + is_multilabel: bool = False, + preds: Optional[Tensor] = None, + target_tensor: Optional[Tensor] = None, + threshold: float = 0.5, +) -> Union[str, Dict[str, Union[float, Dict[str, Union[float, int]]]]]: + """Format metrics into a classification report.""" + if output_dict: + result_dict: Dict[str, Union[float, Dict[str, Union[float, int]]]] = {} + + # Add class metrics + for i, (class_name, metrics) in enumerate(class_metrics.items()): + display_name = target_names[i] if target_names is not None and i < len(target_names) else str(class_name) + result_dict[display_name] = { + "precision": round(float(metrics["precision"]), digits), + "recall": round(float(metrics["recall"]), digits), + "f1-score": round(float(metrics["f1-score"]), digits), + "support": metrics["support"], + } + + # Add accuracy (only for non-multilabel) and averages + if not is_multilabel: + result_dict["accuracy"] = accuracy + + result_dict.update( + _compute_averages( + class_metrics, micro_metrics, show_micro_avg, is_multilabel, preds, target_tensor, threshold + ) + ) + + return result_dict + + # String formatting + headers = ["precision", "recall", "f1-score", "support"] + + # Convert numpy array to list if necessary + if target_names is not None and hasattr(target_names, "tolist"): + target_names = target_names.tolist() + + # Calculate widths needed for formatting + name_width = max(len(str(name)) for name in class_metrics) + if target_names: + name_width = max(name_width, max(len(str(name)) for name in target_names)) + + # Add extra width for average methods + name_width = max(name_width, len("weighted avg")) + if is_multilabel: + name_width = max(name_width, len("samples avg")) + + # Determine width for each metric column + width = max(digits + 6, len(headers[0])) + + # Format header + head = " " * name_width + " " + for h in headers: + head += "{:>{width}} ".format(h, width=width) + + report_lines = [head, ""] + + # Format rows for each class + for i, (class_name, metrics) in enumerate(class_metrics.items()): + display_name = target_names[i] if target_names and i < len(target_names) else str(class_name) + # Right-align the class/label name for scikit-learn compatibility + row = "{:>{name_width}} ".format(display_name, name_width=name_width) + + row += "{:>{width}.{digits}f} ".format(metrics.get("precision", 0.0), width=width, digits=digits) + row += "{:>{width}.{digits}f} ".format(metrics.get("recall", 0.0), width=width, digits=digits) + row += "{:>{width}.{digits}f} ".format(metrics.get("f1-score", 0.0), width=width, digits=digits) + row += "{:>{width}} ".format(metrics.get("support", 0), width=width) + report_lines.append(row) + + # Add a blank line + report_lines.append("") + + # Format accuracy row (only for non-multilabel) + if not is_multilabel: + total_support = sum(metrics["support"] for metrics in class_metrics.values()) + acc_row = "{:>{name_width}} ".format("accuracy", name_width=name_width) + acc_row += "{:>{width}} ".format("", width=width) + acc_row += "{:>{width}} ".format("", width=width) + acc_row += "{:>{width}.{digits}f} ".format(accuracy, width=width, digits=digits) + acc_row += "{:>{width}} ".format(total_support, width=width) + report_lines.append(acc_row) + + # Format averages rows + averages = _compute_averages( + class_metrics, micro_metrics, show_micro_avg, is_multilabel, preds, target_tensor, threshold + ) + for avg_name, avg_metrics in averages.items(): + row = "{:>{name_width}} ".format(avg_name, name_width=name_width) + + row += "{:>{width}.{digits}f} ".format(avg_metrics["precision"], width=width, digits=digits) + row += "{:>{width}.{digits}f} ".format(avg_metrics["recall"], width=width, digits=digits) + row += "{:>{width}.{digits}f} ".format(avg_metrics["f1-score"], width=width, digits=digits) + row += "{:>{width}} ".format(avg_metrics["support"], width=width) + report_lines.append(row) + + return "\n".join(report_lines) + + +def _compute_binary_metrics( + preds: Tensor, target: Tensor, threshold: float, ignore_index: Optional[int], validate_args: bool +) -> Dict[int, Dict[str, Union[float, int]]]: + """Compute metrics for binary classification.""" + class_metrics = {} + + for class_idx in [0, 1]: + if class_idx == 0: + # For class 0 (negative class), we need to invert both preds and target + # But first we need to handle ignore_index properly + if ignore_index is not None: + # Create a mask for valid indices + mask = target != ignore_index + # Create inverted target only for valid indices, preserving ignore_index + inv_target = target.clone() + inv_target[mask] = 1 - target[mask] + # Invert predictions for all indices + inv_preds = 1 - preds + else: + inv_preds = 1 - preds + inv_target = 1 - target + + precision_val = binary_precision( + inv_preds, inv_target, threshold, ignore_index=ignore_index, validate_args=validate_args + ).item() + recall_val = binary_recall( + inv_preds, inv_target, threshold, ignore_index=ignore_index, validate_args=validate_args + ).item() + f1_val = binary_fbeta_score( + inv_preds, + inv_target, + beta=1.0, + threshold=threshold, + ignore_index=ignore_index, + validate_args=validate_args, + ).item() + else: + # For class 1 (positive class), use binary metrics directly + precision_val = binary_precision( + preds, target, threshold, ignore_index=ignore_index, validate_args=validate_args + ).item() + recall_val = binary_recall( + preds, target, threshold, ignore_index=ignore_index, validate_args=validate_args + ).item() + f1_val = binary_fbeta_score( + preds, target, beta=1.0, threshold=threshold, ignore_index=ignore_index, validate_args=validate_args + ).item() + + # Calculate support, accounting for ignore_index + if ignore_index is not None: + mask = target != ignore_index + support_val = int(((target == class_idx) & mask).sum().item()) + else: + support_val = int((target == class_idx).sum().item()) + + class_metrics[class_idx] = { + "precision": precision_val, + "recall": recall_val, + "f1-score": f1_val, + "support": support_val, + } + + return class_metrics + + +def _compute_multiclass_metrics( + preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int], validate_args: bool, top_k: int = 1 +) -> Dict[int, Dict[str, Union[float, int]]]: + """Compute metrics for multiclass classification.""" + # Calculate per-class metrics + precision_vals = multiclass_precision( + preds, + target, + num_classes=num_classes, + average=None, + top_k=top_k, + ignore_index=ignore_index, + validate_args=validate_args, + ) + recall_vals = multiclass_recall( + preds, + target, + num_classes=num_classes, + average=None, + top_k=top_k, + ignore_index=ignore_index, + validate_args=validate_args, + ) + f1_vals = multiclass_fbeta_score( + preds, + target, + beta=1.0, + num_classes=num_classes, + average=None, + top_k=top_k, + ignore_index=ignore_index, + validate_args=validate_args, + ) + + # Calculate support for each class + if ignore_index is not None: + mask = target != ignore_index + class_counts = torch.bincount(target[mask].flatten(), minlength=num_classes) + else: + class_counts = torch.bincount(target.flatten(), minlength=num_classes) + + class_metrics = {} + for class_idx in range(num_classes): + class_metrics[class_idx] = { + "precision": precision_vals[class_idx].item(), + "recall": recall_vals[class_idx].item(), + "f1-score": f1_vals[class_idx].item(), + "support": int(class_counts[class_idx].item()), + } + + return class_metrics + + +def _compute_multilabel_metrics( + preds: Tensor, target: Tensor, num_labels: int, threshold: float, ignore_index: Optional[int], validate_args: bool +) -> Dict[int, Dict[str, Union[float, int]]]: + """Compute metrics for multilabel classification.""" + # Calculate per-label metrics + precision_vals = multilabel_precision( + preds, + target, + num_labels=num_labels, + threshold=threshold, + average=None, + ignore_index=ignore_index, + validate_args=validate_args, + ) + recall_vals = multilabel_recall( + preds, + target, + num_labels=num_labels, + threshold=threshold, + average=None, + ignore_index=ignore_index, + validate_args=validate_args, + ) + f1_vals = multilabel_fbeta_score( + preds, + target, + beta=1.0, + num_labels=num_labels, + threshold=threshold, + average=None, + ignore_index=ignore_index, + validate_args=validate_args, + ) + + # Calculate support for each label, accounting for ignore_index + if ignore_index is not None: + # For multilabel, support is the number of positive labels (target=1) excluding ignore_index + mask = target != ignore_index + supports = ((target == 1) & mask).sum(dim=0).int() + else: + supports = (target == 1).sum(dim=0).int() + + class_metrics = {} + for label_idx in range(num_labels): + class_metrics[label_idx] = { + "precision": precision_vals[label_idx].item(), + "recall": recall_vals[label_idx].item(), + "f1-score": f1_vals[label_idx].item(), + "support": int(supports[label_idx].item()), + } + + return class_metrics + + +def _apply_zero_division_handling( + class_metrics: Dict[int, Dict[str, Union[float, int]]], zero_division: Union[str, float] +) -> None: + """Apply zero division handling to all class metrics in-place.""" + for metrics in class_metrics.values(): + metrics["precision"] = _handle_zero_division(metrics["precision"], zero_division) + metrics["recall"] = _handle_zero_division(metrics["recall"], zero_division) + metrics["f1-score"] = _handle_zero_division(metrics["f1-score"], zero_division) + + +def classification_report( + preds: Tensor, + target: Tensor, + task: Literal["binary", "multiclass", "multilabel"], + threshold: float = 0.5, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + target_names: Optional[List[str]] = None, + digits: int = 2, + output_dict: bool = False, + zero_division: Union[str, float] = 0.0, + ignore_index: Optional[int] = None, + validate_args: bool = True, + labels: Optional[List[int]] = None, + top_k: int = 1, +) -> Union[str, Dict[str, Union[float, Dict[str, Union[float, int]]]]]: + """Compute a classification report for various classification tasks. + + The classification report shows the precision, recall, F1 score, and support for each class/label. + + Args: + preds: Tensor with predictions + target: Tensor with ground truth labels + task: The classification task - either 'binary', 'multiclass', or 'multilabel' + threshold: Threshold for converting probabilities to binary predictions (for binary and multilabel tasks) + num_classes: Number of classes (for multiclass tasks) + num_labels: Number of labels (for multilabel tasks) + target_names: Optional list of names for the classes/labels + digits: Number of decimal places to display in the report + output_dict: If True, return a dict instead of a string report + zero_division: Value to use when dividing by zero + ignore_index: Optional index to ignore in the target (for multiclass tasks) + validate_args: bool indicating if input arguments and tensors should be validated for correctness + labels: Optional list of label indices to include in the report (for multiclass tasks) + top_k: Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits and task is 'multiclass'. + + Returns: + If output_dict=True, a dictionary with the classification report data. + Otherwise, a formatted string with the classification report. + + Examples: + >>> from torch import tensor + >>> from torchmetrics.functional.classification.classification_report import classification_report + >>> + >>> # Binary classification example + >>> binary_target = tensor([0, 1, 0, 1]) + >>> binary_preds = tensor([0, 1, 1, 1]) + >>> binary_report = classification_report( + ... preds=binary_preds, + ... target=binary_target, + ... task="binary", + ... target_names=['Class 0', 'Class 1'], + ... digits=2 + ... ) + >>> print(binary_report) # doctest: +NORMALIZE_WHITESPACE + precision recall f1-score support + + Class 0 1.00 0.50 0.67 2 + Class 1 0.67 1.00 0.80 2 + + accuracy 0.75 4 + macro avg 0.83 0.75 0.73 4 + weighted avg 0.83 0.75 0.73 4 + >>> + >>> # Multiclass classification example + >>> multiclass_target = tensor([0, 1, 2, 2, 2]) + >>> multiclass_preds = tensor([0, 0, 2, 2, 1]) + >>> multiclass_report = classification_report( + ... preds=multiclass_preds, + ... target=multiclass_target, + ... task="multiclass", + ... num_classes=3, + ... target_names=["Class 0", "Class 1", "Class 2"], + ... digits=2 + ... ) + >>> print(multiclass_report) # doctest: +NORMALIZE_WHITESPACE + precision recall f1-score support + + Class 0 0.50 1.00 0.67 1 + Class 1 0.00 0.00 0.00 1 + Class 2 1.00 0.67 0.80 3 + + accuracy 0.60 5 + macro avg 0.50 0.56 0.49 5 + weighted avg 0.70 0.60 0.61 5 + >>> + >>> # Multilabel classification example + >>> multilabel_target = tensor([[1, 0, 1], [0, 1, 0], [1, 1, 0]]) + >>> multilabel_preds = tensor([[1, 0, 1], [0, 1, 1], [1, 0, 0]]) + >>> multilabel_report = classification_report( + ... preds=multilabel_preds, + ... target=multilabel_target, + ... task="multilabel", + ... num_labels=3, + ... target_names=["Label A", "Label B", "Label C"], + ... digits=2 + ... ) + >>> print(multilabel_report) # doctest: +NORMALIZE_WHITESPACE + precision recall f1-score support + + Label A 1.00 1.00 1.00 2 + Label B 1.00 0.50 0.67 2 + Label C 0.50 1.00 0.67 1 + + micro avg 0.80 0.80 0.80 5 + macro avg 0.83 0.83 0.78 5 + weighted avg 0.90 0.80 0.80 5 + samples avg 0.83 0.83 0.78 5 + + """ + # Determine if micro average should be shown in the report based on classification task + # Following scikit-learn's logic: + # - Show for multilabel classification (always) + # - Show for multiclass when using a subset of classes + # - Don't show for binary classification (micro avg is same as accuracy) + # - Don't show for full multiclass classification with all classes (micro avg is same as accuracy) + show_micro_avg = False + is_multilabel = task == ClassificationTask.MULTILABEL + + # Compute task-specific metrics + if task == ClassificationTask.BINARY: + class_metrics = _compute_binary_metrics(preds, target, threshold, ignore_index, validate_args) + accuracy_val = binary_accuracy( + preds, target, threshold, ignore_index=ignore_index, validate_args=validate_args + ).item() + + # Calculate micro metrics (same as accuracy for binary classification) + micro_metrics = {"precision": accuracy_val, "recall": accuracy_val, "f1-score": accuracy_val} + # For binary classification, don't show micro avg (it's same as accuracy) + show_micro_avg = False + + elif task == ClassificationTask.MULTICLASS: + if num_classes is None: + raise ValueError("num_classes must be provided for multiclass classification") + + class_metrics = _compute_multiclass_metrics(preds, target, num_classes, ignore_index, validate_args, top_k) + + # Filter metrics by labels if provided + if labels is not None: + # Create a new dict with only the specified labels + filtered_metrics = { + class_idx: metrics for class_idx, metrics in class_metrics.items() if class_idx in labels + } + class_metrics = filtered_metrics + show_micro_avg = True # Always show micro avg when specific labels are requested + else: + # For multiclass, check if we have a subset of classes with support + classes_with_support = sum(1 for metrics in class_metrics.values() if metrics["support"] > 0) + show_micro_avg = classes_with_support < num_classes + + accuracy_val = multiclass_accuracy( + preds, + target, + num_classes=num_classes, + average="micro", + top_k=top_k, + ignore_index=ignore_index, + validate_args=validate_args, + ).item() + + # Calculate micro-averaged metrics + micro_precision = multiclass_precision( + preds, + target, + num_classes=num_classes, + average="micro", + top_k=top_k, + ignore_index=ignore_index, + validate_args=validate_args, + ).item() + micro_recall = multiclass_recall( + preds, + target, + num_classes=num_classes, + average="micro", + top_k=top_k, + ignore_index=ignore_index, + validate_args=validate_args, + ).item() + micro_f1 = multiclass_fbeta_score( + preds, + target, + beta=1.0, + num_classes=num_classes, + average="micro", + top_k=top_k, + ignore_index=ignore_index, + validate_args=validate_args, + ).item() + + micro_metrics = {"precision": micro_precision, "recall": micro_recall, "f1-score": micro_f1} + + elif task == ClassificationTask.MULTILABEL: + if num_labels is None: + raise ValueError("num_labels must be provided for multilabel classification") + + class_metrics = _compute_multilabel_metrics(preds, target, num_labels, threshold, ignore_index, validate_args) + accuracy_val = multilabel_accuracy( + preds, + target, + num_labels=num_labels, + threshold=threshold, + average="micro", + ignore_index=ignore_index, + validate_args=validate_args, + ).item() + + # Calculate micro-averaged metrics + micro_precision = multilabel_precision( + preds, + target, + num_labels=num_labels, + threshold=threshold, + average="micro", + ignore_index=ignore_index, + validate_args=validate_args, + ).item() + micro_recall = multilabel_recall( + preds, + target, + num_labels=num_labels, + threshold=threshold, + average="micro", + ignore_index=ignore_index, + validate_args=validate_args, + ).item() + micro_f1 = multilabel_fbeta_score( + preds, + target, + beta=1.0, + num_labels=num_labels, + threshold=threshold, + average="micro", + ignore_index=ignore_index, + validate_args=validate_args, + ).item() + + micro_metrics = {"precision": micro_precision, "recall": micro_recall, "f1-score": micro_f1} + + # Always show micro avg for multilabel + show_micro_avg = True + + else: + raise ValueError(f"Invalid Classification: expected one of (binary, multiclass, multilabel) but got {task}") + + # Apply zero division handling + _apply_zero_division_handling(class_metrics, zero_division) + + # Filter metrics by labels if provided - this needs to happen after computing all metrics + # to ensure proper calculation of overall statistics, but before formatting + if task == ClassificationTask.MULTICLASS and labels is not None: + # Create a new dict with only the specified labels + filtered_metrics = {class_idx: metrics for class_idx, metrics in class_metrics.items() if class_idx in labels} + class_metrics = filtered_metrics + + # Convert integer keys to strings for compatibility with _format_report + class_metrics_str = {str(k): v for k, v in class_metrics.items()} + + # Apply zero_division to micro metrics + for key in micro_metrics: + micro_metrics[key] = _handle_zero_division(micro_metrics[key], zero_division) + + return _format_report( + class_metrics_str, + accuracy_val, + target_names, + digits, + output_dict, + micro_metrics, + show_micro_avg, + is_multilabel, + preds if is_multilabel else None, + target if is_multilabel else None, + threshold, + ) + + +def binary_classification_report( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + target_names: Optional[List[str]] = None, + digits: int = 2, + output_dict: bool = False, + zero_division: Union[str, float] = 0.0, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Union[str, Dict[str, Union[float, Dict[str, Union[float, int]]]]]: + """Compute a classification report for binary classification tasks. + + The classification report shows the precision, recall, F1 score, and support for each class. + + Args: + preds: Tensor with predictions + target: Tensor with ground truth labels + threshold: Threshold for converting probabilities to binary predictions + target_names: Optional list of names for the classes + digits: Number of decimal places to display in the report + output_dict: If True, return a dict instead of a string report + zero_division: Value to use when dividing by zero + ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness + + Returns: + If output_dict=True, a dictionary with the classification report data. + Otherwise, a formatted string with the classification report. + + Example: + >>> from torch import tensor + >>> from torchmetrics.functional.classification.classification_report import binary_classification_report + >>> target = tensor([0, 1, 0, 1]) + >>> preds = tensor([0, 1, 1, 1]) + >>> target_names = ['0', '1'] + >>> report = binary_classification_report( + ... preds=preds, + ... target=target, + ... target_names=target_names, + ... digits=2 + ... ) + >>> print(report) # doctest: +NORMALIZE_WHITESPACE + precision recall f1-score support + + 0 1.00 0.50 0.67 2 + 1 0.67 1.00 0.80 2 + + accuracy 0.75 4 + macro avg 0.83 0.75 0.73 4 + weighted avg 0.83 0.75 0.73 4 + + """ + return classification_report( + preds, + target, + task="binary", + threshold=threshold, + target_names=target_names, + digits=digits, + output_dict=output_dict, + zero_division=zero_division, + ignore_index=ignore_index, + validate_args=validate_args, + ) + + +def multiclass_classification_report( + preds: Tensor, + target: Tensor, + num_classes: int, + target_names: Optional[List[str]] = None, + digits: int = 2, + output_dict: bool = False, + zero_division: Union[str, float] = 0.0, + ignore_index: Optional[int] = None, + validate_args: bool = True, + labels: Optional[List[int]] = None, + top_k: int = 1, +) -> Union[str, Dict[str, Union[float, Dict[str, Union[float, int]]]]]: + """Compute a classification report for multiclass classification tasks. + + The classification report shows the precision, recall, F1 score, and support for each class. + + Args: + preds: Tensor with predictions of shape (N, ...) or (N, C, ...) where C is the number of classes + target: Tensor with ground truth labels of shape (N, ...) + num_classes: Number of classes + target_names: Optional list of names for the classes + digits: Number of decimal places to display in the report + output_dict: If True, return a dict instead of a string report + zero_division: Value to use when dividing by zero + ignore_index: Optional index to ignore in the target + validate_args: bool indicating if input arguments and tensors should be validated for correctness + labels: Optional list of label indices to include in the report + top_k: Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + + Returns: + If output_dict=True, a dictionary with the classification report data. + Otherwise, a formatted string with the classification report. + + Example: + >>> from torch import tensor + >>> from torchmetrics.functional.classification.classification_report import multiclass_classification_report + >>> target = tensor([0, 1, 2, 2, 2]) + >>> preds = tensor([0, 0, 2, 2, 1]) + >>> target_names = ["class 0", "class 1", "class 2"] + >>> report = multiclass_classification_report( + ... preds=preds, + ... target=target, + ... num_classes=3, + ... target_names=target_names, + ... digits=2 + ... ) + >>> print(report) # doctest: +NORMALIZE_WHITESPACE + precision recall f1-score support + + class 0 0.50 1.00 0.67 1 + class 1 0.00 0.00 0.00 1 + class 2 1.00 0.67 0.80 3 + + accuracy 0.60 5 + macro avg 0.50 0.56 0.49 5 + weighted avg 0.70 0.60 0.61 5 + + """ + return classification_report( + preds, + target, + task="multiclass", + num_classes=num_classes, + target_names=target_names, + digits=digits, + output_dict=output_dict, + zero_division=zero_division, + ignore_index=ignore_index, + validate_args=validate_args, + labels=labels, + top_k=top_k, + ) + + +def multilabel_classification_report( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + target_names: Optional[List[str]] = None, + digits: int = 2, + output_dict: bool = False, + zero_division: Union[str, float] = 0.0, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Union[str, Dict[str, Union[float, Dict[str, Union[float, int]]]]]: + """Compute a classification report for multilabel classification tasks. + + The classification report shows the precision, recall, F1 score, and support for each label. + + Args: + preds: Tensor with predictions of shape (N, L, ...) where L is the number of labels + target: Tensor with ground truth labels of shape (N, L, ...) + num_labels: Number of labels + threshold: Threshold for converting probabilities to binary predictions + target_names: Optional list of names for the labels + digits: Number of decimal places to display in the report + output_dict: If True, return a dict instead of a string report + zero_division: Value to use when dividing by zero + ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness + + Returns: + If output_dict=True, a dictionary with the classification report data. + Otherwise, a formatted string with the classification report. + + Example: + >>> from torch import tensor + >>> from torchmetrics.functional.classification.classification_report import multilabel_classification_report + >>> target = tensor([[1, 0, 1], [0, 1, 0], [1, 1, 0]]) + >>> preds = tensor([[1, 0, 1], [0, 1, 1], [1, 0, 0]]) + >>> target_names = ["Label A", "Label B", "Label C"] + >>> report = multilabel_classification_report( + ... preds=preds, + ... target=target, + ... num_labels=len(target_names), + ... target_names=target_names, + ... digits=2, + ... ) + >>> print(report) # doctest: +NORMALIZE_WHITESPACE + precision recall f1-score support + + Label A 1.00 1.00 1.00 2 + Label B 1.00 0.50 0.67 2 + Label C 0.50 1.00 0.67 1 + + micro avg 0.80 0.80 0.80 5 + macro avg 0.83 0.83 0.78 5 + weighted avg 0.90 0.80 0.80 5 + samples avg 0.83 0.83 0.78 5 + + """ + return classification_report( + preds, + target, + task="multilabel", + num_labels=num_labels, + threshold=threshold, + target_names=target_names, + digits=digits, + output_dict=output_dict, + zero_division=zero_division, + ignore_index=ignore_index, + validate_args=validate_args, + ) diff --git a/tests/unittests/classification/test_classification_report.py b/tests/unittests/classification/test_classification_report.py new file mode 100644 index 00000000000..826bb1aa299 --- /dev/null +++ b/tests/unittests/classification/test_classification_report.py @@ -0,0 +1,973 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest +import torch +from sklearn import datasets +from sklearn.metrics import classification_report +from sklearn.svm import SVC +from sklearn.utils import check_random_state + +from torchmetrics.classification import ClassificationReport +from torchmetrics.functional.classification.classification_report import ( + binary_classification_report, + multiclass_classification_report, + multilabel_classification_report, +) +from torchmetrics.functional.classification.classification_report import ( + classification_report as functional_classification_report, +) +from unittests._helpers import seed_all + +seed_all(42) + + +def make_prediction(dataset=None, binary=False): + """Make some classification predictions on a toy dataset using a SVC. + + If binary is True restrict to a binary classification problem instead of a multiclass classification problem. + + This is adapted from scikit-learn's test_classification.py. + + """ + if dataset is None: + # import some data to play with + dataset = datasets.load_iris() + + x = dataset.data + y = dataset.target + + if binary: + # restrict to a binary classification task + x, y = x[y < 2], y[y < 2] + + n_samples, n_features = x.shape + p = np.arange(n_samples) + + rng = check_random_state(37) + rng.shuffle(p) + x, y = x[p], y[p] + half = int(n_samples / 2) + + # add noisy features to make the problem harder and avoid perfect results + rng = np.random.RandomState(0) + x = np.c_[x, rng.randn(n_samples, 200 * n_features)] + + # run classifier, get class probabilities and label predictions + clf = SVC(kernel="linear", probability=True, random_state=0) + y_pred_proba = clf.fit(x[:half], y[:half]).predict_proba(x[half:]) + + if binary: + # only interested in probabilities of the positive case + y_pred_proba = y_pred_proba[:, 1] + + y_pred = clf.predict(x[half:]) + y_true = y[half:] + return y_true, y_pred, y_pred_proba + + +# Define fixtures for test data with different scenarios +@pytest.fixture( + params=[ + ("binary", "get_binary_test_data"), + ("multiclass", "get_multiclass_test_data"), + ("multiclass", "get_balanced_multiclass_test_data"), + ("multilabel", "get_multilabel_test_data"), + ] +) +def classification_test_data(request): + """Return test data for different classification scenarios.""" + task, data_fn = request.param + + # Get the appropriate test data function + data_function = globals()[data_fn] + + if task == "multilabel": + y_true, y_pred, y_prob, target_names = data_function() + return task, y_true, y_pred, target_names, y_prob + y_true, y_pred, target_names = data_function() + return task, y_true, y_pred, target_names, None + + +def get_test_data_with_ignore_index(task): + """Generate test data with ignore_index scenario for different tasks.""" + if task == "binary": + preds = torch.tensor([0, 1, 1, 0, 1, 0]) + target = torch.tensor([0, 1, -1, 0, 1, -1]) # -1 will be ignored + ignore_index = -1 + expected_support = 4 # Only 4 valid samples + return preds, target, ignore_index, expected_support + if task == "multiclass": + preds = torch.tensor([0, 1, 2, 1, 2, 0, 1]) + target = torch.tensor([0, 1, 2, -1, 2, 0, -1]) # -1 will be ignored + ignore_index = -1 + expected_support = 5 # Only 5 valid samples + return preds, target, ignore_index, expected_support + if task == "multilabel": + preds = torch.tensor([[1, 0, 1], [0, 1, 0], [1, 1, 0], [0, 0, 1]]) + target = torch.tensor([[1, 0, 1], [0, -1, 0], [1, 1, -1], [0, 0, 1]]) # -1 will be ignored + ignore_index = -1 + expected_support = [2, 1, 2] # Per-label support counts + return preds, target, ignore_index, expected_support + return None, None, None, None + + +# Define test cases for different scenarios +def get_multiclass_test_data(): + """Get test data for multiclass scenarios.""" + iris = datasets.load_iris() + y_true, y_pred, _ = make_prediction(dataset=iris, binary=False) + return y_true, y_pred, iris.target_names + + +def get_binary_test_data(): + """Get test data for binary scenarios.""" + iris = datasets.load_iris() + y_true, y_pred, _ = make_prediction(dataset=iris, binary=True) + return y_true, y_pred, iris.target_names[:2] + + +def get_balanced_multiclass_test_data(): + """Get balanced multiclass test data.""" + y_true = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]) + y_pred = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2]) + return y_true, y_pred, None + + +def get_multilabel_test_data(): + """Get test data for multilabel scenarios.""" + # Create a multilabel dataset with 3 labels + num_samples = 100 # Increased for more stable metrics + num_labels = 3 + + # Generate random predictions and targets with some correlation + rng = np.random.RandomState(42) + y_true = rng.randint(0, 2, size=(num_samples, num_labels)) + + # Generate predictions that are mostly correct but with some noise + y_pred = y_true.copy() + flip_mask = rng.random(y_true.shape) < 0.2 # 20% chance of flipping a label + y_pred[flip_mask] = 1 - y_pred[flip_mask] + + # Generate probability predictions (not strictly proper probabilities, but good for testing) + y_prob = np.zeros_like(y_pred, dtype=float) + y_prob[y_pred == 1] = rng.uniform(0.5, 1.0, size=y_pred[y_pred == 1].shape) + y_prob[y_pred == 0] = rng.uniform(0.0, 0.5, size=y_pred[y_pred == 0].shape) + + # Create label names + label_names = [f"Label_{i}" for i in range(num_labels)] + + return y_true, y_pred, y_prob, label_names + + +class _BaseTestClassificationReport: + """Base class for ClassificationReport tests.""" + + def _assert_dicts_equal(self, d1, d2, atol=1e-8): + """Helper to assert two dictionaries are approximately equal.""" + assert set(d1.keys()) == set(d2.keys()) + for k in d1: + if isinstance(d1[k], dict): + self._assert_dicts_equal(d1[k], d2[k], atol) + elif isinstance(d1[k], (int, np.integer)): + assert d1[k] == d2[k], f"Mismatch for key {k}: {d1[k]} != {d2[k]}" + else: + # Handle NaN values specially - if both are NaN, consider them equal + if np.isnan(d1[k]) and np.isnan(d2[k]): + continue + assert np.allclose(d1[k], d2[k], atol=atol), f"Mismatch for key {k}: {d1[k]} != {d2[k]}" + + def _assert_dicts_equal_with_tolerance(self, expected_dict, actual_dict): + """Compare two classification report dictionaries for approximate equality.""" + # The keys might be different between scikit-learn and torchmetrics + # especially for binary classification, where class ordering might be different + # Here we primarily verify that the important aggregate metrics are present + + # Check accuracy + if "accuracy" in expected_dict and "accuracy" in actual_dict: + expected_accuracy = expected_dict["accuracy"] + actual_accuracy = actual_dict["accuracy"] + # Handle tensor vs float + if hasattr(actual_accuracy, "item"): + actual_accuracy = actual_accuracy.item() + assert abs(expected_accuracy - actual_accuracy) < 1e-2, ( + f"Accuracy metric doesn't match: {expected_accuracy} vs {actual_accuracy}" + ) + + # Check if aggregate metrics exist + for avg_key in ["macro avg", "weighted avg"]: + if avg_key in expected_dict: + # Either the exact key or a variant might exist + found_key = None + for key in actual_dict: + if key.replace("-", " ") == avg_key: + found_key = key + break + + # Skip detailed comparison as implementations may differ + assert found_key is not None, f"Missing aggregate metric: {avg_key}" + + # For individual classes, just check presence rather than exact values + # as binary classification can have significant implementation differences + for cls_key in expected_dict: + if isinstance(expected_dict[cls_key], dict) and cls_key not in ["macro avg", "weighted avg", "micro avg"]: + # For individual classes, just check if metrics exist + class_exists = False + for key in actual_dict: + if isinstance(actual_dict[key], dict) and key not in ["macro avg", "weighted avg", "micro avg"]: + class_exists = True + break + assert class_exists, f"Missing class metrics for class: {cls_key}" + + def _verify_string_report(self, report): + """Verify that a string report has the expected format.""" + assert isinstance(report, str) + assert "precision" in report + assert "recall" in report + assert "f1-score" in report + assert "support" in report + + # Check for aggregate metrics + assert any( + metric in report for metric in ["accuracy", "macro avg", "weighted avg", "macro-avg", "weighted-avg"] + ) + + +@pytest.mark.parametrize("output_dict", [False, True]) +class TestClassificationReport(_BaseTestClassificationReport): + """Unified test class for all ClassificationReport types.""" + + @pytest.mark.parametrize("with_target_names", [True, False]) + @pytest.mark.parametrize("use_probabilities", [False, True]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_classification_report( + self, classification_test_data, output_dict, with_target_names, use_probabilities, ignore_index + ): + """Test the classification report across different scenarios.""" + task, y_true, y_pred, target_names, y_prob = classification_test_data + + # Skip irrelevant combinations + if task != "multilabel" and use_probabilities: + pytest.skip("Probabilities only relevant for multilabel tasks") + + # Use ignore_index test data if ignore_index is specified + if ignore_index is not None: + y_pred, y_true, ignore_index, expected_support = get_test_data_with_ignore_index(task) + target_names = ["0", "1", "2"] if task in ["multiclass", "multilabel"] else ["0", "1"] + + # Create common parameters for all tasks + common_params = { + "task": task, + "output_dict": output_dict, + "ignore_index": ignore_index, + } + + # Add task-specific parameters + if task == "binary": + common_params["num_classes"] = len(np.unique(y_true)) if ignore_index is None else 2 + elif task == "multiclass": + common_params["num_classes"] = len(np.unique(y_true)) if ignore_index is None else 3 + elif task == "multilabel": + common_params["num_labels"] = y_true.shape[1] if ignore_index is None else 3 + common_params["threshold"] = 0.5 + + # Handle target names + if with_target_names and target_names is not None: + common_params["target_names"] = target_names + + # Create metric and update with data + torchmetrics_report = ClassificationReport(**common_params) + + # Use probabilities if applicable (only for multilabel currently) + if task == "multilabel" and use_probabilities and y_prob is not None and ignore_index is None: + torchmetrics_report.update(torch.tensor(y_prob), torch.tensor(y_true)) + else: + torchmetrics_report.update(torch.tensor(y_pred), torch.tensor(y_true)) + + # Compute result + result = torchmetrics_report.compute() + + # For comparison, generate sklearn report when possible + if ( + task != "multilabel" and ignore_index is None + ): # sklearn doesn't support multilabel or ignore_index in the same way + # Generate sklearn report + sklearn_params = { + "output_dict": output_dict, + } + + if with_target_names and target_names is not None: + sklearn_params["target_names"] = target_names + sklearn_params["labels"] = np.arange(len(target_names)) + + report_scikit = classification_report(y_true, y_pred, **sklearn_params) + + # Verify results + if output_dict: + self._assert_dicts_equal_with_tolerance(report_scikit, result) + else: + self._verify_string_report(result) + else: + # For multilabel or ignore_index cases, we don't have a direct sklearn comparison + # Verify the format is correct + if output_dict: + # Check basic structure + if with_target_names and target_names is not None: + for label in target_names: + assert label in result + assert "precision" in result[label] + assert "recall" in result[label] + assert "f1-score" in result[label] + assert "support" in result[label] + + # Check for aggregate metrics + possible_avg_keys = ["micro avg", "macro avg", "weighted avg", "micro-avg", "macro-avg", "weighted-avg"] + assert any(key in result for key in possible_avg_keys) + + # Additional tests for ignore_index functionality + if ignore_index is not None: + self._test_ignore_index_functionality(task, result, expected_support) + else: + self._verify_string_report(result) + + def _test_ignore_index_functionality(self, task, tm_report, expected_support): + """Test that ignore_index functionality works correctly.""" + if task in ["binary", "multiclass"]: + # Check that total support matches expected (ignored samples excluded) + total_support = sum( + tm_report[key]["support"] + for key in tm_report + if key + not in ["accuracy", "macro avg", "weighted avg", "macro-avg", "weighted-avg", "micro avg", "micro-avg"] + ) + assert total_support == expected_support + elif task == "multilabel": + # For multilabel, check per-label support + for i, label_key in enumerate(["0", "1", "2"]): + if label_key in tm_report: + assert tm_report[label_key]["support"] == expected_support[i] + + @pytest.mark.parametrize("task", ["binary", "multiclass", "multilabel"]) + def test_functional_equivalence(self, task, output_dict): + """Test that the functional and class implementations are equivalent.""" + # Create test data based on task + if task == "binary": + y_true, y_pred, target_names = get_binary_test_data() + y_prob = None + elif task == "multiclass": + y_true, y_pred, target_names = get_multiclass_test_data() + y_prob = None + else: # multilabel + y_true, y_pred, y_prob, target_names = get_multilabel_test_data() + + # Create common parameters + common_params = { + "output_dict": output_dict, + "target_names": target_names, + } + + # Add task-specific parameters + if task == "binary": + common_params["threshold"] = 0.5 + elif task == "multiclass": + common_params["num_classes"] = len(np.unique(y_true)) + elif task == "multilabel": + common_params["num_labels"] = y_true.shape[1] + common_params["threshold"] = 0.5 + + # Get class implementation result + class_metric = ClassificationReport(task=task, **common_params) + class_metric.update(torch.tensor(y_pred), torch.tensor(y_true)) + class_result = class_metric.compute() + + # Get functional implementation result + if task == "binary": + func_result = binary_classification_report(torch.tensor(y_pred), torch.tensor(y_true), **common_params) + elif task == "multiclass": + func_result = multiclass_classification_report(torch.tensor(y_pred), torch.tensor(y_true), **common_params) + elif task == "multilabel": + func_result = multilabel_classification_report(torch.tensor(y_pred), torch.tensor(y_true), **common_params) + + # Also test the general functional implementation + general_result = functional_classification_report( + torch.tensor(y_pred), torch.tensor(y_true), task=task, **common_params + ) + + # Verify results are equivalent + if output_dict: + self._assert_dicts_equal(class_result, func_result) + self._assert_dicts_equal(class_result, general_result) + else: + # For string output, check they have the same key content + for metric in ["precision", "recall", "f1-score", "support"]: + assert metric in func_result + assert metric in general_result + + @pytest.mark.parametrize("task", ["binary", "multiclass", "multilabel"]) + @pytest.mark.parametrize("ignore_value", [-1, 99]) + def test_ignore_index_specific_functionality(self, task, ignore_value, output_dict): + """Test specific ignore_index functionality and edge cases.""" + # Create test data with ignore_index values + if task == "binary": + preds = torch.tensor([0, 1, 1, 0, 1, 0]) + target = torch.tensor([0, 1, ignore_value, 0, 1, ignore_value]) + expected_support = 4 + num_classes = 2 + func_call = binary_classification_report + common_params = {"threshold": 0.5} + elif task == "multiclass": + preds = torch.tensor([0, 1, 2, 1, 2, 0, 1]) + target = torch.tensor([0, 1, 2, ignore_value, 2, 0, ignore_value]) + expected_support = 5 + num_classes = 3 + func_call = multiclass_classification_report + common_params = {"num_classes": num_classes} + else: # multilabel + preds = torch.tensor([[1, 0, 1], [0, 1, 0], [1, 1, 0], [0, 0, 1]]) + target = torch.tensor([[1, 0, 1], [0, ignore_value, 0], [1, 1, ignore_value], [0, 0, 1]]) + expected_support = [2, 1, 2] # Per-label support + func_call = multilabel_classification_report + common_params = {"num_labels": 3, "threshold": 0.5} + + # Test functional version + result = func_call(preds=preds, target=target, ignore_index=ignore_value, output_dict=True, **common_params) + + # Test modular version + metric_params = {"task": task, "ignore_index": ignore_value, "output_dict": True} + if task == "binary" or task == "multiclass": + metric_params.update(common_params) + else: # multilabel + metric_params.update(common_params) + + metric = ClassificationReport(**metric_params) + metric.update(preds, target) + result_modular = metric.compute() + + # Verify support counts + if task in ["binary", "multiclass"]: + total_support = sum(result[str(i)]["support"] for i in range(num_classes)) + total_support_modular = sum(result_modular[str(i)]["support"] for i in range(num_classes)) + assert total_support == expected_support + assert total_support_modular == expected_support + else: # multilabel + for i in range(3): + assert result[str(i)]["support"] == expected_support[i] + assert result_modular[str(i)]["support"] == expected_support[i] + + # Test that ignore_index=None behaves like no ignore_index + result_none = func_call( + preds=preds, + target=torch.where(target == ignore_value, 0, target), # Replace ignore values with valid ones + ignore_index=None, + output_dict=True, + **common_params, + ) + + result_no_param = func_call( + preds=preds, target=torch.where(target == ignore_value, 0, target), output_dict=True, **common_params + ) + + # These should be equivalent + if task in ["binary", "multiclass"]: + for i in range(num_classes): + if str(i) in result_none and str(i) in result_no_param: + assert abs(result_none[str(i)]["support"] - result_no_param[str(i)]["support"]) < 1e-6 + else: # multilabel + for i in range(3): + if str(i) in result_none and str(i) in result_no_param: + assert abs(result_none[str(i)]["support"] - result_no_param[str(i)]["support"]) < 1e-6 + + def test_ignore_index_accuracy_calculation(self, output_dict): + """Test that ignore_index properly affects accuracy calculation.""" + # Create scenario where ignored indices would change accuracy + preds = torch.tensor([0, 1, 0, 1]) + target = torch.tensor([0, 1, -1, -1]) # Last two are ignored + + result = binary_classification_report(preds=preds, target=target, ignore_index=-1, output_dict=True) + + # With ignore_index, accuracy should be 1.0 (2/2 correct) + assert result["accuracy"] == 1.0 + + # Compare with case where we have wrong predictions for ignored indices + preds_wrong = torch.tensor([0, 1, 1, 0]) # Wrong predictions for what would be ignored + target_wrong = torch.tensor([0, 1, -1, -1]) + + result_wrong = binary_classification_report( + preds=preds_wrong, target=target_wrong, ignore_index=-1, output_dict=True + ) + + # Should still be 1.0 because ignored indices don't affect accuracy + assert result_wrong["accuracy"] == 1.0 + + +@pytest.mark.parametrize( + ("y_true", "y_pred", "output_dict", "expected_avg_keys"), + [ + ( + np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]), + np.array([0, 1, 2, 0, 1, 2, 0, 1, 2]), + True, + ["macro avg", "weighted avg"], + ), + ], +) +def test_classification_report_dict_format(y_true, y_pred, output_dict, expected_avg_keys): + """Test the format of classification report when output_dict=True.""" + num_classes = len(np.unique(y_true)) + torchmetrics_report = ClassificationReport(output_dict=output_dict, task="multiclass", num_classes=num_classes) + torchmetrics_report.update(torch.tensor(y_pred), torch.tensor(y_true)) + result_dict = torchmetrics_report.compute() + + # Check dictionary format + for key in expected_avg_keys: + assert key in result_dict, f"Key '{key}' is missing from the classification report" + + # Check class keys are present + unique_classes = np.unique(y_true) + for cls in unique_classes: + assert str(cls) in result_dict, f"Class '{cls}' is missing from the report" + + # Check metrics structure + for cls_key in [str(cls) for cls in unique_classes]: + for metric in ["precision", "recall", "f1-score", "support"]: + assert metric in result_dict[cls_key], f"Metric '{metric}' missing for class '{cls_key}'" + + +def test_task_validation(): + """Test validation of task parameter.""" + with pytest.raises(ValueError, match="Invalid Classification: expected one of"): + _ = ClassificationReport(task="invalid_task") + + +def test_functional_invalid_task(): + """Test validation of task parameter in functional classification_report.""" + y_true = torch.tensor([0, 1, 0, 1]) + y_pred = torch.tensor([0, 0, 1, 1]) + + with pytest.raises(ValueError, match="Invalid Classification: expected one of"): + functional_classification_report(y_pred, y_true, task="invalid_task") + + +# Add parameterized tests for various edge cases +@pytest.mark.parametrize("task", ["binary", "multiclass", "multilabel"]) +@pytest.mark.parametrize("output_dict", [True, False]) +@pytest.mark.parametrize("zero_division", [0, 1, "warn"]) +def test_zero_division_handling(task, output_dict, zero_division): + """Test zero_division parameter works correctly across all classification types.""" + # Create edge case data with some classes having no support + if task == "binary": + # Create data where class 1 never appears in target + y_true = np.array([0, 0, 0, 0]) + y_pred = np.array([0, 1, 0, 1]) + params = {"threshold": 0.5} + elif task == "multiclass": + # Create data where class 2 never appears in target + y_true = np.array([0, 0, 1, 1]) + y_pred = np.array([0, 2, 1, 2]) + params = {"num_classes": 3} + else: # multilabel + # Create data where second label never appears + y_true = np.array([[1, 0, 1], [1, 0, 0], [0, 0, 1], [1, 0, 0]]) + y_pred = np.array([[1, 1, 1], [0, 1, 0], [1, 0, 1], [1, 1, 0]]) + params = {"num_labels": 3, "threshold": 0.5} + + # Create report with zero_division parameter + report = ClassificationReport(task=task, output_dict=output_dict, zero_division=zero_division, **params) + + report.update(torch.tensor(y_pred), torch.tensor(y_true)) + result = report.compute() + + # Check the results + if output_dict: + # Verify that a result is produced + if task == "binary": + # Verify class '1' is in the result if it was predicted + if "1" in result: + # Just check that precision exists - actual value depends on implementation + assert "precision" in result["1"] + + # For zero_division=0, precision should always be 0 for classes with no support + if zero_division == 0: + assert result["1"]["precision"] == 0.0 + + elif task == "multiclass" and "2" in result: + # Just check that precision exists - actual value depends on implementation + assert "precision" in result["2"] + + # For zero_division=0, precision should always be 0 for classes with no support + if zero_division == 0: + assert result["2"]["precision"] == 0.0 + else: + # For string output, just verify it's a valid string + assert isinstance(result, str) + + +# Tests for top_k functionality +@pytest.mark.parametrize("output_dict", [True, False]) +@pytest.mark.parametrize("top_k", [1, 2, 3]) +def test_multiclass_classification_report_top_k(output_dict, top_k): + """Test top_k functionality in multiclass classification report.""" + # Create simple test data where top_k can make a difference + num_classes = 3 + + # Create predictions with specific pattern for testing top_k + preds = torch.tensor([ + [0.1, 0.8, 0.1], # Class 1 is top-1, class 0 is top-2 -> target: 0 + [0.7, 0.2, 0.1], # Class 0 is top-1, class 1 is top-2 -> target: 1 + [0.1, 0.1, 0.8], # Class 2 is top-1, class 0 is top-2 -> target: 2 + [0.4, 0.5, 0.1], # Class 1 is top-1, class 0 is top-2 -> target: 0 + [0.3, 0.6, 0.1], # Class 1 is top-1, class 0 is top-2 -> target: 1 + [0.2, 0.1, 0.7], # Class 2 is top-1, class 0 is top-2 -> target: 2 + [0.6, 0.3, 0.1], # Class 0 is top-1, class 1 is top-2 -> target: 0 + [0.2, 0.7, 0.1], # Class 1 is top-1, class 0 is top-2 -> target: 1 + [0.1, 0.2, 0.7], # Class 2 is top-1, class 1 is top-2 -> target: 2 + [0.5, 0.4, 0.1], # Class 0 is top-1, class 1 is top-2 -> target: 0 + [0.1, 0.8, 0.1], # Class 1 is top-1, class 0 is top-2 -> target: 1 + [0.1, 0.3, 0.6], # Class 2 is top-1, class 1 is top-2 -> target: 2 + ]) + + target = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]) + + # Test functional interface + result_functional = multiclass_classification_report( + preds=preds, target=target, num_classes=num_classes, top_k=top_k, output_dict=output_dict + ) + + # Test class interface + metric = ClassificationReport(task="multiclass", num_classes=num_classes, top_k=top_k, output_dict=output_dict) + metric.update(preds, target) + result_class = metric.compute() + + # Verify both interfaces produce same result + if output_dict: + assert isinstance(result_functional, dict) + assert isinstance(result_class, dict) + # Check that accuracy improves with higher top_k (should be non-decreasing) + if "accuracy" in result_functional: + assert result_functional["accuracy"] >= 0.0 + assert result_functional["accuracy"] <= 1.0 + else: + assert isinstance(result_functional, str) + assert isinstance(result_class, str) + # Verify standard metrics are present in string output + assert "precision" in result_functional + assert "recall" in result_functional + assert "f1-score" in result_functional + assert "support" in result_functional + + # Verify that functional and class methods produce identical results + assert result_functional == result_class + + +@pytest.mark.parametrize("top_k", [1, 2, 3]) +def test_multiclass_classification_report_top_k_accuracy_monotonic(top_k): + """Test that accuracy is monotonic non-decreasing with increasing top_k.""" + num_classes = 4 + batch_size = 20 + + # Create random but consistent test data + torch.manual_seed(42) + preds = torch.randn(batch_size, num_classes).softmax(dim=1) + target = torch.randint(0, num_classes, (batch_size,)) + + result = multiclass_classification_report( + preds=preds, target=target, num_classes=num_classes, top_k=top_k, output_dict=True + ) + + # Basic sanity checks + assert "accuracy" in result + assert 0.0 <= result["accuracy"] <= 1.0 + + # Check that all class metrics are present + for i in range(num_classes): + assert str(i) in result + class_metrics = result[str(i)] + assert "precision" in class_metrics + assert "recall" in class_metrics + assert "f1-score" in class_metrics + assert "support" in class_metrics + + +def test_multiclass_classification_report_top_k_comparison(): + """Test that higher top_k generally leads to equal or better accuracy.""" + num_classes = 5 + batch_size = 50 + + # Create test data where top_k makes a significant difference + torch.manual_seed(123) + preds = torch.randn(batch_size, num_classes).softmax(dim=1) + target = torch.randint(0, num_classes, (batch_size,)) + + accuracies = {} + + for k in [1, 2, 3, 4, 5]: + result = multiclass_classification_report( + preds=preds, target=target, num_classes=num_classes, top_k=k, output_dict=True + ) + accuracies[k] = result["accuracy"] + + # Verify accuracy is non-decreasing + for k in range(1, 5): + assert accuracies[k] <= accuracies[k + 1], ( + f"Accuracy should be non-decreasing with top_k: " + f"top_{k}={accuracies[k]:.3f} > top_{k + 1}={accuracies[k + 1]:.3f}" + ) + + # At top_k = num_classes, accuracy should be 1.0 + assert accuracies[5] == 1.0, f"Accuracy at top_k=num_classes should be 1.0, got {accuracies[5]}" + + +@pytest.mark.parametrize("ignore_index", [None, -1]) +@pytest.mark.parametrize("top_k", [1, 2]) +def test_multiclass_classification_report_top_k_with_ignore_index(ignore_index, top_k): + """Test top_k functionality works correctly with ignore_index.""" + num_classes = 3 + + preds = torch.tensor([ + [0.6, 0.3, 0.1], # pred: 0, target: 0 (correct) + [0.2, 0.7, 0.1], # pred: 1, target: 1 (correct) + [0.1, 0.2, 0.7], # pred: 2, target: ignored + [0.4, 0.5, 0.1], # pred: 1, target: 0 (wrong for top-1, correct for top-2) + ]) + + target = torch.tensor([0, 1, ignore_index, 0]) if ignore_index is not None else torch.tensor([0, 1, 2, 0]) + + result = multiclass_classification_report( + preds=preds, target=target, num_classes=num_classes, top_k=top_k, ignore_index=ignore_index, output_dict=True + ) + + # Basic verification + assert "accuracy" in result + assert 0.0 <= result["accuracy"] <= 1.0 + + # With ignore_index, the third sample should be ignored + if ignore_index is not None and top_k == 2: + # With top_k=2, the last prediction [0.4, 0.5, 0.1] should be correct + # since target=0 and both classes 0 and 1 are in top-2 + expected_accuracy = 1.0 # 3 out of 3 valid samples correct + assert abs(result["accuracy"] - expected_accuracy) < 1e-6 + + +def test_classification_report_wrapper_top_k(): + """Test that the wrapper ClassificationReport correctly handles top_k.""" + num_classes = 3 + preds = torch.tensor([ + [0.1, 0.8, 0.1], + [0.7, 0.2, 0.1], + [0.1, 0.1, 0.8], + ]) + target = torch.tensor([0, 1, 2]) + + # Test with different top_k values + for top_k in [1, 2, 3]: + report = ClassificationReport(task="multiclass", num_classes=num_classes, top_k=top_k, output_dict=True) + + report.update(preds, target) + result = report.compute() + + assert "accuracy" in result + assert 0.0 <= result["accuracy"] <= 1.0 + + # Check that all expected classes are present + for i in range(num_classes): + assert str(i) in result + + +@pytest.mark.parametrize("top_k", [1, 2]) +def test_functional_classification_report_top_k(top_k): + """Test that the main functional classification_report interface supports top_k.""" + num_classes = 3 + preds = torch.tensor([ + [0.1, 0.8, 0.1], + [0.7, 0.2, 0.1], + [0.1, 0.1, 0.8], + ]) + target = torch.tensor([0, 1, 2]) + + result = functional_classification_report( + preds=preds, target=target, task="multiclass", num_classes=num_classes, top_k=top_k, output_dict=True + ) + + assert "accuracy" in result + assert 0.0 <= result["accuracy"] <= 1.0 + + # Verify structure is correct + for i in range(num_classes): + assert str(i) in result + metrics = result[str(i)] + assert "precision" in metrics + assert "recall" in metrics + assert "f1-score" in metrics + assert "support" in metrics + + +def test_top_k_binary_task_ignored(): + """Test that top_k parameter is ignored for binary tasks (should not cause errors).""" + preds = torch.tensor([0.1, 0.9, 0.3, 0.8]) + target = torch.tensor([0, 1, 0, 1]) + + # top_k should be ignored for binary classification + result1 = functional_classification_report(preds=preds, target=target, task="binary", top_k=1, output_dict=True) + + result2 = functional_classification_report( + preds=preds, + target=target, + task="binary", + top_k=5, # Should be ignored + output_dict=True, + ) + + # Results should be identical since top_k is ignored for binary + assert result1 == result2 + + +def test_top_k_multilabel_task_ignored(): + """Test that top_k parameter is ignored for multilabel tasks.""" + preds = torch.tensor([[0.1, 0.9], [0.8, 0.2], [0.3, 0.7]]) + target = torch.tensor([[0, 1], [1, 0], [0, 1]]) + + # top_k should be ignored for multilabel classification + result1 = functional_classification_report( + preds=preds, target=target, task="multilabel", num_labels=2, top_k=1, output_dict=True + ) + + result2 = functional_classification_report( + preds=preds, + target=target, + task="multilabel", + num_labels=2, + top_k=5, # Should be ignored + output_dict=True, + ) + + # Results should be identical since top_k is ignored for multilabel + assert result1 == result2 + + +class TestTopKFunctionality: + """Test class specifically for top_k functionality in multiclass classification.""" + + def test_top_k_basic_functionality(self): + """Test basic top_k functionality with probabilities.""" + # Create predictions where top-1 prediction is wrong but top-2 includes correct label + preds = torch.tensor([ + [0.1, 0.8, 0.1], # Predicted: 1, True: 0 (wrong for top-1, correct for top-2) + [0.2, 0.3, 0.5], # Predicted: 2, True: 2 (correct for both) + [0.6, 0.3, 0.1], # Predicted: 0, True: 1 (wrong for top-1, correct for top-2) + ]) + target = torch.tensor([0, 2, 1]) + + multiclass_classification_report(preds=preds, target=target, num_classes=3, top_k=1, output_dict=True) + + # Test top_k=2 (should have higher accuracy) + result_k2 = multiclass_classification_report( + preds=preds, target=target, num_classes=3, top_k=2, output_dict=True + ) + + # With top_k=2, accuracy should be 3/3 = 1.0 (all samples have correct label in top-2) + assert result_k2["accuracy"] == 1.0 + + # Per-class metrics should also improve with top_k=2 + assert result_k2["0"]["recall"] >= result_k2["0"]["recall"] + assert result_k2["1"]["recall"] >= result_k2["1"]["recall"] + + def test_top_k_with_logits(self): + """Test top_k functionality with logits (unnormalized scores).""" + # Logits that will be converted to probabilities via softmax + preds = torch.tensor([ + [1.0, 3.0, 1.0], # After softmax: highest prob for class 1, true label is 0 + [2.0, 1.0, 4.0], # After softmax: highest prob for class 2, true label is 2 + [3.0, 2.0, 1.0], # After softmax: highest prob for class 0, true label is 1 + ]) + target = torch.tensor([0, 2, 1]) + + multiclass_classification_report(preds=preds, target=target, num_classes=3, top_k=1, output_dict=True) + + result_k2 = multiclass_classification_report( + preds=preds, target=target, num_classes=3, top_k=2, output_dict=True + ) + + # top_k=2 should perform better than or equal to top_k=1 + assert result_k2["accuracy"] >= 0.0 + + def test_top_k_with_class_wrapper(self): + """Test top_k functionality through the ClassificationReport wrapper class.""" + preds = torch.tensor([ + [0.1, 0.8, 0.1], + [0.2, 0.3, 0.5], + [0.6, 0.3, 0.1], + ]) + target = torch.tensor([0, 2, 1]) + + # Test with class-based implementation + metric_k1 = ClassificationReport(task="multiclass", num_classes=3, top_k=1, output_dict=True) + metric_k1.update(preds, target) + result_k1 = metric_k1.compute() + + metric_k2 = ClassificationReport(task="multiclass", num_classes=3, top_k=2, output_dict=True) + metric_k2.update(preds, target) + result_k2 = metric_k2.compute() + + # top_k=2 should perform better + assert result_k2["accuracy"] >= result_k1["accuracy"] + + # Test equivalence with functional implementation + func_result_k2 = multiclass_classification_report( + preds=preds, target=target, num_classes=3, top_k=2, output_dict=True + ) + + assert result_k2["accuracy"] == func_result_k2["accuracy"] + + @pytest.mark.parametrize("top_k", [1, 2, 3]) + def test_top_k_edge_cases(self, top_k): + """Test top_k with different values and edge cases.""" + # Simple case where all predictions are correct for top-1 + preds = torch.tensor([ + [0.9, 0.05, 0.05], # Correct: class 0 + [0.05, 0.9, 0.05], # Correct: class 1 + [0.05, 0.05, 0.9], # Correct: class 2 + ]) + target = torch.tensor([0, 1, 2]) + + result = multiclass_classification_report( + preds=preds, target=target, num_classes=3, top_k=top_k, output_dict=True + ) + + # Should always be perfect accuracy regardless of top_k value + assert result["accuracy"] == 1.0 + + def test_top_k_larger_than_num_classes(self): + """Test behavior when top_k is larger than number of classes.""" + preds = torch.tensor([ + [0.1, 0.8, 0.1], + [0.2, 0.3, 0.5], + ]) + target = torch.tensor([0, 2]) + + # top_k=5 > num_classes=3, should raise an error as per torchmetrics validation + with pytest.raises(ValueError, match="Expected argument `top_k` to be smaller or equal to `num_classes`"): + multiclass_classification_report(preds=preds, target=target, num_classes=3, top_k=5, output_dict=True) + + def test_top_k_with_hard_predictions(self): + """Test that top_k works correctly with hard predictions (class indices).""" + # When predictions are already class indices, top_k > 1 should raise an error + # because hard predictions are 1D and can't support top_k > 1 + preds = torch.tensor([1, 2, 0]) # Hard predictions + target = torch.tensor([0, 2, 1]) + + multiclass_classification_report(preds=preds, target=target, num_classes=3, top_k=1, output_dict=True) + + # With hard predictions, top_k > 1 should raise an error + with pytest.raises(RuntimeError, match="selected index k out of range"): + multiclass_classification_report(preds=preds, target=target, num_classes=3, top_k=2, output_dict=True)