Skip to content

Conversation

aymuos15
Copy link

@aymuos15 aymuos15 commented Jun 3, 2025

What does this PR do?

Fixes #2580

Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements) - Yes, details in issue number mentioned.
  • Did you read the contributor guideline, Pull Request section? - Yes.
  • Did you make sure to update the docs? - Not yet. Can do once everything is finalised in terms of code.
  • Did you write any new necessary tests? - Yes
PR review

Classification report is a nice to have metric. The motivation can be easily inspired from its scikit learn integration. Having it ready for tensors is definitely a good step forward.

Notes:

  1. The original issue talks about having the option to include more metrics. I personally feel that is not required as this should be a quick stand alone way to get a lot of relevant (and usually required metrics info) quickly in a nice presentable way.
  2. Maybe we need a more in depth discussion to expand the testing. I do not have access to a distributed mission at the moment, but if it is paramount, I can look into it.

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Was a really nice way of getting familiar with the codebase :)


📚 Documentation preview 📚: https://torchmetrics--3116.org.readthedocs.build/en/3116/

@Borda
Copy link
Contributor

Borda commented Jun 4, 2025

Great addition!
pls check failing due to a circular import

aymuos15 and others added 5 commits June 6, 2025 11:46
- Remove direct imports of classification metrics at module level
- Implement lazy imports using @Property for Binary/Multiclass/MultilabelClassificationReport
- Move metric initialization to property getters to break circular dependencies
- Maintain all existing functionality while improving import structure
@aymuos15
Copy link
Author

aymuos15 commented Jun 6, 2025

Thank you very much! @Borda

Could you please help me understand why the TM.unittests are failing? They seemed to be linked with azure and I am not familiar with the error logs there.

Regarding the ruff failures, for some reason my local checks pass, so ill look into that. But is it okay if I tend to those once the actual code is vetted?

Sorry for the delay in getting back to, I had some immediate personal work to get done.

CC: @rittik9 I think it is ready now?

Copy link
Collaborator

@rittik9 rittik9 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aymuos15 I had a quick look and opened a pr for fixing pre-commit errors.
Doctests regarding classification report seem to be failing, mind having look.
going to take a detailed look later.

@aymuos15
Copy link
Author

aymuos15 commented Jun 6, 2025

Thanks a lot for the quickfix @rittik9.

Will look into the doctests now.

Sorry about the close and open, i think the close automatically got triggered post the local merge for some reason? not sure.

Copy link
Collaborator

@rittik9 rittik9 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @aymuos15 pls update the init.py files in the functional interface and docs.
You can refer #3090

@aymuos15
Copy link
Author

@rittik9 Thanks again for the detailed review.

I have tried to address everything bar the classwise accuracy score. That is inherently different from the scikit version of the metric. Are you okay if we raise a separate issue for that, discuss how exactly we want to add that, and then do a separate PR for that?

Copy link
Contributor

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall it looks good, just think if it would be better if we can have it rather more as a wrapper so you can select what metrics you want to have in the report (probably pass it s string metrics/class names) and then have precision, recall, F-measure as default configuration

see suggestion by @rittik9 in #2580 (comment)

@aymuos15
Copy link
Author

Thank you very much for the review.

I definitely like the idea and agree with it.

However, I think I would have to pretty much revamp everything. Would it be okay to keep this as the base and proceed with that separately in another PR or should I just continue here?

@Borda
Copy link
Contributor

Borda commented Jun 18, 2025

However, I think I would have to pretty much revamp everything. Would it be okay to keep this as the base and proceed with that separately in another PR or should I just continue here?

you can continue here 🐰

@aymuos15
Copy link
Author

aymuos15 commented Jun 18, 2025

Haha, alright. Ill take some time to understand the best way to approach this and make the tests so edge cases do not come up. Thanks again!

@rittik9
Copy link
Collaborator

rittik9 commented Jun 18, 2025

Hi @aymuos15 While you're at it, could you please ensure that the 'micro' averaging option is included by default for all three cases—binary, multiclass, and multilabel? In case the results don't align with sklearn, we can always verify them manually.

@aymuos15
Copy link
Author

Yup sure. Thank you for reminding me.

Since anyways now we are going to include everything, ill make all options exhaustive.

@Borda
Copy link
Contributor

Borda commented Jun 23, 2025

Since anyways now we are going to include everything, ill make all options exhaustive.

