Skip to content

Commit ac0b2f0

Browse files
committed
1. Bug fix in trainer_test and explorer_test
2. Fix shutdown in `both` 3. Refactored the internal status transition logic of `Trainer` and `Explorer` in`Synchronizer`. 4. Avoid duplicate model saving. 5. Bug fix where model was exited before it was saved.
1 parent 65a27b7 commit ac0b2f0

File tree

11 files changed

+167
-135
lines changed

11 files changed

+167
-135
lines changed

tests/common/config_test.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,6 @@ def test_load_default_config(self):
3535
)
3636
self.assertEqual(config.model.model_path, config.model.critic_model_path)
3737
self.assertEqual(config.model.model_path, config.explorer.rollout_model.model_path)
38-
self.assertEqual(
39-
config.trainer.trainer_config.trainer.save_freq,
40-
config.synchronizer.sync_interval,
41-
)
4238

4339
def test_all_examples_are_valid(self):
4440
example_dir = os.path.join(os.path.dirname(__file__), "..", "..", "examples")

tests/tools.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818

1919
def get_template_config() -> Config:
2020
config_path = os.path.join(os.path.dirname(__file__), "template", "config.yaml")
21-
return load_config(config_path)
21+
config = load_config(config_path)
22+
config.ray_namespace = ray.get_runtime_context().namespace
23+
return config
2224

2325

2426
def get_model_path() -> str:

tests/trainer/trainer_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
)
2323
from trinity.cli.launcher import bench, both, explore, train
2424
from trinity.common.config import Config, StorageConfig
25-
from trinity.common.constants import StorageType, SyncMethod
25+
from trinity.common.constants import StorageType, SyncMethod, SyncStyle
2626
from trinity.common.models.utils import get_checkpoint_dir_with_step_num
2727
from trinity.manager.manager import CacheManager
2828

@@ -99,7 +99,7 @@ def test_trainer(self):
9999
self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_8, "actor"))) > 0)
100100
self.assertEqual(step_num, 8)
101101
# TODO: Reinit will fail when using v1 engine, find a way to fix it
102-
ray.init(ignore_reinit_error=True)
102+
ray.init(ignore_reinit_error=True, namespace=self.config.ray_namespace)
103103
# test bench mode
104104
self.config.mode = "bench"
105105
self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT
@@ -332,6 +332,7 @@ def test_fully_async_mode(self, name, use_priority_queue):
332332
use_priority_queue=use_priority_queue,
333333
)
334334
config.synchronizer.sync_method = SyncMethod.CHECKPOINT
335+
config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER
335336
config.synchronizer.sync_interval = 8
336337
config.monitor.monitor_type = "tensorboard"
337338
trainer_config = deepcopy(config)

trinity/cli/launcher.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,11 @@ def both(config: Config) -> None:
150150
"============================================================"
151151
)
152152
ray.wait(wait_ref, timeout=config.synchronizer.sync_timeout)
153-
explorer.shutdown.remote()
154-
trainer.shutdown.remote()
153+
ray.wait(
154+
[explorer.shutdown.remote(), trainer.shutdown.remote()],
155+
timeout=config.synchronizer.sync_timeout,
156+
num_returns=2,
157+
)
155158

156159

157160
def run(config_path: str, dlc: bool = False, plugin_dir: str = None):

