Skip to content

Commit e4e68a6

Browse files
committed
improve: log early stopping best score
1 parent c321b93 commit e4e68a6

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

lantern/early_stopping.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch.utils.tensorboard
12
from typing import Optional
23

34
from lantern import FunctionalBase
@@ -6,9 +7,14 @@
67
class EarlyStopping(FunctionalBase):
78
"""Keeps track of the best score and how long ago it was calculated."""
89

10+
tensorboard_logger: torch.utils.tensorboard.SummaryWriter
911
best_score: Optional[float] = None
1012
scores_since_improvement: int = -1
1113

14+
class Config:
15+
arbitrary_types_allowed = True
16+
allow_mutation = False
17+
1218
def score(self, value):
1319
if self.best_score is None or value >= self.best_score:
1420
return self.replace(
@@ -31,5 +37,10 @@ def print(self):
3137
)
3238
return self
3339

34-
def log(self):
35-
pass
40+
def log(self, step):
41+
self.tensorboard_logger.add_scalar(
42+
"best_score",
43+
self.best_score,
44+
step,
45+
)
46+
return self

test/test_mnist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def test_mnist():
5252
)
5353

5454
tensorboard_logger = torch.utils.tensorboard.SummaryWriter()
55-
early_stopping = lantern.EarlyStopping()
55+
early_stopping = lantern.EarlyStopping(tensorboard_logger=tensorboard_logger)
5656
gradient_metrics = lantern.Metrics(
5757
name="gradient",
5858
tensorboard_logger=tensorboard_logger,
@@ -108,7 +108,7 @@ def test_mnist():
108108
torch.save(optimizer.state_dict(), "optimizer.pt")
109109
elif early_stopping.scores_since_improvement > 5:
110110
break
111-
early_stopping.print()
111+
early_stopping.log(epoch).print()
112112

113113

114114
if __name__ == "__main__":

0 commit comments

Comments
 (0)