We can figure out how to split it into smaller pieces so it would land smoother/faster

@aymuos15
Copy link
Author

That sounds good with me. Could you please let me know what that would entail?

This is what i had in mind:

TorchMetrics Classification Report Examples

============================================================
 BINARY CLASSIFICATION
============================================================
Predictions: [0, 1, 1, 0, 1, 0, 1, 1]
True Labels: [0, 1, 0, 0, 1, 1, 1, 0]

--- CASE 1: DEFAULT METRICS (precision, recall, f1-score) ---
             precision    recall  f1-score   support 

           0      0.67      0.50      0.57         4 
           1      0.60      0.75      0.67         4 

    accuracy                          0.62         8 
   micro avg      0.62      0.62      0.62         8 
   macro avg      0.63      0.62      0.62         8 
weighted avg      0.63      0.62      0.62         8 

--- CASE 2: ADD ACCURACY TO DEFAULT METRICS ---
             precision    recall  f1-score  accuracy   support 

           0      0.67      0.50      0.57      0.50         4 
           1      0.60      0.75      0.67      0.75         4 

    accuracy                                    0.62         8 
   micro avg      0.62      0.62      0.62      0.62         8 
   macro avg      0.63      0.62      0.62      0.62         8 
weighted avg      0.63      0.62      0.62      0.62         8 

--- CASE 3: SPECIFICITY ONLY ---
             specificity     support 

           0        0.75           4 
           1        0.50           4 

    accuracy        0.62           8 
   micro avg        0.50           8 
   macro avg        0.62           8 
weighted avg        0.62           8 

Verification - Direct accuracy calculation: 0.6250

============================================================
 MULTICLASS CLASSIFICATION
============================================================
Predictions: [0, 2, 1, 2, 0, 1, 2, 0]
True Labels: [0, 1, 1, 2, 0, 2, 2, 1]

--- CASE 1: DEFAULT METRICS ---
             precision    recall  f1-score   support 

           0      0.67      1.00      0.80         2 
           1      0.50      0.33      0.40         3 
           2      0.67      0.67      0.67         3 

    accuracy                          0.62         8 
   micro avg      0.62      0.62      0.62         8 
   macro avg      0.61      0.67      0.62         8 
weighted avg      0.60      0.63      0.60         8 

--- CASE 2: ADD ACCURACY TO DEFAULT METRICS ---
             precision    recall  f1-score  accuracy   support 

           0      0.67      1.00      0.80      1.00         2 
           1      0.50      0.33      0.40      0.33         3 
           2      0.67      0.67      0.67      0.67         3 

    accuracy                                    0.62         8 
   micro avg      0.62      0.62      0.62      0.62         8 
   macro avg      0.61      0.67      0.62      0.67         8 
weighted avg      0.60      0.63      0.60      0.63         8 

--- CASE 3: SPECIFICITY ONLY ---
             specificity     support 

           0        0.83           2 
           1        0.80           3 
           2        0.80           3 

    accuracy        0.62           8 
   micro avg        0.81           8 
   macro avg        0.81           8 
weighted avg        0.81           8 

Verification - Direct accuracy calculation: 0.6667

============================================================
 MULTILABEL CLASSIFICATION
============================================================
Predictions:
  Sample 1: [1, 0, 1]
  Sample 2: [0, 1, 0]
  Sample 3: [1, 1, 0]
  Sample 4: [0, 0, 1]
True Labels:
  Sample 1: [1, 0, 0]
  Sample 2: [0, 1, 1]
  Sample 3: [1, 1, 0]
  Sample 4: [0, 0, 1]

--- CASE 1: DEFAULT METRICS ---
             precision    recall  f1-score   support 

           0      1.00      1.00      1.00         2 
           1      1.00      1.00      1.00         2 
           2      0.50      0.50      0.50         2 

   micro avg      0.83      0.83      0.83         6 
   macro avg      0.83      0.83      0.83         6 
weighted avg      0.83      0.83      0.83         6 
 samples avg      0.88      0.88      0.83         6 

--- CASE 2: ADD ACCURACY TO DEFAULT METRICS ---
             precision    recall  f1-score  accuracy   support 

           0      1.00      1.00      1.00      1.00         2 
           1      1.00      1.00      1.00      1.00         2 
           2      0.50      0.50      0.50      0.50         2 

   micro avg      0.83      0.83      0.83      0.83         6 
   macro avg      0.83      0.83      0.83      0.83         6 
