Why does F1, Recall, Precision, and Accuracy are outputting the same thing in my implementation? #743
-
|
I have created a very barebone As you can see there is something fishy going on. Is |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
|
Apparently someone has reproduced this recently: https://stackoverflow.com/questions/69139618/torchmetrics-does-not-work-with-pytorchlightning |
Beta Was this translation helpful? Give feedback.
-
|
@FeryET — this was a common issue with the old API and Lightning's Cause 1: Sharing metric instances across stages. # WRONG — same metric object for train and val
self.metrics = MetricCollection({"acc": Accuracy(), "f1": F1()})
self.train_metrics = self.metrics
self.val_metrics = self.metrics # shares state!Fix: # RIGHT — .clone() creates independent copies
self.train_metrics = MetricCollection({"acc": Accuracy(), "f1": F1()})
self.val_metrics = self.train_metrics.clone(prefix="val_")Cause 2: Using the old unified API incorrectly. Here's the correct pattern today (v1.9.0): from torchmetrics import MetricCollection
from torchmetrics.classification import BinaryAccuracy, BinaryF1Score, BinaryPrecision, BinaryRecall
class MyModel(L.LightningModule):
def __init__(self):
super().__init__()
metrics = MetricCollection({
"acc": BinaryAccuracy(),
"f1": BinaryF1Score(),
"precision": BinaryPrecision(),
"recall": BinaryRecall(),
})
self.train_metrics = metrics.clone(prefix="train/")
self.val_metrics = metrics.clone(prefix="val/")
def training_step(self, batch, batch_idx):
probs = self(batch[0]).squeeze().sigmoid()
self.log_dict(self.train_metrics(probs, batch[1]), on_step=False, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
probs = self(batch[0]).squeeze().sigmoid()
self.log_dict(self.val_metrics(probs, batch[1]))With task-specific classes, each metric computes independently and you'll see distinct values. Docs: MetricCollection |
Beta Was this translation helpful? Give feedback.
@FeryET — this was a common issue with the old API and Lightning's
MetricCollection. The root cause was almost always one of:Cause 1: Sharing metric instances across stages.
Fix:
Cause 2: Using the old unified API incorrectly.
Here's the correct pattern today (v1.9.0):