Precision@k is removed after v0.11.0 #2002
Replies: 1 comment
-
You can manually implement Precision@k by using Here is a minimal example to manually implement Precision@k using import torch
from torchmetrics.classification import MultilabelPrecision
def precision_at_k(preds, target, k):
"""
Args:
preds: (batch_size, num_classes) tensor with prediction scores (floats)
target: (batch_size, num_classes) binary tensor with ground truth
k: int, top-k predictions to consider
"""
num_labels = preds.size(1)
# Get top-k indices for each sample
topk_indices = torch.topk(preds, k=k, dim=1).indices
# Create mask for top-k predictions
topk_mask = torch.zeros_like(preds, dtype=torch.bool)
batch_indices = torch.arange(preds.size(0)).unsqueeze(1)
topk_mask[batch_indices, topk_indices] = True
# Use top-k mask as predictions (1 for selected, 0 otherwise)
preds_topk = topk_mask.int()
# Ground truth only for the top-k positions
target_topk = target & topk_mask
# Create precision metric with num_labels
precision_metric = MultilabelPrecision(num_labels=num_labels, average="micro")
precision = precision_metric(preds_topk, target_topk.int())
return precision
# Example usage
preds = torch.tensor([[0.2, 0.8, 0.4, 0.9], [0.7, 0.1, 0.5, 0.3]])
target = torch.tensor([[0, 1, 0, 1], [1, 0, 1, 0]])
print(precision_at_k(preds, target, k=2)) If you find this approach useful, consider contributing a PR to TorchMetrics with a built-in Precision@k metric for multilabel tasks. This would benefit many users facing the same need. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi TorchMetrics team,
I am thankful for this powerful tool and really appreciate your hard work.
I'm working on evaluating multilabel classification results with precision@k, which is logged by Lightning loggers. I noticed that since v0.11.0, Precison module cannot generalize to Precision@K anymore.
An pseudo example of preds and target in my case is:
To calculate precision@k, I obtain the indices of the top k elements in preds, apply thresholding to them, and calculate precision using the tp and fp obtained from the top_k preds-target pair.
I can achieve my purpose with Precision before 0.11.0. Currently the version of TorchMetrics in my project is stuck to 0.10.3. And I'd like to know if there is an workaround for my problem. Thank you very much.
Beta Was this translation helpful? Give feedback.
All reactions