weighted avg      0.83      0.83      0.83      0.83         6 
 samples avg      0.88      0.88      0.83      0.83         6 

--- CASE 3: SPECIFICITY ONLY ---
             specificity     support 

           0        1.00           2 
           1        1.00           2 
           2        0.50           2 

   micro avg        0.83           6 
   macro avg        0.83           6 
weighted avg        0.83           6 
 samples avg                       6 

Verification - Direct accuracy calculation: 0.8333

Essentially

  1. If anyone uses it off the shelf, it is exactly like scikit-learn
  2. But then, we have the option to add whatever has been discussed till now as well.

Is the above okay?

I have not committed anything yet because the code is very messy. Once we agree on a path forward, I will trigger the next commit if thats okay.

@Borda
Copy link
Contributor

Borda commented Jun 24, 2025

I had rather in mind that we can trim this PR to keep the printing functions and table formatting, and in the following PR have extension of Collection metrics either as a new subclass or a new method...

@aymuos15
Copy link
Author

Ah okay! I will push a commit for the formatting and printing tonight. Thank you.

@aymuos15
Copy link
Author

@Borda

Keeping in mind the below rigidness for the micro avg

    # Determine if micro average should be shown in the report based on classification task
    # Following scikit-learn's logic:
    # - Show for multilabel classification (always)
    # - Show for multiclass when using a subset of classes
    # - Don't show for binary classification (micro avg is same as accuracy)
    # - Don't show for full multiclass classification with all classes (micro avg is same as accuracy)
    show_micro_avg = False
    is_multilabel = task == ClassificationTask.MULTILABEL

After going through everything again, I think the PR as of right now, it does exactly the bare minimum? The main additions after the initial few commits were the ignore_index and the micro avg. So Just a few quick questions:

  1. Should I remove them as well and keep them for the follow up PR?
  2. Do I keep the base of the collection metrics (as it is not implemented in the current commit) or leave that out as well?
  3. The test file has multiple test covering all the bases, does that need to change?

Thank you very much!

@Borda
Copy link
Contributor

Borda commented Jun 25, 2025

Still, the Module-like metric is not derived from the collection to save some computations
I would just leave the functional part as the module-like part needs to be redone

@aymuos15
Copy link
Author

Okay, Thank you. So just to confirm -- I would only push src/torchmetrics/classification/classification_report.py and its corresponding files and tests?

@mergify mergify bot removed the has conflicts label Jun 30, 2025
@Borda
Copy link
Contributor

Borda commented Sep 2, 2025

How is it going here?

@aymuos15
Copy link
Author

aymuos15 commented Sep 2, 2025

Hi @Borda Didn't realise this was in the pipeline still. Happy to revisit this. Do you think itll be okay if I gave it a go with the collection style to complete the whole PR? Thanks.

@Borda Borda requested a review from Copilot September 19, 2025 19:50
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements a new ClassificationReport metric for TorchMetrics, providing comprehensive precision, recall, F1-score, and support statistics for binary, multiclass, and multilabel classification tasks. This metric mirrors scikit-learn's classification_report functionality but is adapted for PyTorch tensors with additional features like top-k support for multiclass tasks.

  • Adds functional implementations for binary, multiclass, and multilabel classification reports
  • Implements class-based metrics with state management and distributed support
  • Includes comprehensive test coverage with comparison against scikit-learn implementations

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
src/torchmetrics/functional/classification/classification_report.py Core functional implementation with task-specific report generation
src/torchmetrics/classification/classification_report.py Class-based metric implementations with state management
tests/unittests/classification/test_classification_report.py Comprehensive test suite covering all tasks and edge cases
src/torchmetrics/functional/classification/__init__.py Exports for functional classification report functions
src/torchmetrics/classification/__init__.py Exports for classification report classes
docs/source/classification/classification_report.rst Documentation for the new metric

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment on lines +41 to +48
def _handle_zero_division(value: float, zero_division: Union[str, float]) -> float:
"""Handle NaN values based on zero_division parameter."""
if torch.isnan(torch.tensor(value)):
if zero_division == "warn":
return 0.0
if isinstance(zero_division, (int, float)):
return float(zero_division)
return value
Copy link

