11"""Monitor"""
2-
2+ import os
33from typing import Any , List , Optional , Union
44
55import numpy as np
66import pandas as pd
77import wandb
8+ from torch .utils .tensorboard import SummaryWriter
89
10+ from trinity .common .constants import MonitorType
911from trinity .utils .log import get_logger
1012
1113
@@ -19,19 +21,15 @@ def __init__(
1921 role : str ,
2022 config : Any = None ,
2123 ) -> None :
22- self .logger = wandb .init (
23- project = project ,
24- group = name ,
25- name = f"{ name } _{ role } " ,
26- tags = [role ],
27- config = config ,
28- save_code = False ,
29- )
30- self .console_logger = get_logger (__name__ )
24+ if config .monitor .monitor_type == MonitorType .WANDB :
25+ self .logger = WandbLogger (project , name , role , config )
26+ elif config .monitor .monitor_type == MonitorType .TENSORBOARD :
27+ self .logger = TensorboardLogger (project , name , role , config )
28+ else :
29+ raise ValueError (f"Unknown monitor type: { config .monitor .monitor_type } " )
3130
3231 def log_table (self , table_name : str , experiences_table : pd .DataFrame , step : int ):
33- experiences_table = wandb .Table (dataframe = experiences_table )
34- self .log (data = {table_name : experiences_table }, step = step )
32+ self .logger .log_table (table_name , experiences_table , step = step )
3533
3634 def calculate_metrics (
3735 self , data : dict [str , Union [List [float ], float ]], prefix : Optional [str ] = None
@@ -55,6 +53,46 @@ def calculate_metrics(
5553 def log (self , data : dict , step : int , commit : bool = False ) -> None :
5654 """Log metrics."""
5755 self .logger .log (data , step = step , commit = commit )
56+
57+
58+ class TensorboardLogger :
59+ def __init__ (self , project : str , name : str , role : str , config : Any = None ) -> None :
60+ self .tensorboard_dir = os .path .join (config .monitor .job_dir , "tensorboard" )
61+ os .makedirs (self .tensorboard_dir , exist_ok = True )
62+ self .logger = SummaryWriter (self .tensorboard_dir )
63+ self .console_logger = get_logger (__name__ )
64+
65+ def log_table (self , table_name : str , experiences_table : pd .DataFrame , step : int ):
66+ pass
67+
68+ def log (self , data : dict , step : int , commit : bool = False ) -> None :
69+ """Log metrics."""
70+ for key in data :
71+ self .logger .add_scalar (key , data [key ], step )
72+
73+ def __del__ (self ) -> None :
74+ self .logger .close ()
75+
76+
77+ class WandbLogger :
78+ def __init__ (self , project : str , name : str , role : str , config : Any = None ) -> None :
79+ self .logger = wandb .init (
80+ project = project ,
81+ group = name ,
82+ name = f"{ name } _{ role } " ,
83+ tags = [role ],
84+ config = config ,
85+ save_code = False ,
86+ )
87+ self .console_logger = get_logger (__name__ )
88+
89+ def log_table (self , table_name : str , experiences_table : pd .DataFrame , step : int ):
90+ experiences_table = wandb .Table (dataframe = experiences_table )
91+ self .log (data = {table_name : experiences_table }, step = step )
92+
93+ def log (self , data : dict , step : int , commit : bool = False ) -> None :
94+ """Log metrics."""
95+ self .logger .log (data , step = step , commit = commit )
5896 self .console_logger .info (f"Step { step } : { data } " )
5997
6098 def __del__ (self ) -> None :
0 commit comments