@@ -21,6 +21,9 @@ class WandbSummaryWriter(SummaryWriter):
2121 def __init__ (self , log_dir : str , flush_secs : int , cfg ):
2222 super ().__init__ (log_dir , flush_secs )
2323
24+ # Get the run name
25+ run_name = os .path .split (log_dir )[- 1 ]
26+
2427 try :
2528 project = cfg ["wandb_project" ]
2629 except KeyError :
@@ -29,35 +32,27 @@ def __init__(self, log_dir: str, flush_secs: int, cfg):
2932 try :
3033 entity = os .environ ["WANDB_USERNAME" ]
3134 except KeyError :
32- raise KeyError (
33- "Wandb username not found. Please run or add to ~/.bashrc: export WANDB_USERNAME=YOUR_USERNAME"
34- )
35+ entity = None
3536
36- wandb .init (project = project , entity = entity )
37+ # Initialize wandb
38+ wandb .init (project = project , entity = entity , name = run_name )
3739
38- # Change generated name to project-number format
39- wandb .run . name = project + wandb . run . name . split ( "-" )[ - 1 ]
40+ # Add log directory to wandb
41+ wandb .config . update ({ "log_dir" : log_dir })
4042
4143 self .name_map = {
4244 "Train/mean_reward/time" : "Train/mean_reward_time" ,
4345 "Train/mean_episode_length/time" : "Train/mean_episode_length_time" ,
4446 }
4547
46- run_name = os .path .split (log_dir )[- 1 ]
47-
48- wandb .log ({"log_dir" : run_name })
49-
5048 def store_config (self , env_cfg , runner_cfg , alg_cfg , policy_cfg ):
5149 wandb .config .update ({"runner_cfg" : runner_cfg })
5250 wandb .config .update ({"policy_cfg" : policy_cfg })
5351 wandb .config .update ({"alg_cfg" : alg_cfg })
54- wandb .config .update ({"env_cfg" : asdict (env_cfg )})
55-
56- def _map_path (self , path ):
57- if path in self .name_map :
58- return self .name_map [path ]
59- else :
60- return path
52+ try :
53+ wandb .config .update ({"env_cfg" : env_cfg .to_dict ()})
54+ except Exception :
55+ wandb .config .update ({"env_cfg" : asdict (env_cfg )})
6156
6257 def add_scalar (self , tag , scalar_value , global_step = None , walltime = None , new_style = False ):
6358 super ().add_scalar (
@@ -80,3 +75,13 @@ def save_model(self, model_path, iter):
8075
8176 def save_file (self , path , iter = None ):
8277 wandb .save (path , base_path = os .path .dirname (path ))
78+
79+ """
80+ Private methods.
81+ """
82+
83+ def _map_path (self , path ):
84+ if path in self .name_map :
85+ return self .name_map [path ]
86+ else :
87+ return path
0 commit comments