66import os
77import time
88from collections import defaultdict
9- from typing import List , Optional , Tuple
9+ from typing import List , Optional
1010
1111import torch
1212
@@ -65,6 +65,7 @@ def __init__(self, config: Config):
6565 self .use_checkpoint_weights_update = (
6666 self .config .synchronizer .sync_method == SyncMethod .CHECKPOINT
6767 )
68+ self .eval_explore_step_num = None
6869
6970 # For checkpoint weights update
7071 # Use explorer to periodically load the latest model weights and
@@ -170,17 +171,26 @@ async def get_weight(self, name: str) -> torch.Tensor:
170171 return self .state_dict [name ]
171172
172173 async def explore (self ) -> str :
174+ """
175+ The dreamming loop for explorer and trainer.
176+ | <----------------------------------------- one period ----------------------------------------------> |
177+ explorer | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- eval --> | <-- [idle] --> | <-- sync --> |
178+ trainer | <-- idle --> | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- [idle] --> | <-- sync --> |
179+ """
180+ self .eval_explore_step_num = None
173181 while True :
174182 try :
183+ if (
184+ self .eval_explore_step_num is None
185+ and self .explore_step_num % self .config .explorer .eval_interval == 0
186+ ):
187+ self .eval_explore_step_num = self .explore_step_num
175188 explore_contionue = self .explore_step ()
176189 if not explore_contionue :
177190 break
178191 if self .need_sync ():
179192 self .wait_for_workflow_done ()
180193 await self .sync_weight ()
181- if self .explore_step_num % self .config .explorer .eval_interval == 0 :
182- self .wait_for_workflow_done ()
183- self .eval ()
184194 except Exception as e :
185195 self .logger .error (f"Error in Explorer: { e } " )
186196 break
@@ -216,16 +226,18 @@ def need_sync(self) -> bool:
216226 self .explore_step_num - self .config .synchronizer .sync_offset
217227 ) % self .config .synchronizer .sync_interval == 0
218228
219- def eval (self ) -> Tuple [ bool , int ] :
229+ def eval (self , eval_explore_step_num : int ) :
220230 """Evaluation on all evaluation data samples."""
221231 if len (self .config .buffer .explorer_input .eval_tasksets ) == 0 :
222232 self .logger .warning ("No evaluation data samples. Skip evaluation." )
223- return True , self . explore_step_num
224- self .logger .info ("Evaluation started." )
233+ return
234+ self .logger .info (f "Evaluation at step { eval_explore_step_num } started." )
225235 all_st = time .time ()
226236 log_metrics = {}
227237 for eval_taskset_config in self .config .buffer .explorer_input .eval_tasksets :
228- self .logger .info (f"Evaluation on { eval_taskset_config .name } started." )
238+ self .logger .info (
239+ f"Evaluation on { eval_taskset_config .name } at step { eval_explore_step_num } started."
240+ )
229241 eval_taskset = get_buffer_reader (eval_taskset_config , self .config .buffer )
230242 st = time .time ()
231243 all_metrics = defaultdict (list )
@@ -254,18 +266,19 @@ def wait():
254266 log_metrics .update (metrics )
255267 log_metrics [f"eval/{ eval_taskset .name } /time" ] = time .time () - st
256268 log_metrics ["eval/total_time" ] = time .time () - all_st
257- self .monitor .log (log_metrics , step = self .explore_step_num ) # type: ignore
258- self .logger .info ("Evaluation finished." )
259- return True , self .explore_step_num
269+ self .monitor .log (log_metrics , step = eval_explore_step_num ) # type: ignore
270+ self .logger .info (f"Evaluation at step { eval_explore_step_num } finished." )
260271
261272 async def benchmark (self ) -> bool :
262273 """Benchmark the model checkpoints."""
263274 # benchmark on the latest checkpoint
264275 if self .config .explorer .eval_on_latest_checkpoint :
265276 await self ._checkpoint_weights_update ()
266- self .eval ()
277+ self .eval (self . explore_step_num )
267278 return True
268279
280+ # benchmark on base model
281+ self .eval (0 )
269282 # benchmark on all checkoints
270283 all_ckp_steps = sorted (
271284 [
@@ -276,9 +289,8 @@ async def benchmark(self) -> bool:
276289 ]
277290 )
278291 for step_num in all_ckp_steps :
279- self .explore_step_num = step_num
280292 await self ._checkpoint_weights_update (step_num = step_num )
281- self .eval ()
293+ self .eval (step_num )
282294 return True
283295
284296 def wait_for_workflow_done (self ) -> None :
@@ -302,6 +314,10 @@ def wait_for_workflow_done(self) -> None:
302314 else :
303315 for metric_name , metric_value in status .metric .items ():
304316 all_metrics [metric_name ].append (metric_value )
317+ # eval
318+ if self .eval_explore_step_num is not None :
319+ self .eval (self .eval_explore_step_num )
320+ self .eval_explore_step_num = None
305321 # calculate metrics
306322 log_metrics = self .monitor .calculate_metrics (all_metrics , prefix = "rollout" ) # type: ignore
307323 self .monitor .log (log_metrics , step = self .explore_step_num )
0 commit comments