@@ -72,12 +72,8 @@ def __init__(self, config: Config):
7272 # For checkpoint weights update
7373 # Use explorer to periodically load the latest model weights and
7474 # boradcast to all rollout models
75- self .model_version = 0
76- if self .use_state_dict_weights_update :
77- self .old_checkpoint = None
78- self .state_dict = {}
79- else : # nccl mode
80- self .state_dict_meta = []
75+ self .model_version = - 1
76+ self .last_sync_successful = True
8177 self .logger .info ("Finished initializing Explorer." )
8278 self .collect_experiences = self .config .explorer .collect_experiences
8379 self .generated_experience_cnt = 0
@@ -102,7 +98,6 @@ async def setup_weight_sync_group(
10298 f"master_address={ master_address } , master_port={ master_port } , "
10399 f"world_size={ world_size } , rank_offset={ base_offset } "
104100 )
105- self .state_dict_meta = state_dict_meta
106101 # TODO: save state_dict in models
107102 refs = [
108103 model .init_process_group .remote (
@@ -130,21 +125,6 @@ def _init_scheduler(self) -> Scheduler:
130125 )
131126 return Scheduler (self .config , self .models , self .auxiliary_models )
132127
133- async def _update_model_weight (self , step_num : int , state_dict : dict ) -> None :
134- # TODO: update model weight
135- self .state_dict = state_dict
136- if self .state_dict_meta is None :
137- update_weight_args_list = []
138- for name , param in state_dict .items ():
139- update_weight_args_list .append ((name , str (param .dtype ), tuple (param .shape )))
140- self .state_dict_meta = update_weight_args_list
141- else :
142- update_weight_args_list = None
143- await asyncio .gather (
144- * [model .sync_model .remote (step_num , update_weight_args_list ) for model in self .models ]
145- )
146- self .state_dict .clear ()
147-
148128 async def _checkpoint_weights_update (self , step_num : Optional [int ] = None ) -> int :
149129 step_num = ray .get (self .synchronizer .set_model_state_dict_with_step_num .remote (step_num ))
150130 await asyncio .gather (* [model .sync_model .remote (step_num ) for model in self .models ])
@@ -156,41 +136,59 @@ async def _state_dict_update(self):
156136 self .synchronizer .wait_new_model_state_dict .remote (self .model_version )
157137 )
158138 if new_version > self .model_version :
159- self .logger .info (f"New model state dict version: { new_version } " )
160- await asyncio .gather (* [model .sync_model .remote (new_version ) for model in self .models ])
139+ if self .model_version != - 1 :
140+ self .logger .info (f"New model state dict version: { new_version } " )
141+ await asyncio .gather (
142+ * [model .sync_model .remote (new_version ) for model in self .models ]
143+ )
161144 self .model_version = new_version
145+ self .last_sync_step = self .explore_step_num
146+ ray .get (
147+ self .synchronizer .set_explorer_status .remote (
148+ RunningStatus .RUNNING , old_status = RunningStatus .WAITING_SYNC
149+ )
150+ )
151+ self .last_sync_successful = True
162152 else :
163153 self .logger .warning (
164154 f"No new model state dict found, current version: { self .model_version } "
165155 )
156+ self .last_sync_successful = False
166157
167158 async def _nccl_weights_update (self ):
168- assert self .state_dict_meta is not None
169159 new_version = ray .get (
170160 self .synchronizer .ready_to_nccl_sync .remote ("explorer" , self .model_version )
171161 )
172162 if new_version is None :
173163 self .logger .info ("Trainer is not ready to sync weight. Skipping sync weight." )
164+ self .last_sync_successful = False
174165 return
175166 self .model_version = new_version
176167 await asyncio .gather (
177- * [model .sync_model .remote (self .explore_step_num ) for model in self .models ]
168+ * [model .sync_model .remote (self .model_version ) for model in self .models ]
178169 )
170+ self .last_sync_step = self .explore_step_num
171+ ray .get (
172+ self .synchronizer .set_explorer_status .remote (
173+ RunningStatus .RUNNING , old_status = RunningStatus .WAITING_SYNC
174+ )
175+ )
176+ self .last_sync_successful = True
179177
180178 async def prepare (self ) -> None :
181179 """Preparation before running."""
180+ if self .experience_buffer :
181+ await self .experience_buffer .acquire ()
182182 futures = [asyncio .create_task (self .scheduler .start ())]
183183 if self .use_state_dict_weights_update :
184184 master_address , master_port = await self .models [0 ].get_available_address .remote ()
185185 futures .append (
186186 asyncio .create_task (self .setup_weight_sync_group (master_address , master_port ))
187187 )
188188 asyncio .gather (* futures , return_exceptions = True )
189- await self .synchronizer .set_explorer_status .remote (RunningStatus .REQUIRE_SYNC )
190- if self .experience_buffer :
191- await self .experience_buffer .acquire ()
192189 if self .config .explorer .eval_on_startup and self .explore_step_num == 0 :
193190 self .eval ()
191+ await self .synchronizer .set_explorer_status .remote (RunningStatus .REQUIRE_SYNC )
194192
195193 async def get_weight (self , name : str ) -> torch .Tensor :
196194 """Get the weight of the loaded model (For checkpoint weights update)."""
@@ -237,7 +235,10 @@ async def explore_step(self) -> bool:
237235 self .logger .warning ("No more tasks to explore. Stop exploring." )
238236 await self .save_checkpoint (sync_weight = False )
239237 await self .synchronizer .set_explorer_status .remote (
240- RunningStatus .STOPPED , old_status = RunningStatus .RUNNING
238+ RunningStatus .STOPPED ,
239+ old_status = RunningStatus .RUNNING
240+ if self .last_sync_successful
241+ else RunningStatus .REQUIRE_SYNC ,
241242 )
242243 await self .experience_buffer .release ()
243244 return False
@@ -249,7 +250,7 @@ def need_sync(self) -> bool:
249250 if self .config .synchronizer .sync_style == SyncStyle .FIXED :
250251 if self .explore_step_num <= self .config .synchronizer .sync_offset :
251252 return False
252- return (
253+ require_sync = (
253254 self .explore_step_num - self .config .synchronizer .sync_offset
254255 ) % self .config .synchronizer .sync_interval == 0
255256 else :
@@ -263,13 +264,13 @@ def need_sync(self) -> bool:
263264 ray .get (self .synchronizer .get_trainer_status .remote ())
264265 == RunningStatus .REQUIRE_SYNC
265266 )
266- if require_sync :
267- ray .get (
268- self .synchronizer .set_explorer_status .remote (
269- RunningStatus .REQUIRE_SYNC , old_status = RunningStatus .RUNNING
270- )
267+ if require_sync and self .last_sync_successful :
268+ ray .get (
269+ self .synchronizer .set_explorer_status .remote (
270+ RunningStatus .REQUIRE_SYNC , old_status = RunningStatus .RUNNING
271271 )
272- return require_sync
272+ )
273+ return require_sync
273274
274275 def need_eval (self ) -> bool :
275276 return self .explore_step_num % self .config .explorer .eval_interval == 0
@@ -338,8 +339,9 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None:
338339 await self ._state_dict_update ()
339340 else : # nccl weights update
340341 await self ._nccl_weights_update ()
341- self .last_sync_step = self .explore_step_num
342- self .logger .info (f"Explorer sync_weights at step { self .explore_step_num } finished" )
342+ self .logger .info (
343+ f"Explorer sync_weights at step { self .explore_step_num } finished, model version = { self .model_version } ."
344+ )
343345
344346 # overlay log and weight sync
345347 await log_task
@@ -354,11 +356,6 @@ async def sync_weight(self) -> None:
354356 """Synchronize model weights."""
355357 # call this method before training start to load the latest model weights
356358 await self .save_checkpoint (sync_weight = True )
357- ray .get (
358- self .synchronizer .set_explorer_status .remote (
359- RunningStatus .RUNNING , old_status = RunningStatus .WAITING_SYNC
360- )
361- )
362359
363360 async def _finish_steps (self , start_step : int , end_step : int , model_version : int ) -> None :
364361 for step in range (start_step , end_step + 1 ):
0 commit comments