Convert a custom accuracy function into torchmetric? #501
-
|
Hello, Thank you for this awesome project. I am new to use torchmetric, and I have one doubt regarding the custom function. So if I have my own function to calculate the accuracy, such as : How I can convert this function into torch metric where all the internal logic will be from my function, Just want to wrap my logic in torch metric. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
|
Hi @Abhinav43, |
Beta Was this translation helpful? Give feedback.
-
|
@SkafteNicki pointed to the docs — let me give a concrete end-to-end example that should cover most "I have a custom function, how do I wrap it?" cases (v1.9.0). Say you have this custom accuracy function: def my_custom_accuracy(preds, target, top_k=3):
"""Top-k accuracy: correct if true label is in top-k predictions."""
top_k_preds = preds.topk(top_k, dim=-1).indices
correct = (top_k_preds == target.unsqueeze(-1)).any(dim=-1)
return correct.sum(), target.shape[0]Wrap it as a TorchMetrics Metric: import torch
from torchmetrics import Metric
class MyTopKAccuracy(Metric):
# Metadata — helps Lightning and logging tools
is_differentiable = False
higher_is_better = True
full_state_update = False # update() doesn't need full state
def __init__(self, top_k: int = 3, **kwargs):
super().__init__(**kwargs)
self.top_k = top_k
# Register states with reduction function for DDP
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
correct, total = my_custom_accuracy(preds, target, self.top_k)
self.correct += correct
self.total += total
def compute(self) -> torch.Tensor:
return self.correct.float() / self.totalUsage: metric = MyTopKAccuracy(top_k=5)
# Per-batch
for preds, target in dataloader:
metric.update(preds, target)
# Epoch-level result (works across DDP automatically)
result = metric.compute()
metric.reset()Key rules:
|
Beta Was this translation helpful? Give feedback.
@SkafteNicki pointed to the docs — let me give a concrete end-to-end example that should cover most "I have a custom function, how do I wrap it?" cases (v1.9.0).
Say you have this custom accuracy function:
Wrap it as a TorchMetrics Metric: