1414from trinity .buffer import get_buffer_writer
1515from trinity .buffer .buffer import get_buffer_reader
1616from trinity .common .config import Config
17- from trinity .common .constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME , SyncMethod
17+ from trinity .common .constants import (
18+ EXPLORER_NAME ,
19+ ROLLOUT_WEIGHT_SYNC_GROUP_NAME ,
20+ RunningStatus ,
21+ SyncMethod ,
22+ )
1823from trinity .common .models import create_inference_models
1924from trinity .common .models .utils import (
2025 get_checkpoint_dir_with_step_num ,
@@ -50,7 +55,7 @@ def __init__(self, config: Config):
5055 self .monitor = MONITOR .get (self .config .monitor .monitor_type )(
5156 project = self .config .project ,
5257 name = self .config .name ,
53- role = "explorer" ,
58+ role = EXPLORER_NAME ,
5459 config = config ,
5560 )
5661 self .batch_size = config .buffer .batch_size
@@ -69,6 +74,7 @@ def __init__(self, config: Config):
6974 self .state_dict = {}
7075 else : # nccl mode
7176 self .state_dict_meta = []
77+ self .status = RunningStatus .RUNNING
7278 self .logger .info ("Finished initializing Explorer." )
7379
7480 async def setup_weight_sync_group (
@@ -162,35 +168,44 @@ async def get_weight(self, name: str) -> torch.Tensor:
162168 """Get the weight of the loaded model (For checkpoint weights update)."""
163169 return self .state_dict [name ]
164170
165- async def explore (self ) -> None :
171+ async def explore (self ) -> str :
166172 while True :
167173 try :
168174 explore_contionue = self .explore_step ()
175+ if not explore_contionue :
176+ break
169177 if self .need_sync ():
170178 self .wait_for_workflow_done ()
171179 await self .sync_weight ()
172180 if self .explore_step_num % self .config .explorer .eval_interval == 0 :
173181 self .wait_for_workflow_done ()
174182 self .eval ()
175- if not explore_contionue :
176- break
177183 except Exception as e :
178184 self .logger .error (f"Error in Explorer: { e } " )
179185 break
180- self .logger .info ("--------------------\n > Explorer finished.\n --------------------\n " )
186+ self .logger .info ("--------------------\n > Explorer finished.\n --------------------" )
187+ return EXPLORER_NAME
181188
182189 def explore_step (self ) -> bool :
183- self .explore_step_num += 1
184- algo_config = self .algorithm_manager .get_current_algorithm_config (self .explore_step_num )
190+ algo_config = self .algorithm_manager .get_current_algorithm_config (self .explore_step_num + 1 )
185191 # skip warmup
186192 if algo_config .algorithm_type == "sft" :
193+ self .explore_step_num += 1
187194 return True
188195 try :
189196 tasks = self .taskset .read ()
190197 except StopIteration :
191198 self .logger .warning ("No more tasks to explore. Stop exploring." )
199+ self .cache .save_explorer (
200+ current_step = self .explore_step_num ,
201+ current_task_index = self .explore_step_num * self .config .buffer .batch_size ,
202+ )
203+ self .status = RunningStatus .STOPPED
204+ self .wait_for_workflow_done ()
205+ self .experience_buffer .finish ()
192206 return False
193207 self .runner_pool .run_tasks (tasks )
208+ self .explore_step_num += 1
194209 return True
195210
196211 def need_sync (self ) -> bool :
@@ -278,20 +293,25 @@ def wait_for_workflow_done(self) -> None:
278293 if not status .ok :
279294 self .logger .error (f"Error when running task: { status .message } " )
280295 # submit another task to replace the failed task
281- self .runner_pool .run_tasks (self .taskset .read (batch_size = 1 ))
296+ try :
297+ tasks = self .taskset .read (batch_size = 1 )
298+ except StopIteration :
299+ self .logger .warning ("No more tasks in taskset. Stop retrying." )
300+ return
301+ self .runner_pool .run_tasks (tasks )
282302 else :
283303 for metric_name , metric_value in status .metric .items ():
284304 all_metrics [metric_name ].append (metric_value )
285305 # calculate metrics
286306 log_metrics = self .monitor .calculate_metrics (all_metrics , prefix = "rollout" ) # type: ignore
287307 self .monitor .log (log_metrics , step = self .explore_step_num )
288-
289308 self .logger .info (f"Explore step { self .explore_step_num } finished." )
290309
291310 async def sync_weight (self ) -> None :
292311 """Synchronize model weights."""
293312 # call this method before training start to load the latest model weights
294- self .logger .info (f"Explorer synchronizing weights at step { self .explore_step_num } ." )
313+ self .logger .info (f"Explorer sync weights at step { self .explore_step_num } ." )
314+ self .status = RunningStatus .WAITING_SYNC
295315 if self .use_checkpoint_weights_update :
296316 await self ._checkpoint_weights_update ()
297317 else : # nccl weights update
@@ -301,7 +321,11 @@ async def sync_weight(self) -> None:
301321 current_step = self .explore_step_num ,
302322 current_task_index = self .explore_step_num * self .config .buffer .batch_size ,
303323 )
304- self .logger .info (f"Explorer synchronizing at step { self .explore_step_num } finished" )
324+ self .status = RunningStatus .RUNNING
325+ self .logger .info (f"Explorer sync at step { self .explore_step_num } finished" )
326+
327+ async def running_status (self ) -> RunningStatus :
328+ return self .status
305329
306330 def flush_log (self , step : int ) -> None :
307331 """Flush the log of the current step."""
0 commit comments