trinity/common/synchronizer.py

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ class Synchronizer:
3333
def __init__(self, config: Config):
3434
self.logger = get_logger(__name__)
3535
self.config = config
36-
self.trainer_status = RunningStatus.RUNNING
37-
self.explorer_status_counter: Dict[RunningStatus, int] = {}
36+
self.trainer_status = RunningStatus.STOPPED
37+
self.explorer_status_counter: Dict[RunningStatus, int] = defaultdict(lambda: 0)
3838
self._ready_condition = asyncio.Condition()
3939
self.model_state_dict = None
4040
self.model_version = 0
@@ -62,9 +62,11 @@ def set_explorer_status(
6262
assert (
6363
old_status in self.explorer_status_counter
6464
), f"Invalid explorer status {old_status}"
65-
assert old_status != status
65+
assert old_status != status, f"Invalid status change from {old_status} to {status}"
6666
self.explorer_status_counter[old_status] -= 1
67-
assert self.explorer_status_counter[old_status] >= 0
67+
assert (
68+
self.explorer_status_counter[old_status] >= 0
69+
), f"Invalid status count {old_status} (new status {status})"
6870
if status not in self.explorer_status_counter:
6971
self.explorer_status_counter[status] = 0
7072
self.explorer_status_counter[status] += 1
@@ -88,9 +90,10 @@ async def set_model_state_dict_with_step_num(
8890
"""
8991
if world_size is not None: # Used when trainer updates the model
9092
assert step_num is not None
93+
assert self.checkpoint_shard_counter[step_num] < world_size, "World size mismatch!"
9194
self.checkpoint_shard_counter[step_num] += 1
9295
self.logger.info(
93-
f"Synchronizer received checkpoint {self.checkpoint_shard_counter[step_num]} of {world_size} shards"
96+
f"Synchronizer has received {self.checkpoint_shard_counter[step_num]} out of {world_size} shards from the checkpoint {step_num}."
9497
)
9598
if self.checkpoint_shard_counter[step_num] < world_size:
9699
return step_num
@@ -100,11 +103,14 @@ async def set_model_state_dict_with_step_num(
100103
trainer_type=self.config.trainer.trainer_type,
101104
step_num=step_num,
102105
)
103-
model_state_dict = load_state_dict(os.path.join(checkpoint_dir, "actor")) # TODO: to thread
104-
await self.set_model_state_dict(model_state_dict, checkpoint_step_num)
106+
if checkpoint_step_num != self.model_version:
107+
model_state_dict = load_state_dict(
108+
os.path.join(checkpoint_dir, "actor")
109+
) # TODO: to thread
110+
await self.set_model_state_dict(model_state_dict, checkpoint_step_num)
105111
return checkpoint_step_num
106112

107-
async def set_model_state_dict(self, model_state_dict, trainer_step):
113+
async def set_model_state_dict(self, model_state_dict: Union[dict, None], trainer_step: int):
108114
"""
109115
Set the new model state and update the version.
110116
@@ -152,7 +158,7 @@ async def setup_weight_sync_group(
152158
explorer = ray.get_actor(self.config.explorer_name)
153159
await explorer.setup_weight_sync_group.remote(master_address, master_port, state_dict_meta)
154160

155-
async def wait_new_model_state_dict(self, current_version: int) -> int:
161+
async def wait_new_model_state_dict(self, current_version: int, no_wait: bool = False) -> int:
156162
"""
157163
Wait until a new model state is available.
158164
@@ -163,14 +169,21 @@ async def wait_new_model_state_dict(self, current_version: int) -> int:
163169
The new model version after it has been updated.
164170
"""
165171
async with self._ready_condition:
166-
if self.model_version <= current_version:
172+
assert (
173+
self.model_version >= current_version
174+
), f"The model version in Synchronizer ({self.model_version}) should be greater than that in Explorer ({current_version})!"
175+
if self.model_version == current_version:
176+
if not no_wait and self.trainer_status != RunningStatus.STOPPED:
177+
# TODO: explorer need support no wait
178+
# TODO: handle timeout
179+
await asyncio.wait_for(
180+
self._ready_condition.wait(),
181+
timeout=self.config.synchronizer.sync_timeout,
182+
)
183+
if self.model_version > current_version:
167184
self.set_explorer_status(
168185
RunningStatus.WAITING_SYNC, old_status=RunningStatus.REQUIRE_SYNC
169186
)
170-
await asyncio.wait_for(
171-
self._ready_condition.wait(),
172-
timeout=self.config.synchronizer.sync_timeout,
173-
)
174187
return self.model_version
175188

176189
async def ready_to_nccl_sync(
@@ -191,6 +204,29 @@ async def ready_to_nccl_sync(
191204
assert (
192205
sum(self.explorer_status_counter.values()) == 1
193206
), "NCCL sync is only supported for one explorer."
207+
208+
def sync_failed():
209+
if module == "explorer":
210+
another_module = "Trainer"
211+
self.set_explorer_status(
212+
RunningStatus.REQUIRE_SYNC, old_status=RunningStatus.WAITING_SYNC
213+
)
214+
else:
215+
another_module = "Explorer"
216+
self.trainer_status = RunningStatus.REQUIRE_SYNC
217+
self.logger.error(
218+
f"{another_module} is not ready for model weight sync in {self.config.synchronizer.sync_timeout} seconds."
219+
)
220+
return None
221+
222+
non_stop_cnt = sum(
223+
value
224+
for key, value in self.explorer_status_counter.items()
225+
if key != RunningStatus.STOPPED
226+
)
227+
if non_stop_cnt == 0:
228+
return sync_failed()
229+
# for status in RunningStatus:
194230
async with self._ready_condition:
195231
try:
196232
if module == "trainer":
@@ -219,11 +255,7 @@ async def ready_to_nccl_sync(
219255
)
220256
return self.model_version
221257
except asyncio.TimeoutError:
222-
another_module = "Trainer" if module == "explorer" else "Explorer"
223-
self.logger.error(
224-
f"{another_module} is not ready for model weight sync in {self.config.synchronizer.sync_timeout} seconds."
225-
)
226-
return None
258+
return sync_failed()
227259

228260
@classmethod
229261
def get_actor(cls, config: Optional[Config] = None, namespace: Optional[str] = None):

trinity/explorer/explorer.py

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,8 @@ def __init__(self, config: Config):
7272
# For checkpoint weights update
7373
# Use explorer to periodically load the latest model weights and
7474
# boradcast to all rollout models
75-
self.model_version = 0
76-
if self.use_state_dict_weights_update:
77-
self.old_checkpoint = None
78-
self.state_dict = {}
79-
else: # nccl mode
80-
self.state_dict_meta = []
75+
self.model_version = -1
76+
self.last_sync_successful = True
8177
self.logger.info("Finished initializing Explorer.")
8278
self.collect_experiences = self.config.explorer.collect_experiences
8379
self.generated_experience_cnt = 0
@@ -102,7 +98,6 @@ async def setup_weight_sync_group(
10298
f"master_address={master_address}, master_port={master_port}, "
10399
f"world_size={world_size}, rank_offset={base_offset}"
104100
)
105-
self.state_dict_meta = state_dict_meta
106101
# TODO: save state_dict in models
107102
refs = [
108103
model.init_process_group.remote(
@@ -130,21 +125,6 @@ def _init_scheduler(self) -> Scheduler:
130125
)
131126
return Scheduler(self.config, self.models, self.auxiliary_models)
132127

133-
async def _update_model_weight(self, step_num: int, state_dict: dict) -> None:
134-
# TODO: update model weight
135-
self.state_dict = state_dict
136-
if self.state_dict_meta is None:
137-
update_weight_args_list = []
138-
for name, param in state_dict.items():
139-
update_weight_args_list.append((name, str(param.dtype), tuple(param.shape)))
140-
self.state_dict_meta = update_weight_args_list
141-
else:
142-
update_weight_args_list = None
143-
await asyncio.gather(
144-
*[model.sync_model.remote(step_num, update_weight_args_list) for model in self.models]
145-
)
146-
self.state_dict.clear()
147-
148128
async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> int:
149129
step_num = ray.get(self.synchronizer.set_model_state_dict_with_step_num.remote(step_num))
150130
await asyncio.gather(*[model.sync_model.remote(step_num) for model in self.models])
@@ -156,41 +136,59 @@ async def _state_dict_update(self):
156136
self.synchronizer.wait_new_model_state_dict.remote(self.model_version)
157137
)
158138
if new_version > self.model_version:
159-
self.logger.info(f"New model state dict version: {new_version}")
160-
await asyncio.gather(*[model.sync_model.remote(new_version) for model in self.models])
139+
if self.model_version != -1:
140+
self.logger.info(f"New model state dict version: {new_version}")
141+
await asyncio.gather(
142+
*[model.sync_model.remote(new_version) for model in self.models]
143+
)
161144
self.model_version = new_version
145+
self.last_sync_step = self.explore_step_num
146+
ray.get(
147+
self.synchronizer.set_explorer_status.remote(
148+
RunningStatus.RUNNING, old_status=RunningStatus.WAITING_SYNC
149+
)
150+
)
151+
self.last_sync_successful = True
162152
else:
163153
self.logger.warning(
164154
f"No new model state dict found, current version: {self.model_version}"
165155
)
156+
self.last_sync_successful = False
166157

167158
async def _nccl_weights_update(self):
168-
assert self.state_dict_meta is not None
169159
new_version = ray.get(
170160
self.synchronizer.ready_to_nccl_sync.remote("explorer", self.model_version)
171161
)
172162
if new_version is None:
173163
self.logger.info("Trainer is not ready to sync weight. Skipping sync weight.")
164+
self.last_sync_successful = False
174165
return
175166
self.model_version = new_version
176167
await asyncio.gather(
177-
*[model.sync_model.remote(self.explore_step_num) for model in self.models]
168+
*[model.sync_model.remote(self.model_version) for model in self.models]
178169
)
170+
self.last_sync_step = self.explore_step_num
171+
ray.get(
172+
self.synchronizer.set_explorer_status.remote(
173+
RunningStatus.RUNNING, old_status=RunningStatus.WAITING_SYNC
174+
)
175+
)
176+
self.last_sync_successful = True
179177

180178
async def prepare(self) -> None:
181179
"""Preparation before running."""
180+
if self.experience_buffer:
181+
await self.experience_buffer.acquire()
182182
futures = [asyncio.create_task(self.scheduler.start())]
183183
if self.use_state_dict_weights_update:
184184
master_address, master_port = await self.models[0].get_available_address.remote()
185185
futures.append(
186186
asyncio.create_task(self.setup_weight_sync_group(master_address, master_port))
187187
)
188188
asyncio.gather(*futures, return_exceptions=True)
189-
await self.synchronizer.set_explorer_status.remote(RunningStatus.REQUIRE_SYNC)
190-
if self.experience_buffer:
191-
await self.experience_buffer.acquire()
192189
if self.config.explorer.eval_on_startup and self.explore_step_num == 0:
193190
self.eval()
191+
await self.synchronizer.set_explorer_status.remote(RunningStatus.REQUIRE_SYNC)
194192

195193
async def get_weight(self, name: str) -> torch.Tensor:
196194
"""Get the weight of the loaded model (For checkpoint weights update)."""
@@ -237,7 +235,10 @@ async def explore_step(self) -> bool:
237235
self.logger.warning("No more tasks to explore. Stop exploring.")
238236
await self.save_checkpoint(sync_weight=False)
239237
await self.synchronizer.set_explorer_status.remote(
240-
RunningStatus.STOPPED, old_status=RunningStatus.RUNNING
238+
RunningStatus.STOPPED,
239+
old_status=RunningStatus.RUNNING
240+
if self.last_sync_successful
241+
else RunningStatus.REQUIRE_SYNC,
241242
)
242243
await self.experience_buffer.release()
243244
return False
@@ -249,7 +250,7 @@ def need_sync(self) -> bool:
249250
if self.config.synchronizer.sync_style == SyncStyle.FIXED:
250251
if self.explore_step_num <= self.config.synchronizer.sync_offset:
251252
return False
252-
return (
253+
require_sync = (
253254
self.explore_step_num - self.config.synchronizer.sync_offset
254255
) % self.config.synchronizer.sync_interval == 0
255256
else:
@@ -263,13 +264,13 @@ def need_sync(self) -> bool:
263264
ray.get(self.synchronizer.get_trainer_status.remote())
264265
== RunningStatus.REQUIRE_SYNC
265266
)
266-
if require_sync:
267-
ray.get(
268-
self.synchronizer.set_explorer_status.remote(
269-
RunningStatus.REQUIRE_SYNC, old_status=RunningStatus.RUNNING
270-
)
267+
if require_sync and self.last_sync_successful:
268+
ray.get(
269+
self.synchronizer.set_explorer_status.remote(
270+
RunningStatus.REQUIRE_SYNC, old_status=RunningStatus.RUNNING
271271
)
272-
return require_sync
272+
)
273+
return require_sync
273274

274275
def need_eval(self) -> bool:
275276
return self.explore_step_num % self.config.explorer.eval_interval == 0
@@ -338,8 +339,9 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None:
338339
await self._state_dict_update()
339340
else: # nccl weights update
340341
await self._nccl_weights_update()
341-
self.last_sync_step = self.explore_step_num
342-
self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} finished")
342+
self.logger.info(
343+
f"Explorer sync_weights at step {self.explore_step_num} finished, model version = {self.model_version}."
344+
)
343345

344346
# overlay log and weight sync
345347
await log_task
@@ -354,11 +356,6 @@ async def sync_weight(self) -> None:
354356
"""Synchronize model weights."""
355357
# call this method before training start to load the latest model weights
356358
await self.save_checkpoint(sync_weight=True)
357-
ray.get(
358-
self.synchronizer.set_explorer_status.remote(
359-
RunningStatus.RUNNING, old_status=RunningStatus.WAITING_SYNC
360-
)
361-
)
362359

363360
async def _finish_steps(self, start_step: int, end_step: int, model_version: int) -> None:
364361
for step in range(start_step, end_step + 1):

trinity/explorer/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def task_done_callback(self, async_task: asyncio.Task):
227227
if async_task.cancelled():
228228
return
229229
elif async_task.exception():
230-
self.logger.error(f"Task {task.task_id} failed: {async_task.exception()}")
230+
self.logger.error(f"Task {task.task.task_id} failed: {async_task.exception()}")
231231
return
232232
else:
233233
status, exps, runner_id = async_task.result()

0 commit comments

Comments
 (0)