File tree Expand file tree Collapse file tree 2 files changed +15
-4
lines changed Expand file tree Collapse file tree 2 files changed +15
-4
lines changed Original file line number Diff line number Diff line change
1
+ import torch .utils .tensorboard
1
2
from typing import Optional
2
3
3
4
from lantern import FunctionalBase
6
7
class EarlyStopping (FunctionalBase ):
7
8
"""Keeps track of the best score and how long ago it was calculated."""
8
9
10
+ tensorboard_logger : torch .utils .tensorboard .SummaryWriter
9
11
best_score : Optional [float ] = None
10
12
scores_since_improvement : int = - 1
11
13
14
+ class Config :
15
+ arbitrary_types_allowed = True
16
+ allow_mutation = False
17
+
12
18
def score (self , value ):
13
19
if self .best_score is None or value >= self .best_score :
14
20
return self .replace (
@@ -31,5 +37,10 @@ def print(self):
31
37
)
32
38
return self
33
39
34
- def log (self ):
35
- pass
40
+ def log (self , step ):
41
+ self .tensorboard_logger .add_scalar (
42
+ "best_score" ,
43
+ self .best_score ,
44
+ step ,
45
+ )
46
+ return self
Original file line number Diff line number Diff line change @@ -52,7 +52,7 @@ def test_mnist():
52
52
)
53
53
54
54
tensorboard_logger = torch .utils .tensorboard .SummaryWriter ()
55
- early_stopping = lantern .EarlyStopping ()
55
+ early_stopping = lantern .EarlyStopping (tensorboard_logger = tensorboard_logger )
56
56
gradient_metrics = lantern .Metrics (
57
57
name = "gradient" ,
58
58
tensorboard_logger = tensorboard_logger ,
@@ -108,7 +108,7 @@ def test_mnist():
108
108
torch .save (optimizer .state_dict (), "optimizer.pt" )
109
109
elif early_stopping .scores_since_improvement > 5 :
110
110
break
111
- early_stopping .print ()
111
+ early_stopping .log ( epoch ). print ()
112
112
113
113
114
114
if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments