@@ -56,6 +56,7 @@ def __init__(self, config: Config) -> None:
5656 )
5757 self .save_interval = config .trainer .save_interval
5858 self .last_sync_step = None
59+ self .last_sync_time = None
5960 self .total_steps = config .trainer .total_steps or float ("inf" )
6061
6162 async def prepare (self ) -> None :
@@ -68,22 +69,20 @@ async def train(self) -> str:
6869 """Train the model."""
6970 while self .train_step_num < self .total_steps :
7071 try :
71- st = time . time ()
72+ metrics = {}
7273 # sample may be blocked due to explorer does not generate enough data
7374 self .logger .info (f"Sample data for step { self .train_step_num + 1 } started." )
7475 sample_task = asyncio .create_task (self ._sample_data ())
7576 while not sample_task .done ():
7677 # sync weight to make sure the explorer can continue to explore and generate enough data
7778 if await self .need_sync ():
78- # Currently, we do not record the metrics of sync_weight here
79- await self .sync_weight ()
79+ metrics .update (await self .sync_weight ())
8080 await asyncio .sleep (1 )
81- exps , metrics , repr_samples = await sample_task
81+ exps , sample_metrics , repr_samples = await sample_task
82+ metrics .update (sample_metrics )
8283 self .logger .info (f"Sample data for step { self .train_step_num + 1 } finished." )
8384 metrics .update (await self .train_step (exps ))
8485 if await self .need_sync ():
85- # Record the time: sample_experience + train_step (>=1)
86- metrics .update ({"time/trainer_sync_interval" : time .time () - st })
8786 metrics .update (await self .sync_weight ())
8887 if self .need_save ():
8988 metrics .update (self .save_checkpoint ())
@@ -126,7 +125,7 @@ async def _sample_data(self) -> Tuple[Experiences, Dict, List[Dict]]:
126125 List[Dict]: A list of representative samples for logging.
127126 """
128127 batch , metrics , repr_samples = await self .sample_strategy .sample (self .train_step_num + 1 )
129- metrics ["sample/task_count" ] = len (set (eid .task for eid in batch .eids ))
128+ metrics ["sample/task_count" ] = len (set (eid .tid for eid in batch .eids ))
130129 return batch , metrics , repr_samples
131130
132131 async def need_sync (self ) -> bool :
@@ -155,6 +154,8 @@ async def sync_weight(self) -> Dict:
155154 """Sync the model weight."""
156155 self .logger .info (f"Trainer sync_weights at step { self .train_step_num } started." )
157156 metrics = {}
157+ if self .last_sync_time is not None :
158+ metrics ["time/trainer_sync_interval" ] = time .time () - self .last_sync_time
158159 with Timer (metrics , "time/sync_weight" ):
159160 if self .config .synchronizer .sync_method == SyncMethod .NCCL :
160161 result = await self .synchronizer .ready_to_nccl_sync .remote (
@@ -170,6 +171,7 @@ async def sync_weight(self) -> Dict:
170171 elif self .config .synchronizer .sync_method == SyncMethod .MEMORY :
171172 self .engine .upload_state_dict ()
172173 self .last_sync_step = self .train_step_num
174+ self .last_sync_time = time .time ()
173175 await self .synchronizer .set_trainer_status .remote (RunningStatus .RUNNING )
174176 self .logger .info (f"Trainer sync_weights at step { self .train_step_num } finished." )
175177 return metrics
0 commit comments