-
Notifications
You must be signed in to change notification settings - Fork 464
Newmetric: ClassificationReport #3116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
Great addition! |
- Remove direct imports of classification metrics at module level - Implement lazy imports using @Property for Binary/Multiclass/MultilabelClassificationReport - Move metric initialization to property getters to break circular dependencies - Maintain all existing functionality while improving import structure
for more information, see https://pre-commit.ci
Thank you very much! @Borda Could you please help me understand why the TM.unittests are failing? They seemed to be linked with azure and I am not familiar with the error logs there. Regarding the ruff failures, for some reason my local checks pass, so ill look into that. But is it okay if I tend to those once the actual code is vetted? Sorry for the delay in getting back to, I had some immediate personal work to get done. CC: @rittik9 I think it is ready now? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aymuos15 I had a quick look and opened a pr for fixing pre-commit errors.
Doctests regarding classification report seem to be failing, mind having look.
going to take a detailed look later.
fix pre-commit errors
for more information, see https://pre-commit.ci
Thanks a lot for the quickfix @rittik9. Will look into the doctests now. Sorry about the close and open, i think the close automatically got triggered post the local merge for some reason? not sure. |
src/torchmetrics/functional/classification/classification_report.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rittik9 Thanks again for the detailed review. I have tried to address everything bar the classwise accuracy score. That is inherently different from the scikit version of the metric. Are you okay if we raise a separate issue for that, discuss how exactly we want to add that, and then do a separate PR for that? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall it looks good, just think if it would be better if we can have it rather more as a wrapper so you can select what metrics you want to have in the report (probably pass it s string metrics/class names) and then have precision, recall, F-measure as default configuration
see suggestion by @rittik9 in #2580 (comment)
Thank you very much for the review. I definitely like the idea and agree with it. However, I think I would have to pretty much revamp everything. Would it be okay to keep this as the base and proceed with that separately in another PR or should I just continue here? |
you can continue here 🐰 |
Haha, alright. Ill take some time to understand the best way to approach this and make the tests so edge cases do not come up. Thanks again! |
Hi @aymuos15 While you're at it, could you please ensure that the 'micro' averaging option is included by default for all three cases—binary, multiclass, and multilabel? In case the results don't align with sklearn, we can always verify them manually. |
Yup sure. Thank you for reminding me. Since anyways now we are going to include everything, ill make all options exhaustive. |
We can figure out how to split it into smaller pieces so it would land smoother/faster |
That sounds good with me. Could you please let me know what that would entail? This is what i had in mind: TorchMetrics Classification Report Examples
============================================================
BINARY CLASSIFICATION
============================================================
Predictions: [0, 1, 1, 0, 1, 0, 1, 1]
True Labels: [0, 1, 0, 0, 1, 1, 1, 0]
--- CASE 1: DEFAULT METRICS (precision, recall, f1-score) ---
precision recall f1-score support
0 0.67 0.50 0.57 4
1 0.60 0.75 0.67 4
accuracy 0.62 8
micro avg 0.62 0.62 0.62 8
macro avg 0.63 0.62 0.62 8
weighted avg 0.63 0.62 0.62 8
--- CASE 2: ADD ACCURACY TO DEFAULT METRICS ---
precision recall f1-score accuracy support
0 0.67 0.50 0.57 0.50 4
1 0.60 0.75 0.67 0.75 4
accuracy 0.62 8
micro avg 0.62 0.62 0.62 0.62 8
macro avg 0.63 0.62 0.62 0.62 8
weighted avg 0.63 0.62 0.62 0.62 8
--- CASE 3: SPECIFICITY ONLY ---
specificity support
0 0.75 4
1 0.50 4
accuracy 0.62 8
micro avg 0.50 8
macro avg 0.62 8
weighted avg 0.62 8
Verification - Direct accuracy calculation: 0.6250
============================================================
MULTICLASS CLASSIFICATION
============================================================
Predictions: [0, 2, 1, 2, 0, 1, 2, 0]
True Labels: [0, 1, 1, 2, 0, 2, 2, 1]
--- CASE 1: DEFAULT METRICS ---
precision recall f1-score support
0 0.67 1.00 0.80 2
1 0.50 0.33 0.40 3
2 0.67 0.67 0.67 3
accuracy 0.62 8
micro avg 0.62 0.62 0.62 8
macro avg 0.61 0.67 0.62 8
weighted avg 0.60 0.63 0.60 8
--- CASE 2: ADD ACCURACY TO DEFAULT METRICS ---
precision recall f1-score accuracy support
0 0.67 1.00 0.80 1.00 2
1 0.50 0.33 0.40 0.33 3
2 0.67 0.67 0.67 0.67 3
accuracy 0.62 8
micro avg 0.62 0.62 0.62 0.62 8
macro avg 0.61 0.67 0.62 0.67 8
weighted avg 0.60 0.63 0.60 0.63 8
--- CASE 3: SPECIFICITY ONLY ---
specificity support
0 0.83 2
1 0.80 3
2 0.80 3
accuracy 0.62 8
micro avg 0.81 8
macro avg 0.81 8
weighted avg 0.81 8
Verification - Direct accuracy calculation: 0.6667
============================================================
MULTILABEL CLASSIFICATION
============================================================
Predictions:
Sample 1: [1, 0, 1]
Sample 2: [0, 1, 0]
Sample 3: [1, 1, 0]
Sample 4: [0, 0, 1]
True Labels:
Sample 1: [1, 0, 0]
Sample 2: [0, 1, 1]
Sample 3: [1, 1, 0]
Sample 4: [0, 0, 1]
--- CASE 1: DEFAULT METRICS ---
precision recall f1-score support
0 1.00 1.00 1.00 2
1 1.00 1.00 1.00 2
2 0.50 0.50 0.50 2
micro avg 0.83 0.83 0.83 6
macro avg 0.83 0.83 0.83 6
weighted avg 0.83 0.83 0.83 6
samples avg 0.88 0.88 0.83 6
--- CASE 2: ADD ACCURACY TO DEFAULT METRICS ---
precision recall f1-score accuracy support
0 1.00 1.00 1.00 1.00 2
1 1.00 1.00 1.00 1.00 2
2 0.50 0.50 0.50 0.50 2
micro avg 0.83 0.83 0.83 0.83 6
macro avg 0.83 0.83 0.83 0.83 6
weighted avg 0.83 0.83 0.83 0.83 6
samples avg 0.88 0.88 0.83 0.83 6
--- CASE 3: SPECIFICITY ONLY ---
specificity support
0 1.00 2
1 1.00 2
2 0.50 2
micro avg 0.83 6
macro avg 0.83 6
weighted avg 0.83 6
samples avg 6
Verification - Direct accuracy calculation: 0.8333 Essentially
Is the above okay? I have not committed anything yet because the code is very messy. Once we agree on a path forward, I will trigger the next commit if thats okay. |
I had rather in mind that we can trim this PR to keep the printing functions and table formatting, and in the following PR have extension of Collection metrics either as a new subclass or a new method... |
Ah okay! I will push a commit for the formatting and printing tonight. Thank you. |
Keeping in mind the below rigidness for the micro avg # Determine if micro average should be shown in the report based on classification task
# Following scikit-learn's logic:
# - Show for multilabel classification (always)
# - Show for multiclass when using a subset of classes
# - Don't show for binary classification (micro avg is same as accuracy)
# - Don't show for full multiclass classification with all classes (micro avg is same as accuracy)
show_micro_avg = False
is_multilabel = task == ClassificationTask.MULTILABEL After going through everything again, I think the PR as of right now, it does exactly the bare minimum? The main additions after the initial few commits were the ignore_index and the micro avg. So Just a few quick questions:
Thank you very much! |
Still, the Module-like metric is not derived from the collection to save some computations |
Okay, Thank you. So just to confirm -- I would only push |
How is it going here? |
Hi @Borda Didn't realise this was in the pipeline still. Happy to revisit this. Do you think itll be okay if I gave it a go with the collection style to complete the whole PR? Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements a new ClassificationReport
metric for TorchMetrics, providing comprehensive precision, recall, F1-score, and support statistics for binary, multiclass, and multilabel classification tasks. This metric mirrors scikit-learn's classification_report functionality but is adapted for PyTorch tensors with additional features like top-k support for multiclass tasks.
- Adds functional implementations for binary, multiclass, and multilabel classification reports
- Implements class-based metrics with state management and distributed support
- Includes comprehensive test coverage with comparison against scikit-learn implementations
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.
Show a summary per file
File | Description |
---|---|
src/torchmetrics/functional/classification/classification_report.py |
Core functional implementation with task-specific report generation |
src/torchmetrics/classification/classification_report.py |
Class-based metric implementations with state management |
tests/unittests/classification/test_classification_report.py |
Comprehensive test suite covering all tasks and edge cases |
src/torchmetrics/functional/classification/__init__.py |
Exports for functional classification report functions |
src/torchmetrics/classification/__init__.py |
Exports for classification report classes |
docs/source/classification/classification_report.rst |
Documentation for the new metric |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
def _handle_zero_division(value: float, zero_division: Union[str, float]) -> float: | ||
"""Handle NaN values based on zero_division parameter.""" | ||
if torch.isnan(torch.tensor(value)): | ||
if zero_division == "warn": | ||
return 0.0 | ||
if isinstance(zero_division, (int, float)): | ||
return float(zero_division) | ||
return value |
Copilot
AI
Sep 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Creating a tensor from a scalar value just to check if it's NaN is inefficient. Use math.isnan(value)
instead of torch.isnan(torch.tensor(value))
.
Copilot uses AI. Check for mistakes.
def _compute_binary_metrics( | ||
preds: Tensor, target: Tensor, threshold: float, ignore_index: Optional[int], validate_args: bool | ||
) -> Dict[int, Dict[str, Union[float, int]]]: | ||
"""Compute metrics for binary classification.""" |
Copilot
AI
Sep 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing docstring for this internal function. While it's a private function, given its complexity and importance in the classification report computation, it would benefit from documentation explaining the parameters and return value structure.
"""Compute metrics for binary classification.""" | |
""" | |
Compute per-class precision, recall, f1-score, and support for binary classification. | |
For class 0 (negative class), predictions and targets are inverted to compute metrics. | |
Handles optional ignore_index for excluding certain targets from metric computation. | |
Args: | |
preds (Tensor): Predicted labels or probabilities for each sample. | |
target (Tensor): Ground truth labels for each sample. | |
threshold (float): Threshold for converting probabilities/logits to binary predictions. | |
ignore_index (Optional[int]): Specifies a target value that is ignored and does not contribute to the metric computation. | |
validate_args (bool): If True, performs input validation. | |
Returns: | |
Dict[int, Dict[str, Union[float, int]]]: A dictionary mapping class indices (0 and 1) to dictionaries containing | |
the following metrics for each class: | |
- "precision": Precision score (float) | |
- "recall": Recall score (float) | |
- "f1-score": F1 score (float) | |
- "support": Number of true instances for the class (int) | |
""" |
Copilot uses AI. Check for mistakes.
def _compute_multiclass_metrics( | ||
preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int], validate_args: bool, top_k: int = 1 | ||
) -> Dict[int, Dict[str, Union[float, int]]]: | ||
"""Compute metrics for multiclass classification.""" |
Copilot
AI
Sep 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing docstring for this internal function. The function parameters and return structure should be documented, especially the top_k parameter which adds complexity to the multiclass computation.
"""Compute metrics for multiclass classification.""" | |
""" | |
Compute precision, recall, f1-score, and support for each class in multiclass classification. | |
Args: | |
preds (Tensor): Predicted labels or probabilities for each sample. | |
target (Tensor): Ground truth labels for each sample. | |
num_classes (int): Number of classes in the classification task. | |
ignore_index (Optional[int]): Specifies a target value that is ignored and does not contribute to the metric computation. | |
validate_args (bool): If True, validates the input arguments for correctness. | |
top_k (int, optional): Number of highest probability or logit predictions considered to find the correct label. Defaults to 1. | |
If greater than 1, the metric is computed considering whether the true label is among the top_k predicted classes. | |
Returns: | |
Dict[int, Dict[str, Union[float, int]]]: A dictionary mapping each class index to a dictionary containing: | |
- "precision": Precision score for the class. | |
- "recall": Recall score for the class. | |
- "f1-score": F1 score for the class. | |
- "support": Number of true instances for the class. | |
""" |
Copilot uses AI. Check for mistakes.
def _compute_multilabel_metrics( | ||
preds: Tensor, target: Tensor, num_labels: int, threshold: float, ignore_index: Optional[int], validate_args: bool | ||
) -> Dict[int, Dict[str, Union[float, int]]]: | ||
"""Compute metrics for multilabel classification.""" |
Copilot
AI
Sep 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing docstring for this internal function. Documentation would help explain the multilabel-specific computation logic and parameter meanings.
"""Compute metrics for multilabel classification.""" | |
""" | |
Compute multilabel classification metrics (precision, recall, f1-score, support) for each label. | |
For each label, computes precision, recall, and f1-score using the provided predictions and targets. | |
The function also calculates the support (number of true instances) for each label, optionally | |
accounting for an `ignore_index` to exclude certain samples from the computation. | |
Args: | |
preds (Tensor): Predicted outputs of shape (N, num_labels), where N is the number of samples. | |
target (Tensor): Ground truth labels of shape (N, num_labels). | |
num_labels (int): Number of labels in the multilabel classification task. | |
threshold (float): Threshold for converting predicted probabilities/logits to binary predictions. | |
ignore_index (Optional[int]): Specifies a target value that is ignored and does not contribute to the metric computation. | |
validate_args (bool): If True, validates the input arguments for correctness. | |
Returns: | |
Dict[int, Dict[str, Union[float, int]]]: A dictionary mapping each label index to a dictionary containing | |
the computed precision, recall, f1-score, and support for that label. | |
""" |
Copilot uses AI. Check for mistakes.
What does this PR do?
Fixes #2580
Before submitting
PR review
Classification report is a nice to have metric. The motivation can be easily inspired from its scikit learn integration. Having it ready for tensors is definitely a good step forward.
Notes:
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Was a really nice way of getting familiar with the codebase :)
📚 Documentation preview 📚: https://torchmetrics--3116.org.readthedocs.build/en/3116/