11"""Monitor"""
2+
23import os
4+ from abc import ABC , abstractmethod
35from typing import List , Optional , Union
46
57import numpy as np
810from torch .utils .tensorboard import SummaryWriter
911
1012from trinity .common .config import Config
11- from trinity .common .constants import MonitorType
1213from trinity .utils .log import get_logger
14+ from trinity .utils .registry import Registry
15+
16+ MONITOR = Registry ("monitor" )
1317
1418
15- class Monitor :
19+ class Monitor ( ABC ) :
1620 """Monitor"""
1721
1822 def __init__ (
@@ -22,15 +26,25 @@ def __init__(
2226 role : str ,
2327 config : Config = None , # pass the global Config for recording
2428 ) -> None :
25- if config .monitor .monitor_type == MonitorType .WANDB :
26- self .logger = WandbLogger (project , name , role , config )
27- elif config .monitor .monitor_type == MonitorType .TENSORBOARD :
28- self .logger = TensorboardLogger (project , name , role , config )
29- else :
30- raise ValueError (f"Unknown monitor type: { config .monitor .monitor_type } " )
29+ self .project = project
30+ self .name = name
31+ self .role = role
32+ self .config = config
3133
34+ @abstractmethod
3235 def log_table (self , table_name : str , experiences_table : pd .DataFrame , step : int ):
33- self .logger .log_table (table_name , experiences_table , step = step )
36+ """Log a table"""
37+
38+ @abstractmethod
39+ def log (self , data : dict , step : int , commit : bool = False ) -> None :
40+ """Log metrics."""
41+
42+ @abstractmethod
43+ def close (self ) -> None :
44+ """Close the monitor"""
45+
46+ def __del__ (self ) -> None :
47+ self .close ()
3448
3549 def calculate_metrics (
3650 self , data : dict [str , Union [List [float ], float ]], prefix : Optional [str ] = None
@@ -51,15 +65,9 @@ def calculate_metrics(
5165 metrics [key ] = val
5266 return metrics
5367
54- def log (self , data : dict , step : int , commit : bool = False ) -> None :
55- """Log metrics."""
56- self .logger .log (data , step = step , commit = commit )
57-
58- def close (self ) -> None :
59- self .logger .close ()
60-
6168
62- class TensorboardLogger :
69+ @MONITOR .register_module ("tensorboard" )
70+ class TensorboardMonitor (Monitor ):
6371 def __init__ (self , project : str , name : str , role : str , config : Config = None ) -> None :
6472 self .tensorboard_dir = os .path .join (config .monitor .cache_dir , "tensorboard" )
6573 os .makedirs (self .tensorboard_dir , exist_ok = True )
@@ -77,11 +85,9 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
7785 def close (self ) -> None :
7886 self .logger .close ()
7987
80- def __del__ (self ) -> None :
81- self .logger .close ()
82-
8388
84- class WandbLogger :
89+ @MONITOR .register_module ("wandb" )
90+ class WandbMonitor (Monitor ):
8591 def __init__ (self , project : str , name : str , role : str , config : Config = None ) -> None :
8692 self .logger = wandb .init (
8793 project = project ,
@@ -104,6 +110,3 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:
104110
105111 def close (self ) -> None :
106112 self .logger .finish ()
107-
108- def __del__ (self ) -> None :
109- self .logger .finish ()
0 commit comments