11"""Logger hooks."""
2+
3+ # Copyright (C) 2023 Intel Corporation
4+ # SPDX-License-Identifier: Apache-2.0
5+
26from collections import defaultdict
37from typing import Any , Dict , Optional
48
@@ -29,6 +33,19 @@ def __repr__(self):
2933 points .append (f"({ x } ,{ y } )" )
3034 return "curve[" + "," .join (points ) + "]"
3135
36+ _TAGS_TO_SKIP = (
37+ "accuracy_top-1" ,
38+ "current_iters" ,
39+ "decode.acc_seg" ,
40+ "decode.loss_ce_ignore" ,
41+ )
42+
43+ _TAGS_TO_RENAME = {
44+ "train/time" : "train/time (sec/iter)" ,
45+ "train/data_time" : "train/data_time (sec/iter)" ,
46+ "val/accuracy" : "val/accuracy (%)" ,
47+ }
48+
3249 def __init__ (
3350 self ,
3451 curves : Optional [Dict [Any , Curve ]] = None ,
@@ -43,12 +60,13 @@ def __init__(
4360 @master_only
4461 def log (self , runner : BaseRunner ):
4562 """Log function for OTXLoggerHook."""
46- tags = self .get_loggable_tags (runner , allow_text = False , tags_to_skip = () )
63+ tags = self .get_loggable_tags (runner , allow_text = False , tags_to_skip = self . _TAGS_TO_SKIP )
4764 if runner .max_epochs is not None :
4865 normalized_iter = self .get_iter (runner ) / runner .max_iters * runner .max_epochs
4966 else :
5067 normalized_iter = self .get_iter (runner )
5168 for tag , value in tags .items ():
69+ tag = self ._TAGS_TO_RENAME .get (tag , tag )
5270 curve = self .curves [tag ]
5371 # Remove duplicates.
5472 if len (curve .x ) > 0 and curve .x [- 1 ] == normalized_iter :
@@ -57,6 +75,11 @@ def log(self, runner: BaseRunner):
5775 curve .x .append (normalized_iter )
5876 curve .y .append (value )
5977
78+ def before_run (self , runner : BaseRunner ):
79+ """Called before_run in OTXLoggerHook."""
80+ super ().before_run (runner )
81+ self .curves .clear ()
82+
6083 def after_train_epoch (self , runner : BaseRunner ):
6184 """Called after_train_epoch in OTXLoggerHook."""
6285 # Iteration counter is increased right after the last iteration in the epoch,
0 commit comments