Copilot AI Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a tensor from a scalar value just to check if it's NaN is inefficient. Use math.isnan(value) instead of torch.isnan(torch.tensor(value)).

Copilot uses AI. Check for mistakes.

def _compute_binary_metrics(
preds: Tensor, target: Tensor, threshold: float, ignore_index: Optional[int], validate_args: bool
) -> Dict[int, Dict[str, Union[float, int]]]:
"""Compute metrics for binary classification."""
Copy link

Copilot AI Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing docstring for this internal function. While it's a private function, given its complexity and importance in the classification report computation, it would benefit from documentation explaining the parameters and return value structure.

Suggested change
"""Compute metrics for binary classification."""
"""
Compute per-class precision, recall, f1-score, and support for binary classification.
For class 0 (negative class), predictions and targets are inverted to compute metrics.
Handles optional ignore_index for excluding certain targets from metric computation.
Args:
preds (Tensor): Predicted labels or probabilities for each sample.
target (Tensor): Ground truth labels for each sample.
threshold (float): Threshold for converting probabilities/logits to binary predictions.
ignore_index (Optional[int]): Specifies a target value that is ignored and does not contribute to the metric computation.
validate_args (bool): If True, performs input validation.
Returns:
Dict[int, Dict[str, Union[float, int]]]: A dictionary mapping class indices (0 and 1) to dictionaries containing
the following metrics for each class:
- "precision": Precision score (float)
- "recall": Recall score (float)
- "f1-score": F1 score (float)
- "support": Number of true instances for the class (int)
"""

Copilot uses AI. Check for mistakes.

def _compute_multiclass_metrics(
preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int], validate_args: bool, top_k: int = 1
) -> Dict[int, Dict[str, Union[float, int]]]:
"""Compute metrics for multiclass classification."""
Copy link

Copilot AI Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing docstring for this internal function. The function parameters and return structure should be documented, especially the top_k parameter which adds complexity to the multiclass computation.

Suggested change
"""Compute metrics for multiclass classification."""
"""
Compute precision, recall, f1-score, and support for each class in multiclass classification.
Args:
preds (Tensor): Predicted labels or probabilities for each sample.
target (Tensor): Ground truth labels for each sample.
num_classes (int): Number of classes in the classification task.
ignore_index (Optional[int]): Specifies a target value that is ignored and does not contribute to the metric computation.
validate_args (bool): If True, validates the input arguments for correctness.
top_k (int, optional): Number of highest probability or logit predictions considered to find the correct label. Defaults to 1.
If greater than 1, the metric is computed considering whether the true label is among the top_k predicted classes.
Returns:
Dict[int, Dict[str, Union[float, int]]]: A dictionary mapping each class index to a dictionary containing:
- "precision": Precision score for the class.
- "recall": Recall score for the class.
- "f1-score": F1 score for the class.
- "support": Number of true instances for the class.
"""

Copilot uses AI. Check for mistakes.

def _compute_multilabel_metrics(
preds: Tensor, target: Tensor, num_labels: int, threshold: float, ignore_index: Optional[int], validate_args: bool
) -> Dict[int, Dict[str, Union[float, int]]]:
"""Compute metrics for multilabel classification."""
Copy link

Copilot AI Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing docstring for this internal function. Documentation would help explain the multilabel-specific computation logic and parameter meanings.

Suggested change
"""Compute metrics for multilabel classification."""
"""
Compute multilabel classification metrics (precision, recall, f1-score, support) for each label.
For each label, computes precision, recall, and f1-score using the provided predictions and targets.
The function also calculates the support (number of true instances) for each label, optionally
accounting for an `ignore_index` to exclude certain samples from the computation.
Args:
preds (Tensor): Predicted outputs of shape (N, num_labels), where N is the number of samples.
target (Tensor): Ground truth labels of shape (N, num_labels).
num_labels (int): Number of labels in the multilabel classification task.
threshold (float): Threshold for converting predicted probabilities/logits to binary predictions.
ignore_index (Optional[int]): Specifies a target value that is ignored and does not contribute to the metric computation.
validate_args (bool): If True, validates the input arguments for correctness.
Returns:
Dict[int, Dict[str, Union[float, int]]]: A dictionary mapping each label index to a dictionary containing
the computed precision, recall, f1-score, and support for that label.
"""

Copilot uses AI. Check for mistakes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation topic: Classif

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Classification Report, maybe with the option to select even more metrics

3 participants