Skip to content
Discussion options

You must be logged in to vote

@SkafteNicki pointed to the docs — let me give a concrete end-to-end example that should cover most "I have a custom function, how do I wrap it?" cases (v1.9.0).

Say you have this custom accuracy function:

def my_custom_accuracy(preds, target, top_k=3):
    """Top-k accuracy: correct if true label is in top-k predictions."""
    top_k_preds = preds.topk(top_k, dim=-1).indices
    correct = (top_k_preds == target.unsqueeze(-1)).any(dim=-1)
    return correct.sum(), target.shape[0]

Wrap it as a TorchMetrics Metric:

import torch
from torchmetrics import Metric


class MyTopKAccuracy(Metric):
    # Metadata — helps Lightning and logging tools
    is_differentiable = False
    higher_is_better = 

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by Borda
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants