3838from trinity .common .config import Config
3939from trinity .common .experience import Experiences
4040from trinity .trainer .trainer import TrainEngineWrapper
41+ from trinity .utils .log import get_logger
4142from trinity .utils .monitor import MONITOR
4243
4344
@@ -146,13 +147,14 @@ def __init__(
146147 ray_worker_group_cls ,
147148 )
148149 self .init_workers ()
149- self .logger = MONITOR .get (global_config .monitor .monitor_type )(
150+ self .monitor = MONITOR .get (global_config .monitor .monitor_type )(
150151 project = config .trainer .project_name ,
151152 name = config .trainer .experiment_name ,
152153 role = global_config .trainer .name ,
153154 config = global_config ,
154155 )
155156 self .reset_experiences_example_table ()
157+ self .logger = get_logger (__name__ )
156158
157159 def _validate_config (self ): # TODO
158160 algorithm = ALGORITHM_TYPE .get (self .algorithm_config .algorithm_type )
@@ -276,7 +278,7 @@ def prepare(self):
276278 if self .val_reward_fn is not None and self .config .trainer .get ("val_before_train" , True ):
277279 val_metrics = self ._validate ()
278280 pprint (f"Initial validation metrics: { val_metrics } " )
279- self .logger .log (data = val_metrics , step = self .global_steps )
281+ self .monitor .log (data = val_metrics , step = self .global_steps )
280282 if self .config .trainer .get ("val_only" , False ):
281283 return
282284
@@ -286,6 +288,7 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
286288 self .total_training_steps = self .config .trainer .total_training_steps or sys .maxsize
287289
288290 def train_step (self ) -> bool : # noqa C901
291+ self .logger .info (f"Training at step { self .global_steps + 1 } started." )
289292 metrics = {}
290293 try :
291294 batch , sample_metrics , exp_samples = self .sample_strategy .sample (self .global_steps + 1 )
@@ -294,6 +297,7 @@ def train_step(self) -> bool: # noqa C901
294297 print ("No more data to train. Stop training." )
295298 return False
296299 self .global_steps += 1
300+ self .logger .info (f"Sampling at step { self .global_steps } done." )
297301 timing_raw = {}
298302 algorithm_config = self .algorithm_manager .get_current_algorithm_config (self .global_steps )
299303 algorithm = ALGORITHM_TYPE .get (algorithm_config .algorithm_type )
@@ -356,8 +360,10 @@ def train_step(self) -> bool: # noqa C901
356360 self .config .trainer .save_freq > 0
357361 and self .global_steps % self .config .trainer .save_freq == 0
358362 ):
363+ self .logger .info (f"Saving at step { self .global_steps } ." )
359364 with _timer ("save_checkpoint" , timing_raw ):
360365 self ._save_checkpoint ()
366+ self .logger .info (f"Saved at step { self .global_steps } ." )
361367
362368 # collect metrics
363369 if self .algorithm .use_advantage : # TODO
@@ -372,16 +378,19 @@ def train_step(self) -> bool: # noqa C901
372378 self ._log_experiences (exp_samples )
373379
374380 # TODO: make a canonical logger that supports various backend
375- self .logger .log (data = metrics , step = self .global_steps )
381+ self .monitor .log (data = metrics , step = self .global_steps )
376382
377383 train_status = self .global_steps < self .total_training_steps
378384 if not train_status or self .algorithm_manager .need_save (self .global_steps ):
379385 if (
380386 self .config .trainer .save_freq == 0
381387 or self .global_steps % self .config .trainer .save_freq != 0
382388 ):
389+ self .logger .info (f"Saving at step { self .global_steps } ." )
383390 with _timer ("save_checkpoint" , timing_raw ):
384391 self ._save_checkpoint ()
392+ self .logger .info (f"Saved at step { self .global_steps } ." )
393+ self .logger .info (f"Training at step { self .global_steps } finished." )
385394 return train_status
386395
387396 def _log_single_experience (
@@ -412,7 +421,7 @@ def _log_single_experience(
412421 def _log_experiences (self , samples : List [Dict ]) -> None :
413422 self .sample_exps_to_log .extend (samples )
414423 if self .global_steps % self .config .trainer .sync_freq == 0 :
415- self .logger .log_table (
424+ self .monitor .log_table (
416425 "rollout_examples" , pd .DataFrame (self .sample_exps_to_log ), self .global_steps
417426 )
418427 self .reset_experiences_example_table ()
0 commit comments