From 6b3ad37dcf6323b6ffddf9dc88533b4bd87afa98 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 16 Jul 2025 17:20:14 +0800 Subject: [PATCH 01/16] add synchronize v2 --- trinity/common/constants.py | 7 ++++ trinity/common/synchronizer.py | 77 ++++++++++++++++++++++++++++++++++ trinity/explorer/explorer.py | 46 ++++++++++++-------- trinity/trainer/trainer.py | 29 ++++++++----- 4 files changed, 131 insertions(+), 28 deletions(-) create mode 100644 trinity/common/synchronizer.py diff --git a/trinity/common/constants.py b/trinity/common/constants.py index bac4941453..505eb15875 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -101,6 +101,7 @@ class RunningStatus(Enum): """Running status of explorer and trainer.""" RUNNING = "running" + WANT_SYNC = "want_sync" WAITING_SYNC = "waiting_sync" STOPPED = "stopped" @@ -119,3 +120,9 @@ class OpType(Enum): SUB = "sub" MUL = "mul" DIV = "div" + + +class SyncStyle(CaseInsensitiveEnum): + FIXED = "fixed" + DYNAMIC_BY_TRAINER = "dynamic_by_trainer" + DYNAMIC_BY_EXPLORER = "dynamic_by_explorer" diff --git a/trinity/common/synchronizer.py b/trinity/common/synchronizer.py new file mode 100644 index 0000000000..bfb2277cbc --- /dev/null +++ b/trinity/common/synchronizer.py @@ -0,0 +1,77 @@ +""" """ + +import asyncio +from typing import List + +import ray + +from trinity.common.config import Config +from trinity.common.constants import RunningStatus, SyncStyle + + +class Synchronizer: + def __init__(self, config: Config): + self.config = config + self.trainer_status = RunningStatus.RUNNING + self.last_trainer_sync_step = 0 + self.explorer_status = RunningStatus.RUNNING + self.last_explorer_sync_step = 0 + self.ready_count = 0 + self._ready_condition = asyncio.Condition() + + def set_trainer_status(self, status: RunningStatus): + self.trainer_status = status + + def set_explorer_status(self, status: RunningStatus): + self.explorer_status = status + + def get_trainer_status(self) -> RunningStatus: + return self.trainer_status + + def get_explorer_status(self) -> RunningStatus: + return self.explorer_status + + async def setup_weight_sync_group( + self, master_address: str, master_port: int, state_dict_meta: List = None + ): + explorer = ray.get_actor(self.config.explorer_name) + await explorer.setup_weight_sync_group.remote(master_address, master_port, state_dict_meta) + + async def ready_to_sync(self, module: str): + async with self._ready_condition: + try: + if module == "trainer": + self.trainer_status = RunningStatus.WAITING_SYNC + self._ready_condition.notify_all() + if self.explorer_status != RunningStatus.WAITING_SYNC: + await asyncio.wait_for( + self._ready_condition.wait_for( + lambda: self.explorer_status == RunningStatus.WAITING_SYNC, + ), + timeout=self.config.synchronizer.sync_timeout, + ) + elif module == "explorer": + self.explorer_status = RunningStatus.WAITING_SYNC + self._ready_condition.notify_all() + if self.trainer_status != RunningStatus.WAITING_SYNC: + await asyncio.wait_for( + self._ready_condition.wait_for( + lambda: self.trainer_status == RunningStatus.WAITING_SYNC, + ), + timeout=self.config.synchronizer.sync_timeout, + ) + return True + except asyncio.TimeoutError: + another_module = "Trainer" if module == "explorer" else "Explorer" + self.logger.error( + f"{another_module} is not ready for model weight sync in {self.config.synchronizer.sync_timeout} seconds." + ) + return False + + @classmethod + def get_actor(cls, config: Config): + return ( + ray.remote(cls) + .options(name="synchronizer", namespace=config.ray_namespace, get_if_exists=True) + .remote(config) + ) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index c22516be42..d9a77cfa84 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -9,6 +9,7 @@ from collections import deque from typing import List, Optional +import ray import torch from trinity.algorithm.algorithm_manager import AlgorithmManager @@ -19,12 +20,14 @@ ROLLOUT_WEIGHT_SYNC_GROUP_NAME, RunningStatus, SyncMethod, + SyncStyle, ) from trinity.common.models import create_inference_models from trinity.common.models.utils import ( get_checkpoint_dir_with_step_num, load_state_dict, ) +from trinity.common.synchronizer import Synchronizer from trinity.explorer.scheduler import Scheduler from trinity.manager.manager import CacheManager from trinity.utils.log import get_logger @@ -81,6 +84,9 @@ def __init__(self, config: Config): self.logger.info("Finished initializing Explorer.") self._ready_to_sync_condition = asyncio.Condition() + self.synchronizer = Synchronizer.get_actor(config) + self.last_explorer_sync_step = self.explore_step_num + async def setup_weight_sync_group( self, master_address: str, master_port: int, state_dict_meta: List = None ): @@ -157,19 +163,10 @@ async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> in async def _nccl_weights_update(self): assert self.state_dict_meta is not None - async with self._ready_to_sync_condition: - try: - await asyncio.wait_for( - self._ready_to_sync_condition.wait_for( - lambda: self.status == RunningStatus.WAITING_SYNC, - ), - timeout=self.config.synchronizer.sync_timeout, - ) - except asyncio.TimeoutError as e: - self.logger.error( - f"Trainer is not ready for model weight sync in {self.config.synchronizer.sync_timeout} seconds." - ) - raise e + status = ray.get(self.synchronizer.ready_to_sync.remote("explorer")) + if not status: + self.logger.info("Trainer is not ready to sync weight. Skipping sync weight.") + return await asyncio.gather( *[model.sync_model.remote(self.explore_step_num) for model in self.models] ) @@ -246,11 +243,23 @@ async def explore_step(self) -> bool: return True def need_sync(self) -> bool: - if self.explore_step_num <= self.config.synchronizer.sync_offset: - return False - return ( - self.explore_step_num - self.config.synchronizer.sync_offset - ) % self.config.synchronizer.sync_interval == 0 + if self.config.synchronizer.sync_style == SyncStyle.FIXED: + if self.explore_step_num <= self.config.synchronizer.sync_offset: + return False + return ( + self.explore_step_num - self.config.synchronizer.sync_offset + ) % self.config.synchronizer.sync_interval == 0 + else: + need_sync = False + if self.config.synchronizer.sync_style == SyncStyle.DYNAMIC_BY_EXPLORER: + delta = self.explore_step_num - self.last_explorer_sync_step + if delta >= self.config.synchronizer.sync_interval: + need_sync = True + else: + need_sync = ray.get(self.synchronizer.get_trainer_status == RunningStatus.WANT_SYNC) + if need_sync: + ray.get(self.synchronizer.set_explorer_status.remote(RunningStatus.WANT_SYNC)) + return need_sync def need_eval(self) -> bool: return self.explore_step_num % self.config.explorer.eval_interval == 0 @@ -334,6 +343,7 @@ async def sync_weight(self) -> None: """Synchronize model weights.""" # call this method before training start to load the latest model weights await self.save_checkpoint(sync_weight=True) + self.last_explorer_sync_step = self.explore_step_num async def _log_metrics(self, start_step: int, end_step: int) -> None: for step in range(start_step, end_step + 1): diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 1378449cf2..c472c99823 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -11,7 +11,8 @@ import ray from trinity.common.config import Config -from trinity.common.constants import RunningStatus, SyncMethod +from trinity.common.constants import RunningStatus, SyncMethod, SyncStyle +from trinity.common.synchronizer import Synchronizer from trinity.utils.log import get_logger @@ -22,11 +23,13 @@ def __init__(self, config: Config) -> None: self.config = config self.logger = get_logger(__name__) self.engine = get_trainer_wrapper(config) - self.explorer_ref = None + self.last_trainer_sync_step = 0 + self.synchronizer = Synchronizer.get_actor(config) def prepare(self) -> None: """Prepare the trainer.""" self.engine.prepare() + self.last_trainer_sync_step = self.engine.train_step_num def train(self) -> str: """Train the model.""" @@ -53,7 +56,17 @@ def train_step(self) -> bool: def need_sync(self) -> bool: """Whether to sync the model weight.""" - return self.engine.train_step_num % self.config.synchronizer.sync_interval == 0 + if self.config.synchronizer.sync_style == SyncStyle.FIXED: + return self.engine.train_step_num % self.config.synchronizer.sync_interval == 0 + else: + if self.config.synchronizer.sync_style == SyncStyle.DYNAMIC_BY_TRAINER: + delta = self.engine.train_step_num - self.last_trainer_sync_step + if delta >= self.config.synchronizer.sync_interval: + ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.WANT_SYNC)) + return ( + ray.get(self.synchronizer.get_explorer_status.remote()) + == RunningStatus.WAITING_SYNC + ) def sync_weight(self) -> None: """Sync the model weight.""" @@ -61,17 +74,13 @@ def sync_weight(self) -> None: self.logger.info( f"Trainer synchronizing weights at step {self.engine.train_step_num} starting.." ) - if self.explorer_ref is None: - self.explorer_ref = ray.get_actor(self.config.explorer.name) - explorer_status = ray.get(self.explorer_ref.running_status.remote()) - if explorer_status == RunningStatus.STOPPED: - self.logger.warning("Explorer has already stopped. Skipping sync weight.") - return - ray.get(self.explorer_ref.ready_to_sync.remote()) + assert ray.get(self.synchronizer.ready_to_sync.remote("trainer")) self.engine.sync_weight() self.logger.info( f"Trainer synchronizing weights at step {self.engine.train_step_num} end." ) + self.last_trainer_sync_step = self.engine.train_step_num + ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)) def shutdown(self) -> None: # if checkpoint not saved, save the last checkpoint From e2b149fdcf966c15cf7139266bb5d5abf064348d Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 16 Jul 2025 17:23:26 +0800 Subject: [PATCH 02/16] add config --- trinity/common/config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trinity/common/config.py b/trinity/common/config.py index 1e0bcc5e9d..d0907cc3e0 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -14,6 +14,7 @@ ReadStrategy, StorageType, SyncMethod, + SyncStyle, TaskType, ) from trinity.utils.log import get_logger @@ -370,6 +371,7 @@ class SynchronizerConfig: """Configs for model weight synchronization.""" sync_method: SyncMethod = SyncMethod.NCCL + sync_style: SyncStyle = SyncStyle.FIXED # sync weights every `sync_interval` steps sync_interval: int = 1 # allow explorer to run `sync_offset` steps before sync From 0b5a8f39a4ab0431a0c8e5202d622f9759349c5c Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Thu, 17 Jul 2025 12:11:32 +0800 Subject: [PATCH 03/16] rename `WANT_SYNC` to `REQUIRE_SYNC` --- trinity/common/constants.py | 2 +- trinity/common/synchronizer.py | 2 +- trinity/explorer/explorer.py | 23 ++++++++++------------- trinity/trainer/trainer.py | 9 +++++++-- 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/trinity/common/constants.py b/trinity/common/constants.py index 505eb15875..e66077cb1d 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -101,7 +101,7 @@ class RunningStatus(Enum): """Running status of explorer and trainer.""" RUNNING = "running" - WANT_SYNC = "want_sync" + REQUIRE_SYNC = "require_sync" WAITING_SYNC = "waiting_sync" STOPPED = "stopped" diff --git a/trinity/common/synchronizer.py b/trinity/common/synchronizer.py index bfb2277cbc..d8a742b60a 100644 --- a/trinity/common/synchronizer.py +++ b/trinity/common/synchronizer.py @@ -6,7 +6,7 @@ import ray from trinity.common.config import Config -from trinity.common.constants import RunningStatus, SyncStyle +from trinity.common.constants import RunningStatus class Synchronizer: diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index d9a77cfa84..786a0a46c2 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -80,7 +80,6 @@ def __init__(self, config: Config): self.state_dict = {} else: # nccl mode self.state_dict_meta = [] - self.status = RunningStatus.RUNNING self.logger.info("Finished initializing Explorer.") self._ready_to_sync_condition = asyncio.Condition() @@ -170,12 +169,6 @@ async def _nccl_weights_update(self): await asyncio.gather( *[model.sync_model.remote(self.explore_step_num) for model in self.models] ) - self.status = RunningStatus.RUNNING - - async def ready_to_sync(self): - async with self._ready_to_sync_condition: - self.status = RunningStatus.WAITING_SYNC - self._ready_to_sync_condition.notify_all() async def prepare(self) -> None: """Preparation before running.""" @@ -235,7 +228,6 @@ async def explore_step(self) -> bool: except StopIteration: self.logger.warning("No more tasks to explore. Stop exploring.") await self.save_checkpoint(sync_weight=False) - self.status = RunningStatus.STOPPED await self.experience_buffer.release() return False self.scheduler.schedule(tasks, batch_id=self.explore_step_num + 1) @@ -243,6 +235,11 @@ async def explore_step(self) -> bool: return True def need_sync(self) -> bool: + if self.use_checkpoint_weights_update: + pass + # need return in checkpoint mode + + # SyncMethod.NCCL if self.config.synchronizer.sync_style == SyncStyle.FIXED: if self.explore_step_num <= self.config.synchronizer.sync_offset: return False @@ -256,9 +253,11 @@ def need_sync(self) -> bool: if delta >= self.config.synchronizer.sync_interval: need_sync = True else: - need_sync = ray.get(self.synchronizer.get_trainer_status == RunningStatus.WANT_SYNC) + need_sync = ray.get( + self.synchronizer.get_trainer_status == RunningStatus.REQUIRE_SYNC + ) if need_sync: - ray.get(self.synchronizer.set_explorer_status.remote(RunningStatus.WANT_SYNC)) + ray.get(self.synchronizer.set_explorer_status.remote(RunningStatus.REQUIRE_SYNC)) return need_sync def need_eval(self) -> bool: @@ -343,6 +342,7 @@ async def sync_weight(self) -> None: """Synchronize model weights.""" # call this method before training start to load the latest model weights await self.save_checkpoint(sync_weight=True) + ray.get(self.synchronizer.set_explorer_status.remote(RunningStatus.RUNNING)) self.last_explorer_sync_step = self.explore_step_num async def _log_metrics(self, start_step: int, end_step: int) -> None: @@ -377,9 +377,6 @@ async def _log_eval_metrics(self, step: Optional[int] = None, prefix: str = "eva metric[f"{prefix}/total_time"] = time.time() - st self.monitor.log(metric, step) - async def running_status(self) -> RunningStatus: - return self.status - async def shutdown(self) -> None: self.monitor.close() await self.scheduler.stop() diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index c472c99823..99ad952220 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -56,13 +56,18 @@ def train_step(self) -> bool: def need_sync(self) -> bool: """Whether to sync the model weight.""" + if self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT: + ray.get() + return False + + # SyncMethod.NCCL if self.config.synchronizer.sync_style == SyncStyle.FIXED: return self.engine.train_step_num % self.config.synchronizer.sync_interval == 0 else: if self.config.synchronizer.sync_style == SyncStyle.DYNAMIC_BY_TRAINER: delta = self.engine.train_step_num - self.last_trainer_sync_step if delta >= self.config.synchronizer.sync_interval: - ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.WANT_SYNC)) + ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.REQUIRE_SYNC)) return ( ray.get(self.synchronizer.get_explorer_status.remote()) == RunningStatus.WAITING_SYNC @@ -80,7 +85,7 @@ def sync_weight(self) -> None: f"Trainer synchronizing weights at step {self.engine.train_step_num} end." ) self.last_trainer_sync_step = self.engine.train_step_num - ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)) + ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)) def shutdown(self) -> None: # if checkpoint not saved, save the last checkpoint From 6338b1e222864db79c62594f8a2cb47d61c547db Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Fri, 18 Jul 2025 15:00:31 +0800 Subject: [PATCH 04/16] refactor on checkpoint saving --- pyproject.toml | 2 +- trinity/common/config.py | 2 + trinity/common/constants.py | 1 + trinity/common/models/vllm_worker.py | 24 +- trinity/common/synchronizer.py | 62 ++++- trinity/explorer/explorer.py | 66 ++--- trinity/trainer/trainer.py | 10 +- .../trainer/verl/fsdp_checkpoint_manager.py | 226 ++++++++++++++++++ trinity/trainer/verl/fsdp_workers.py | 14 +- trinity/trainer/verl_trainer.py | 3 + 10 files changed, 344 insertions(+), 66 deletions(-) create mode 100644 trinity/trainer/verl/fsdp_checkpoint_manager.py diff --git a/pyproject.toml b/pyproject.toml index 61059b8c4e..1ec5da88d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ requires-python = ">=3.10" dependencies = [ "verl==0.4.0", "ray[default]>=2.45.0", - "vllm==0.9.1", + "vllm>=0.9.1", "tensordict==0.6.2", "wandb", "omegaconf", diff --git a/trinity/common/config.py b/trinity/common/config.py index d0907cc3e0..45b68549ce 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -383,6 +383,7 @@ class SynchronizerConfig: # ! DO NOT SET, automatically calculated explorer_world_size: Optional[int] = None + ray_namespace: str = "" @dataclass @@ -732,6 +733,7 @@ def check_and_update(self) -> None: # noqa: C901 self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens # check synchronizer + self.synchronizer.ray_namespace = self.ray_namespace self.synchronizer.explorer_world_size = ( self.explorer.rollout_model.engine_num * self.explorer.rollout_model.tensor_parallel_size diff --git a/trinity/common/constants.py b/trinity/common/constants.py index e66077cb1d..d4e0f549ee 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -95,6 +95,7 @@ class SyncMethod(CaseInsensitiveEnum, metaclass=SyncMethodEnumMeta): NCCL = "nccl" CHECKPOINT = "checkpoint" + STATE_DICT = "state_dict" class RunningStatus(Enum): diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 7509942176..7cfecda76d 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -4,6 +4,7 @@ import torch import torch.distributed +from trinity.common.synchronizer import Synchronizer from trinity.utils.distributed import init_process_group from trinity.utils.log import get_logger @@ -28,7 +29,8 @@ def init_process_group( """Init torch process group for model weights update""" assert torch.distributed.is_initialized(), "default torch process group must be initialized" assert group_name != "", "group name must not be empty" - self.set_state_dict_meta(state_dict_meta) + # self.set_state_dict_meta(state_dict_meta) + self._state_dict_meta = state_dict_meta self._update_with_checkpoint = update_with_checkpoint self._weight_update_rank = torch.distributed.get_rank() + rank_offset logger.info( @@ -51,21 +53,21 @@ def init_process_group( logger.info("vLLM init_process_group finished.") self._explorer_name = explorer_name self._namespace = namespace - self._explorer_actor = None - - def set_state_dict_meta(self, state_dict_meta): - self._state_dict_meta = state_dict_meta + self.synchronizer = Synchronizer.get_actor(namespace=self._namespace) def update_weight(self): """Broadcast weight to all vllm workers from source rank 0 (actor model)""" - assert self._state_dict_meta is not None - if self._explorer_actor is None: - self._explorer_actor = ray.get_actor( - name=self._explorer_name, namespace=self._namespace - ) + if self._state_dict_meta is None: + import time + + time.sleep(20) + print(f"Waiting for state dict meta............") + self._state_dict_meta = ray.get(self.synchronizer.get_state_dict_meta.remote()) + if self._weight_update_rank == 0: + state_dict, _ = ray.get(self.synchronizer.get_model_state_dict.remote()) for name, dtype_str, shape in self._state_dict_meta: if self._weight_update_rank == 0: - weight = ray.get(self._explorer_actor.get_weight.remote(name)) + weight = state_dict[name] weight = weight.to(self.device) else: dtype = getattr(torch, dtype_str.split(".")[-1]) diff --git a/trinity/common/synchronizer.py b/trinity/common/synchronizer.py index d8a742b60a..7a4e9fbe60 100644 --- a/trinity/common/synchronizer.py +++ b/trinity/common/synchronizer.py @@ -1,12 +1,17 @@ """ """ import asyncio -from typing import List +import os +from typing import List, Optional import ray from trinity.common.config import Config -from trinity.common.constants import RunningStatus +from trinity.common.constants import RunningStatus, SyncMethod +from trinity.common.models.utils import ( + get_checkpoint_dir_with_step_num, + load_state_dict, +) class Synchronizer: @@ -18,6 +23,8 @@ def __init__(self, config: Config): self.last_explorer_sync_step = 0 self.ready_count = 0 self._ready_condition = asyncio.Condition() + self.model_state_dict = None + self.state_dict_version = 0 def set_trainer_status(self, status: RunningStatus): self.trainer_status = status @@ -31,12 +38,49 @@ def get_trainer_status(self) -> RunningStatus: def get_explorer_status(self) -> RunningStatus: return self.explorer_status + async def set_model_state_dict_with_step_num(self, step_num: Optional[int] = None) -> int: + checkpoint_dir, checkpoint_step_num = get_checkpoint_dir_with_step_num( + checkpoint_root_path=self.config.checkpoint_job_dir, + trainer_type=self.config.trainer.trainer_type, + step_num=step_num, + ) + model_state_dict = load_state_dict(os.path.join(checkpoint_dir, "actor")) # TODO: async + await self.set_model_state_dict(model_state_dict, checkpoint_step_num) + return checkpoint_step_num + + async def set_model_state_dict(self, model_state_dict, trainer_step): + self.model_state_dict = model_state_dict + async with self._ready_condition: + self.state_dict_version = trainer_step + self._ready_condition.notify_all() + + def get_model_state_dict(self): + return self.model_state_dict, self.state_dict_version + + def get_state_dict_meta(self): + if self.model_state_dict is None: + return None + update_weight_args_list = [] + for name, param in self.model_state_dict.items(): + update_weight_args_list.append((name, str(param.dtype), tuple(param.shape))) + return update_weight_args_list + async def setup_weight_sync_group( self, master_address: str, master_port: int, state_dict_meta: List = None ): explorer = ray.get_actor(self.config.explorer_name) await explorer.setup_weight_sync_group.remote(master_address, master_port, state_dict_meta) + async def wait_new_model_state_dict(self, current_version: int) -> int: + # wait for the new model state dict; return new version + # pass + async with self._ready_condition: + if self.state_dict_version <= current_version: + await self._ready_condition.wait_for( + lambda: self.state_dict_version > current_version + ) + return self.state_dict_version + async def ready_to_sync(self, module: str): async with self._ready_condition: try: @@ -69,9 +113,11 @@ async def ready_to_sync(self, module: str): return False @classmethod - def get_actor(cls, config: Config): - return ( - ray.remote(cls) - .options(name="synchronizer", namespace=config.ray_namespace, get_if_exists=True) - .remote(config) - ) + def get_actor(cls, config: Optional[Config] = None, namespace: Optional[str] = None): + if config is not None: + return ( + ray.remote(cls) + .options(name="synchronizer", namespace=config.ray_namespace, get_if_exists=True) + .remote(config) + ) + return ray.get_actor("synchronizer", namespace=namespace) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 786a0a46c2..82cb0a226c 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -43,6 +43,7 @@ def __init__(self, config: Config): explorer_meta = self.cache.load_explorer() self.explore_step_num = explorer_meta.get("latest_iteration", 0) self.last_sync_step = self.explore_step_num if self.explore_step_num > 0 else -1 + self.synchronizer = Synchronizer.get_actor(config) self.config = config self.algorithm_manager = AlgorithmManager(config) self.models, self.auxiliary_models = create_inference_models(config) @@ -67,15 +68,14 @@ def __init__(self, config: Config): self.update_interval = ( self.config.synchronizer.sync_interval * self.config.buffer.batch_size ) - self.use_checkpoint_weights_update = ( - self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT - ) + self.use_state_dict_weights_update = self.config.synchronizer.sync_method != SyncMethod.NCCL self.pending_eval_tasks = deque() # For checkpoint weights update # Use explorer to periodically load the latest model weights and # boradcast to all rollout models - if self.use_checkpoint_weights_update: + self.state_dict_version = 0 + if self.use_state_dict_weights_update: self.old_checkpoint = None self.state_dict = {} else: # nccl mode @@ -83,14 +83,11 @@ def __init__(self, config: Config): self.logger.info("Finished initializing Explorer.") self._ready_to_sync_condition = asyncio.Condition() - self.synchronizer = Synchronizer.get_actor(config) - self.last_explorer_sync_step = self.explore_step_num - async def setup_weight_sync_group( self, master_address: str, master_port: int, state_dict_meta: List = None ): # In checkpoint mode, we use explorer to store the model weights which has no rank - base_offset = 0 if self.use_checkpoint_weights_update else 1 + base_offset = 0 if self.use_state_dict_weights_update else 1 world_size = ( len(self.models) * self.config.explorer.rollout_model.tensor_parallel_size + base_offset ) @@ -111,7 +108,7 @@ async def setup_weight_sync_group( group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME, explorer_name=self.config.explorer.name, timeout=self.config.synchronizer.sync_timeout, - update_with_checkpoint=self.use_checkpoint_weights_update, + update_with_checkpoint=self.use_state_dict_weights_update, state_dict_meta=state_dict_meta, ) for i, model in enumerate(self.models) @@ -143,22 +140,17 @@ async def _update_model_weight(self, step_num: int, state_dict: dict) -> None: self.state_dict.clear() async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> int: - # TODO: support more checkpoint types - try: - checkpoint_dir, checkpoint_step_num = get_checkpoint_dir_with_step_num( - checkpoint_root_path=self.config.checkpoint_job_dir, - trainer_type=self.config.trainer.trainer_type, - step_num=step_num, - ) - if checkpoint_dir == self.old_checkpoint: - return checkpoint_step_num - model_weights = load_state_dict(os.path.join(checkpoint_dir, "actor")) - await self._update_model_weight(checkpoint_step_num, model_weights) - self.old_checkpoint = checkpoint_dir - return checkpoint_step_num - except Exception as e: - self.logger.warning(f"Fail to load checkpoint: {e}") - return 0 + step_num = ray.get(self.synchronizer.set_model_state_dict_with_step_num.remote(step_num)) + await asyncio.gather(*[model.sync_model.remote(step_num) for model in self.models]) + return step_num + + async def _state_dict_update(self): + self.state_dict_version = ray.get( + self.synchronizer.wait_new_model_state_dict.remote(self.state_dict_version) + ) + await asyncio.gather( + *[model.sync_model.remote(self.state_dict_version) for model in self.models] + ) async def _nccl_weights_update(self): assert self.state_dict_meta is not None @@ -173,7 +165,7 @@ async def _nccl_weights_update(self): async def prepare(self) -> None: """Preparation before running.""" futures = [asyncio.create_task(self.scheduler.start())] - if self.use_checkpoint_weights_update: + if self.use_state_dict_weights_update: master_address, master_port = await self.models[0].get_available_address.remote() futures.append( asyncio.create_task(self.setup_weight_sync_group(master_address, master_port)) @@ -235,11 +227,6 @@ async def explore_step(self) -> bool: return True def need_sync(self) -> bool: - if self.use_checkpoint_weights_update: - pass - # need return in checkpoint mode - - # SyncMethod.NCCL if self.config.synchronizer.sync_style == SyncStyle.FIXED: if self.explore_step_num <= self.config.synchronizer.sync_offset: return False @@ -247,18 +234,18 @@ def need_sync(self) -> bool: self.explore_step_num - self.config.synchronizer.sync_offset ) % self.config.synchronizer.sync_interval == 0 else: - need_sync = False + require_sync = False if self.config.synchronizer.sync_style == SyncStyle.DYNAMIC_BY_EXPLORER: - delta = self.explore_step_num - self.last_explorer_sync_step + delta = self.explore_step_num - self.last_sync_step if delta >= self.config.synchronizer.sync_interval: - need_sync = True + require_sync = True else: - need_sync = ray.get( + require_sync = ray.get( self.synchronizer.get_trainer_status == RunningStatus.REQUIRE_SYNC ) - if need_sync: + if require_sync: ray.get(self.synchronizer.set_explorer_status.remote(RunningStatus.REQUIRE_SYNC)) - return need_sync + return require_sync def need_eval(self) -> bool: return self.explore_step_num % self.config.explorer.eval_interval == 0 @@ -322,8 +309,8 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None: if sync_weight: # sync weights self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} started.") - if self.use_checkpoint_weights_update: - await self._checkpoint_weights_update() + if self.use_state_dict_weights_update: + await self._state_dict_update() else: # nccl weights update await self._nccl_weights_update() self.last_sync_step = self.explore_step_num @@ -343,7 +330,6 @@ async def sync_weight(self) -> None: # call this method before training start to load the latest model weights await self.save_checkpoint(sync_weight=True) ray.get(self.synchronizer.set_explorer_status.remote(RunningStatus.RUNNING)) - self.last_explorer_sync_step = self.explore_step_num async def _log_metrics(self, start_step: int, end_step: int) -> None: for step in range(start_step, end_step + 1): diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 99ad952220..bc7529fc93 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -22,9 +22,9 @@ class Trainer: def __init__(self, config: Config) -> None: self.config = config self.logger = get_logger(__name__) + self.synchronizer = Synchronizer.get_actor(config) self.engine = get_trainer_wrapper(config) self.last_trainer_sync_step = 0 - self.synchronizer = Synchronizer.get_actor(config) def prepare(self) -> None: """Prepare the trainer.""" @@ -85,6 +85,10 @@ def sync_weight(self) -> None: f"Trainer synchronizing weights at step {self.engine.train_step_num} end." ) self.last_trainer_sync_step = self.engine.train_step_num + elif self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT: + pass + elif self.config.synchronizer.sync_method == SyncMethod.STATE_DICT: + self.engine.upload_state_dict() ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)) def shutdown(self) -> None: @@ -120,6 +124,10 @@ def save_checkpoint(self) -> None: def sync_weight(self) -> None: """Sync the model weight.""" + @abstractmethod + def upload_state_dict(self) -> None: + """Upload the state dict to Synchronizer.""" + @abstractmethod def shutdown(self) -> None: """Shutdown the engine.""" diff --git a/trinity/trainer/verl/fsdp_checkpoint_manager.py b/trinity/trainer/verl/fsdp_checkpoint_manager.py new file mode 100644 index 0000000000..05c6aff950 --- /dev/null +++ b/trinity/trainer/verl/fsdp_checkpoint_manager.py @@ -0,0 +1,226 @@ +import os +import threading +import warnings +from typing import Optional + +import ray +import torch +from accelerate import init_empty_weights +from torch.distributed.fsdp import FullStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ( + ShardedOptimStateDictConfig, + ShardedStateDictConfig, + StateDictType, +) +from transformers import GenerationConfig, PreTrainedTokenizer, ProcessorMixin +from verl.utils.checkpoint.fsdp_checkpoint_manager import ( + FSDPCheckpointManager as OldFSDPCheckpointManager, +) +from verl.utils.device import is_cuda_available +from verl.utils.fsdp_utils import fsdp_version, get_fsdp_state_ctx + +from trinity.common.synchronizer import Synchronizer + + +class FSDPCheckpointManager(OldFSDPCheckpointManager): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + config = kwargs.pop("config", None) + if config is not None: + self.synchronizer = Synchronizer.get_actor(namespace=config.ray_namespace) + else: + self.synchronizer = None + self._model_state_dict_thread = None + self._optimizer_state_dict_thread = None + self._extra_state_dict_thread = None + self._save_model_thread = None + + def upload_state_dict(self, trainer_step: int): + assert self.synchronizer is not None + state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with get_fsdp_state_ctx(self.model, StateDictType.FULL_STATE_DICT, state_dict_config, None): + state_dict = self.model.state_dict() + if self.rank == 0: + ray.get(self.synchronizer.set_model_state_dict.remote(state_dict, trainer_step)) + + def save_checkpoint( + self, + local_path: str, + hdfs_path: str = None, + global_step: int = 0, + max_ckpt_to_keep: Optional[int] = None, + model_state_dict_only: bool = False, + ): + """ + modified from verl.utils.checkpoint.fsdp_checkpoint_manager.py:save_checkpoint + """ + if local_path is None: + return + + # record the previous global step + self.previous_global_step = global_step + + # remove previous local_path + if ( + max_ckpt_to_keep + and isinstance(max_ckpt_to_keep, int) + and max_ckpt_to_keep > 0 + and len(self.previous_saved_paths) >= max_ckpt_to_keep + ): + keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 + self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) + self.previous_saved_paths = self.previous_saved_paths[keep_start:] + + local_path = self.local_mkdir(local_path) + torch.distributed.barrier() + + # every rank will save its own model and optim shard + state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + with get_fsdp_state_ctx( + self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg + ): + model_state_dict = self.model.state_dict() + optimizer_state_dict = ( + self.optimizer.state_dict() + if self.optimizer is not None and not model_state_dict_only + else None + ) + lr_scheduler_state_dict = ( + self.lr_scheduler.state_dict() + if self.lr_scheduler is not None and not model_state_dict_only + else None + ) + + extra_state_dict = { + "lr_scheduler": lr_scheduler_state_dict, + "rng": self.get_rng_state(), + } + model_path = os.path.join( + local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt" + ) + optim_path = os.path.join( + local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt" + ) + extra_path = os.path.join( + local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt" + ) + + print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}") + # torch.save(model_state_dict, model_path) + # torch.save(optimizer_state_dict, optim_path) # TODO: address optimizer is None + # torch.save(extra_state_dict, extra_path) + if self._model_state_dict_thread is not None: + self._model_state_dict_thread.join() + self._model_state_dict_thread = threading.Thread( + target=torch.save, + args=(model_state_dict, model_path), + ) + self._model_state_dict_thread.start() + + print(f"[rank-{self.rank}]: Saving optim to {os.path.abspath(optim_path)}") + if self._optimizer_state_dict_thread is not None: + self._optimizer_state_dict_thread.join() + self._optimizer_state_dict_thread = threading.Thread( + target=torch.save, + args=(optimizer_state_dict, optim_path), + ) + self._optimizer_state_dict_thread.start() + + print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}") + if self._extra_state_dict_thread is not None: + self._extra_state_dict_thread.join() + self._extra_state_dict_thread = threading.Thread( + target=torch.save, + args=(extra_state_dict, extra_path), + ) + self._extra_state_dict_thread.start() + + if self.rank == 0: + if fsdp_version(self.model) == 1: + unwrap_model = self.model._fsdp_wrapped_module + else: + unwrap_model = self.model + + model_config = unwrap_model.config + if ( + unwrap_model.can_generate() + and hasattr(model_config, "name_or_path") + and model_config.name_or_path + ): + # Some model's name_or_path is empty if not initialized from pretrained, + # in this cases, we don't save generation config. + generation_config = GenerationConfig.from_pretrained(model_config.name_or_path) + generation_config.save_pretrained(local_path) + else: + generation_config = None + + model_config.save_pretrained(local_path) + self.processing_class.save_pretrained(local_path) + + # wait for everyone to dump to local + torch.distributed.barrier() + + if "hf_model" in self.checkpoint_contents: + hf_local_path = os.path.join(local_path, "huggingface") + os.makedirs(hf_local_path, exist_ok=True) + + # Only rank 0 will save hf model and, + # offload to cpu to save LLMs which may be too large to fit in one GPU + state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with get_fsdp_state_ctx( + self.model, StateDictType.FULL_STATE_DICT, state_dict_config, None + ): + state_dict = self.model.state_dict() + + if self.rank == 0: + if "ForTokenClassification" in model_config.architectures[0]: + from transformers import AutoModelForTokenClassification + + auto_model_cls = AutoModelForTokenClassification + elif "ForCausalLM" in model_config.architectures[0]: + from transformers import AutoModelForCausalLM + + auto_model_cls = AutoModelForCausalLM + elif "ForConditionalGeneration" in model_config.architectures[0]: + from transformers import AutoModelForVision2Seq + + auto_model_cls = AutoModelForVision2Seq + else: + raise NotImplementedError( + f"Unknown architecture {model_config['architectures']}" + ) + + with init_empty_weights(): + save_model = auto_model_cls.from_config( + model_config, torch_dtype=torch.bfloat16 + ) + save_model.to_empty(device="cpu") + + if save_model.can_generate(): + if generation_config is not None: + save_model.generation_config = generation_config + else: + print( + f"Warning: {self.__class__.__name__}.save_checkpoint: Generation config file not found in, using a generation config created from the model config when saving hf_model." + ) + + # save_model.save_pretrained(hf_local_path, state_dict=state_dict) + if self._save_model_thread is not None: + self._save_model_thread.join() + self._save_model_thread = threading.Thread( + target=save_model.save_pretrained, + args=(hf_local_path, state_dict), + ) + self._save_model_thread.start() + self.processing_class.save_pretrained(hf_local_path) + del state_dict + del save_model + + # wait for rank0 to dump hf_model to local + torch.distributed.barrier() + + self.previous_saved_paths.append(local_path) diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 8e0bfe314d..4d29ca6763 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -42,7 +42,8 @@ from verl.single_controller.base.decorator import Dispatch, register from verl.utils import hf_processor, hf_tokenizer from verl.utils.activation_offload import enable_activation_offloading -from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager + +# from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.utils.debug import log_gpu_memory_usage from verl.utils.device import get_torch_device, is_cuda_available from verl.utils.flops_counter import FlopsCounter @@ -73,6 +74,7 @@ from trinity.common.config import AlgorithmConfig from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod +from trinity.trainer.verl.fsdp_checkpoint_manager import FSDPCheckpointManager from trinity.utils.distributed import init_process_group logger = logging.getLogger(__file__) @@ -541,14 +543,12 @@ def init_model(self): lr_scheduler=self.actor_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_contents=self.config.actor.checkpoint.contents, + config=self.config.synchronizer, ) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def setup_weight_sync_group(self): - if ( - hasattr(self.config, "synchronizer") - and getattr(self.config.synchronizer, "sync_method", None) == SyncMethod.NCCL - ): + if self.config.synchronizer.sync_method == SyncMethod.NCCL: model = self.actor_module_fsdp self.named_modules = [] self.state_dict_meta = [] @@ -609,6 +609,10 @@ def sync_weight(self): torch.distributed.barrier() torch.cuda.empty_cache() + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def upload_state_dict(self, trainer_step: int): + self.checkpoint_manager.upload_state_dict(trainer_step) + @register(dispatch_mode=Dispatch.ONE_TO_ALL) def set_algorithm(self, algo_config: AlgorithmConfig): self.actor.set_algorithm(algo_config) diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 2a8b1d0135..8b88d24288 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -287,6 +287,9 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl # TODO: compute total training steps self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize + def upload_state_dict(self): + self.actor_rollout_wg.upload_state_dict(self.global_steps) + def train_step(self) -> bool: # noqa C901 self.logger.info(f"Training at step {self.global_steps + 1} started.") metrics = {} From bbd8f70c99b29d78462fd8977e90fbf104fb3dcf Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 21 Jul 2025 11:48:56 +0800 Subject: [PATCH 05/16] Add `state_dict` and `checkpoint` update method to `synchronizer`; Add `group` to `config`. --- trinity/common/config.py | 11 +- trinity/common/models/vllm_model.py | 8 +- trinity/common/models/vllm_worker.py | 5 - trinity/common/synchronizer.py | 29 ++- trinity/common/verl_config.py | 2 + trinity/explorer/explorer.py | 23 +- trinity/trainer/trainer.py | 12 +- .../trainer/verl/fsdp_checkpoint_manager.py | 227 ++++++++++++------ trinity/trainer/verl/fsdp_workers.py | 13 +- trinity/trainer/verl_trainer.py | 48 ++-- trinity/utils/monitor.py | 12 +- 11 files changed, 248 insertions(+), 142 deletions(-) diff --git a/trinity/common/config.py b/trinity/common/config.py index e1de90ef51..850d959aae 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -398,6 +398,7 @@ class Config: mode: str = "both" # `explore`, `train`, `both` or `bench` project: str = "Trinity-RFT" + group: str = "" name: str = "rft" # the root dir for checkpoints checkpoint_root_dir: str = "" @@ -439,16 +440,6 @@ def _check_interval(self) -> None: f"`eval_interval` is not a multiple of `sync_interval`; adjusted to the nearest integer={self.explorer.eval_interval}." ) - # check save_interval - if self.synchronizer.sync_method == SyncMethod.CHECKPOINT: - if self.trainer.save_interval != self.synchronizer.sync_interval: - logger.warning( - f"When `algorithm.algorithm_type` != `dpo` and `synchronizer.sync_method` == `checkpoint`, " - f"`trainer.save_interval` will be set to " - f"`synchronizer.sync_interval = {self.synchronizer.sync_interval}`." - ) - self.trainer.save_interval = self.synchronizer.sync_interval - def _check_buffer(self) -> None: # noqa: C901 # TODO: split this function into different buffer read/writer # check explorer_input diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 01b8135511..0144901e42 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -3,7 +3,7 @@ import os import re -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union import aiohttp import ray @@ -263,12 +263,8 @@ async def _collective_rpc( method, timeout, args, kwargs ) - async def sync_model( - self, model_version: int, update_weight_args_list: Optional[List[Tuple]] = None - ) -> bool: + async def sync_model(self, model_version: int) -> bool: """Sync model weights to vLLM.""" - if update_weight_args_list is not None: - await self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,)) await self._collective_rpc("update_weight") self.logger.info("Sync model weights to vLLM successfully.") self.model_version = model_version diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 7cfecda76d..0688544e81 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -29,7 +29,6 @@ def init_process_group( """Init torch process group for model weights update""" assert torch.distributed.is_initialized(), "default torch process group must be initialized" assert group_name != "", "group name must not be empty" - # self.set_state_dict_meta(state_dict_meta) self._state_dict_meta = state_dict_meta self._update_with_checkpoint = update_with_checkpoint self._weight_update_rank = torch.distributed.get_rank() + rank_offset @@ -58,10 +57,6 @@ def init_process_group( def update_weight(self): """Broadcast weight to all vllm workers from source rank 0 (actor model)""" if self._state_dict_meta is None: - import time - - time.sleep(20) - print(f"Waiting for state dict meta............") self._state_dict_meta = ray.get(self.synchronizer.get_state_dict_meta.remote()) if self._weight_update_rank == 0: state_dict, _ = ray.get(self.synchronizer.get_model_state_dict.remote()) diff --git a/trinity/common/synchronizer.py b/trinity/common/synchronizer.py index 7a4e9fbe60..f801981902 100644 --- a/trinity/common/synchronizer.py +++ b/trinity/common/synchronizer.py @@ -2,29 +2,32 @@ import asyncio import os +from collections import defaultdict from typing import List, Optional import ray from trinity.common.config import Config -from trinity.common.constants import RunningStatus, SyncMethod +from trinity.common.constants import RunningStatus from trinity.common.models.utils import ( get_checkpoint_dir_with_step_num, load_state_dict, ) +from trinity.utils.log import get_logger class Synchronizer: def __init__(self, config: Config): + self.logger = get_logger(__name__) self.config = config self.trainer_status = RunningStatus.RUNNING self.last_trainer_sync_step = 0 self.explorer_status = RunningStatus.RUNNING self.last_explorer_sync_step = 0 - self.ready_count = 0 self._ready_condition = asyncio.Condition() self.model_state_dict = None self.state_dict_version = 0 + self.checkpoint_shard_count_dict = defaultdict(lambda: 0) def set_trainer_status(self, status: RunningStatus): self.trainer_status = status @@ -38,13 +41,24 @@ def get_trainer_status(self) -> RunningStatus: def get_explorer_status(self) -> RunningStatus: return self.explorer_status - async def set_model_state_dict_with_step_num(self, step_num: Optional[int] = None) -> int: + async def set_model_state_dict_with_step_num( + self, step_num: Optional[int] = None, world_size: Optional[int] = None + ) -> int: + if world_size is not None: # Used for trainer to update model + assert step_num is not None + self.checkpoint_shard_count_dict[step_num] += 1 + self.logger.info( + f"Synchronizer received checkpoint {self.checkpoint_shard_count_dict[step_num]} of {world_size} shards" + ) + if self.checkpoint_shard_count_dict[step_num] < world_size: + return step_num + checkpoint_dir, checkpoint_step_num = get_checkpoint_dir_with_step_num( checkpoint_root_path=self.config.checkpoint_job_dir, trainer_type=self.config.trainer.trainer_type, step_num=step_num, ) - model_state_dict = load_state_dict(os.path.join(checkpoint_dir, "actor")) # TODO: async + model_state_dict = load_state_dict(os.path.join(checkpoint_dir, "actor")) # TODO: to thread await self.set_model_state_dict(model_state_dict, checkpoint_step_num) return checkpoint_step_num @@ -52,6 +66,7 @@ async def set_model_state_dict(self, model_state_dict, trainer_step): self.model_state_dict = model_state_dict async with self._ready_condition: self.state_dict_version = trainer_step + self.logger.info(f"Set model state dict version to {trainer_step}.") self._ready_condition.notify_all() def get_model_state_dict(self): @@ -73,11 +88,11 @@ async def setup_weight_sync_group( async def wait_new_model_state_dict(self, current_version: int) -> int: # wait for the new model state dict; return new version - # pass async with self._ready_condition: if self.state_dict_version <= current_version: - await self._ready_condition.wait_for( - lambda: self.state_dict_version > current_version + await asyncio.wait_for( + self._ready_condition.wait(), + timeout=self.config.synchronizer.sync_timeout, ) return self.state_dict_version diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 1bffa70635..35f46d3f31 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -224,6 +224,7 @@ class Trainer: total_epochs: int = 30 total_training_steps: Optional[int] = None project_name: str = "" + group_name: str = "" experiment_name: str = "" logger: List[str] = field(default_factory=list) val_generations_to_log_to_wandb: int = 0 @@ -300,6 +301,7 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.trainer.sync_freq = config.synchronizer.sync_interval self.trainer.save_freq = config.trainer.save_interval self.trainer.project_name = config.project + self.trainer.group_name = config.group self.trainer.experiment_name = config.name self.trainer.default_local_dir = config.checkpoint_job_dir self.trainer.sft_warmup_steps = config.buffer.trainer_input.sft_warmup_steps diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 82cb0a226c..40cf4a8738 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -23,10 +23,6 @@ SyncStyle, ) from trinity.common.models import create_inference_models -from trinity.common.models.utils import ( - get_checkpoint_dir_with_step_num, - load_state_dict, -) from trinity.common.synchronizer import Synchronizer from trinity.explorer.scheduler import Scheduler from trinity.manager.manager import CacheManager @@ -60,6 +56,7 @@ def __init__(self, config: Config): self.scheduler = self._init_scheduler() self.monitor = MONITOR.get(self.config.monitor.monitor_type)( project=self.config.project, + group=self.config.group, name=self.config.name, role=self.config.explorer.name, config=config, @@ -142,15 +139,23 @@ async def _update_model_weight(self, step_num: int, state_dict: dict) -> None: async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> int: step_num = ray.get(self.synchronizer.set_model_state_dict_with_step_num.remote(step_num)) await asyncio.gather(*[model.sync_model.remote(step_num) for model in self.models]) - return step_num + return step_num # type: ignore async def _state_dict_update(self): - self.state_dict_version = ray.get( + self.logger.info("Start to update state dict.") + new_version = ray.get( self.synchronizer.wait_new_model_state_dict.remote(self.state_dict_version) ) - await asyncio.gather( - *[model.sync_model.remote(self.state_dict_version) for model in self.models] - ) + if new_version > self.state_dict_version: + self.logger.info(f"New model state dict version: {new_version}") + await asyncio.gather( + *[model.sync_model.remote(self.state_dict_version) for model in self.models] + ) + self.state_dict_version = new_version + else: + self.logger.warning( + f"No new model state dict found, current version: {self.state_dict_version}" + ) async def _nccl_weights_update(self): assert self.state_dict_meta is not None diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index bc7529fc93..db3cb6cd50 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -56,11 +56,6 @@ def train_step(self) -> bool: def need_sync(self) -> bool: """Whether to sync the model weight.""" - if self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT: - ray.get() - return False - - # SyncMethod.NCCL if self.config.synchronizer.sync_style == SyncStyle.FIXED: return self.engine.train_step_num % self.config.synchronizer.sync_interval == 0 else: @@ -86,7 +81,8 @@ def sync_weight(self) -> None: ) self.last_trainer_sync_step = self.engine.train_step_num elif self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT: - pass + self.engine.save_state_dict() + # ray.get(self.synchronizer.set_model_state_dict_with_step_num.remote(self.engine.train_step_num)) elif self.config.synchronizer.sync_method == SyncMethod.STATE_DICT: self.engine.upload_state_dict() ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)) @@ -128,6 +124,10 @@ def sync_weight(self) -> None: def upload_state_dict(self) -> None: """Upload the state dict to Synchronizer.""" + @abstractmethod + def save_state_dict(self) -> None: + """Only save the model state dict for Synchronizer.""" + @abstractmethod def shutdown(self) -> None: """Shutdown the engine.""" diff --git a/trinity/trainer/verl/fsdp_checkpoint_manager.py b/trinity/trainer/verl/fsdp_checkpoint_manager.py index 05c6aff950..e1d9ac5bea 100644 --- a/trinity/trainer/verl/fsdp_checkpoint_manager.py +++ b/trinity/trainer/verl/fsdp_checkpoint_manager.py @@ -1,24 +1,32 @@ +import json import os import threading import warnings +from dataclasses import asdict from typing import Optional import ray import torch from accelerate import init_empty_weights -from torch.distributed.fsdp import FullStateDictConfig -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import ( + FullStateDictConfig, ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictType, ) -from transformers import GenerationConfig, PreTrainedTokenizer, ProcessorMixin +from transformers import GenerationConfig from verl.utils.checkpoint.fsdp_checkpoint_manager import ( FSDPCheckpointManager as OldFSDPCheckpointManager, ) +from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPConfig, logger from verl.utils.device import is_cuda_available -from verl.utils.fsdp_utils import fsdp_version, get_fsdp_state_ctx +from verl.utils.fs import local_mkdir_safe +from verl.utils.fsdp_utils import ( + fsdp_version, + get_fsdp_full_state_dict, + get_fsdp_state_ctx, +) +from verl.utils.logger import log_with_rank from trinity.common.synchronizer import Synchronizer @@ -44,7 +52,7 @@ def upload_state_dict(self, trainer_step: int): if self.rank == 0: ray.get(self.synchronizer.set_model_state_dict.remote(state_dict, trainer_step)) - def save_checkpoint( + def save_checkpoint( # noqa: C901 self, local_path: str, hdfs_path: str = None, @@ -61,20 +69,31 @@ def save_checkpoint( # record the previous global step self.previous_global_step = global_step - # remove previous local_path + # remove previous local_path, only rank 0 should do this if ( - max_ckpt_to_keep + self.rank == 0 + and max_ckpt_to_keep and isinstance(max_ckpt_to_keep, int) and max_ckpt_to_keep > 0 - and len(self.previous_saved_paths) >= max_ckpt_to_keep + and len(self.previous_saved_paths) >= max_ckpt_to_keep # type: ignore ): - keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 - self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) - self.previous_saved_paths = self.previous_saved_paths[keep_start:] + keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 # type: ignore + self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) # type: ignore + self.previous_saved_paths = self.previous_saved_paths[keep_start:] # type: ignore - local_path = self.local_mkdir(local_path) + local_path = local_mkdir_safe(local_path) torch.distributed.barrier() + # check if the checkpoint_save_contents is valid + if self.should_save_model: + assert ( + self.model is not None + ), "model must be provided when checkpoint_contents.save includes ['model']" + if self.should_save_optimizer: + assert ( + self.optimizer is not None + ), "optimizer must be provided when checkpoint_contents.save includes ['optimizer']" + # every rank will save its own model and optim shard state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) @@ -83,22 +102,6 @@ def save_checkpoint( with get_fsdp_state_ctx( self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg ): - model_state_dict = self.model.state_dict() - optimizer_state_dict = ( - self.optimizer.state_dict() - if self.optimizer is not None and not model_state_dict_only - else None - ) - lr_scheduler_state_dict = ( - self.lr_scheduler.state_dict() - if self.lr_scheduler is not None and not model_state_dict_only - else None - ) - - extra_state_dict = { - "lr_scheduler": lr_scheduler_state_dict, - "rng": self.get_rng_state(), - } model_path = os.path.join( local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt" ) @@ -109,42 +112,94 @@ def save_checkpoint( local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt" ) - print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}") - # torch.save(model_state_dict, model_path) - # torch.save(optimizer_state_dict, optim_path) # TODO: address optimizer is None - # torch.save(extra_state_dict, extra_path) - if self._model_state_dict_thread is not None: - self._model_state_dict_thread.join() - self._model_state_dict_thread = threading.Thread( - target=torch.save, - args=(model_state_dict, model_path), - ) - self._model_state_dict_thread.start() - - print(f"[rank-{self.rank}]: Saving optim to {os.path.abspath(optim_path)}") - if self._optimizer_state_dict_thread is not None: - self._optimizer_state_dict_thread.join() - self._optimizer_state_dict_thread = threading.Thread( - target=torch.save, - args=(optimizer_state_dict, optim_path), - ) - self._optimizer_state_dict_thread.start() - - print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}") - if self._extra_state_dict_thread is not None: - self._extra_state_dict_thread.join() - self._extra_state_dict_thread = threading.Thread( - target=torch.save, - args=(extra_state_dict, extra_path), - ) - self._extra_state_dict_thread.start() + if self.should_save_model: + model_state_dict = self.model.state_dict() + # torch.save(model_state_dict, model_path) + if self._model_state_dict_thread is not None: + self._model_state_dict_thread.join() + + def _save_model_state_dict(): + torch.save(model_state_dict, model_path) + log_with_rank( + f"Saved model to {os.path.abspath(model_path)}", + rank=self.rank, + logger=logger, + ) + ray.get( + self.synchronizer.set_model_state_dict_with_step_num.remote( + global_step, self.world_size + ) + ) + + self._model_state_dict_thread = threading.Thread( + target=_save_model_state_dict, + # target=torch.save, + # args=(model_state_dict, model_path), + ) + self._model_state_dict_thread.start() + # log_with_rank(f"Saved model to {os.path.abspath(model_path)}", rank=self.rank, logger=logger) + + if self.should_save_optimizer and not model_state_dict_only: + optimizer_state_dict = self.optimizer.state_dict() + # torch.save(optimizer_state_dict, optim_path) + if self._optimizer_state_dict_thread is not None: + self._optimizer_state_dict_thread.join() + + def _save_optimizer_state_dict(): + torch.save(optimizer_state_dict, optim_path) + log_with_rank( + f"Saved optim to {os.path.abspath(optim_path)}", + rank=self.rank, + logger=logger, + ) + + self._optimizer_state_dict_thread = threading.Thread( + target=_save_optimizer_state_dict, + # target=torch.save, + # args=(optimizer_state_dict, optim_path), + ) + self._optimizer_state_dict_thread.start() + # log_with_rank(f"Saved optim to {os.path.abspath(optim_path)}", rank=self.rank, logger=logger) + + if self.should_save_extra and not model_state_dict_only: + lr_scheduler_state_dict = ( + self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None + ) + extra_state_dict = { + "lr_scheduler": lr_scheduler_state_dict, + "rng": self.get_rng_state(), + } + # torch.save(extra_state_dict, extra_path) + if self._extra_state_dict_thread is not None: + self._extra_state_dict_thread.join() + + def _save_extra_state_dict(): + torch.save(extra_state_dict, extra_path) + log_with_rank( + f"Saved extra_state to {os.path.abspath(extra_path)}", + rank=self.rank, + logger=logger, + ) + + self._extra_state_dict_thread = threading.Thread( + target=_save_extra_state_dict, + # target=torch.save, + # args=(extra_state_dict, extra_path), + ) + self._extra_state_dict_thread.start() + # log_with_rank(f"Saved extra_state to {os.path.abspath(extra_path)}", rank=self.rank, logger=logger) if self.rank == 0: + # Save HF tokenizer/processor and model config on rank 0 to huggingface/ directory, no matter whether + # huggingface model is requested to be saved or not. + if fsdp_version(self.model) == 1: unwrap_model = self.model._fsdp_wrapped_module else: unwrap_model = self.model + hf_config_tokenizer_path = os.path.join(local_path, "huggingface") + local_mkdir_safe(hf_config_tokenizer_path) model_config = unwrap_model.config if ( unwrap_model.can_generate() @@ -154,29 +209,40 @@ def save_checkpoint( # Some model's name_or_path is empty if not initialized from pretrained, # in this cases, we don't save generation config. generation_config = GenerationConfig.from_pretrained(model_config.name_or_path) - generation_config.save_pretrained(local_path) + generation_config.save_pretrained(hf_config_tokenizer_path) else: generation_config = None - model_config.save_pretrained(local_path) - self.processing_class.save_pretrained(local_path) + model_config.save_pretrained(hf_config_tokenizer_path) + self.processing_class.save_pretrained(hf_config_tokenizer_path) + log_with_rank( + f"Saved model config and tokenizer class to {os.path.abspath(hf_config_tokenizer_path)}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) + + # Also save runtime FSDP config + fsdp_config_path = os.path.join(local_path, "fsdp_config.json") + fsdp_config = FSDPConfig( + FSDP_version=fsdp_version(self.model), + world_size=self.world_size, + ) + with open(fsdp_config_path, "w") as f: + json.dump(asdict(fsdp_config), f, indent=4) # wait for everyone to dump to local torch.distributed.barrier() - if "hf_model" in self.checkpoint_contents: - hf_local_path = os.path.join(local_path, "huggingface") - os.makedirs(hf_local_path, exist_ok=True) - + if self.should_save_hf_model and not model_state_dict_only: # Only rank 0 will save hf model and, # offload to cpu to save LLMs which may be too large to fit in one GPU - state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - with get_fsdp_state_ctx( - self.model, StateDictType.FULL_STATE_DICT, state_dict_config, None - ): - state_dict = self.model.state_dict() + state_dict = get_fsdp_full_state_dict(self.model, offload_to_cpu=True, rank0_only=True) if self.rank == 0: + hf_local_path = os.path.join(local_path, "huggingface") + os.makedirs(hf_local_path, exist_ok=True) + if "ForTokenClassification" in model_config.architectures[0]: from transformers import AutoModelForTokenClassification @@ -211,16 +277,29 @@ def save_checkpoint( # save_model.save_pretrained(hf_local_path, state_dict=state_dict) if self._save_model_thread is not None: self._save_model_thread.join() + + def _save_model(): + save_model.save_pretrained(hf_local_path, state_dict=state_dict) + log_with_rank( + f"Saved hf_model to {os.path.abspath(hf_local_path)}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) + self._save_model_thread = threading.Thread( - target=save_model.save_pretrained, - args=(hf_local_path, state_dict), + target=_save_model, + # target=save_model.save_pretrained, + # args=(hf_local_path, state_dict), ) self._save_model_thread.start() self.processing_class.save_pretrained(hf_local_path) - del state_dict - del save_model + # log_with_rank(f"Saved hf_model to {os.path.abspath(hf_local_path)}", rank=self.rank, logger=logger, log_only_rank_0=True) + # del state_dict + # del save_model # wait for rank0 to dump hf_model to local torch.distributed.barrier() - self.previous_saved_paths.append(local_path) + if not model_state_dict_only: + self.previous_saved_paths.append(local_path) diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index ad715e5d2e..d01866374e 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -27,8 +27,7 @@ import torch import torch.distributed import torch.distributed as dist - -# import vllm # noqa: F401 ; import vllm to set NCCL_CUMEM_ENABLE automatically. +import vllm # noqa: F401 ; import vllm to set NCCL_CUMEM_ENABLE automatically. from codetiming import Timer from omegaconf import DictConfig, OmegaConf, open_dict from peft import LoraConfig, TaskType, get_peft_model @@ -767,7 +766,14 @@ def compute_ref_log_prob(self, data: DataProto): return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): + def save_checkpoint( + self, + local_path, + hdfs_path=None, + global_step=0, + max_ckpt_to_keep=None, + model_state_dict_only=False, + ): from verl.utils.logger import log_with_rank # only support save and load ckpt for actor @@ -781,6 +787,7 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep, + model_state_dict_only=model_state_dict_only, ) dist.barrier() diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 38eef75fe3..b9ff7df2d2 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -149,6 +149,7 @@ def __init__( self.init_workers() self.monitor = MONITOR.get(global_config.monitor.monitor_type)( project=config.trainer.project_name, + group=config.trainer.group_name, name=config.trainer.experiment_name, role=global_config.trainer.name, config=global_config, @@ -287,7 +288,22 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl # TODO: compute total training steps self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize - def upload_state_dict(self): + def save_state_dict(self): # checkpoint sync + local_global_step_folder = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" + ) + + actor_local_path = os.path.join(local_global_step_folder, "actor") + if not os.path.exists(actor_local_path): + self.actor_rollout_wg.save_checkpoint( + actor_local_path, + None, + self.global_steps, + max_ckpt_to_keep=None, + model_state_dict_only=True, + ) + + def upload_state_dict(self): # state dict sync self.actor_rollout_wg.upload_state_dict(self.global_steps) def train_step(self) -> bool: # noqa C901 @@ -366,15 +382,6 @@ def train_step(self) -> bool: # noqa C901 actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) - if ( - self.config.trainer.save_freq > 0 - and self.global_steps % self.config.trainer.save_freq == 0 - ): - self.logger.info(f"Saving at step {self.global_steps}.") - with marked_timer("save_checkpoint", timing_raw): - self._save_checkpoint() - self.logger.info(f"Saved at step {self.global_steps}.") - # collect metrics if self.algorithm.use_advantage: # TODO metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) @@ -391,15 +398,18 @@ def train_step(self) -> bool: # noqa C901 self.monitor.log(data=metrics, step=self.global_steps) train_status = self.global_steps < self.total_training_steps - if not train_status or self.algorithm_manager.need_save(self.global_steps): - if ( - self.config.trainer.save_freq == 0 - or self.global_steps % self.config.trainer.save_freq != 0 - ): - self.logger.info(f"Saving at step {self.global_steps}.") - with marked_timer("save_checkpoint", timing_raw): - self._save_checkpoint() - self.logger.info(f"Saved at step {self.global_steps}.") + if ( + not train_status + or self.algorithm_manager.need_save(self.global_steps) + or ( + self.config.trainer.save_freq > 0 + and self.global_steps % self.config.trainer.save_freq == 0 + ) + ): + self.logger.info(f"Saving at step {self.global_steps}.") + with marked_timer("save_checkpoint", timing_raw): + self._save_checkpoint() + self.logger.info(f"Saved at step {self.global_steps}.") self.logger.info(f"Training at step {self.global_steps} finished.") return train_status diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 5896fc110d..6ba4d7482d 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -80,7 +80,9 @@ def calculate_metrics( @MONITOR.register_module("tensorboard") class TensorboardMonitor(Monitor): - def __init__(self, project: str, name: str, role: str, config: Config = None) -> None: + def __init__( + self, project: str, group: str, name: str, role: str, config: Config = None + ) -> None: self.tensorboard_dir = os.path.join(config.monitor.cache_dir, "tensorboard", role) os.makedirs(self.tensorboard_dir, exist_ok=True) self.logger = SummaryWriter(self.tensorboard_dir) @@ -101,10 +103,14 @@ def close(self) -> None: @MONITOR.register_module("wandb") class WandbMonitor(Monitor): - def __init__(self, project: str, name: str, role: str, config: Config = None) -> None: + def __init__( + self, project: str, group: str, name: str, role: str, config: Config = None + ) -> None: + if not group: + group = name self.logger = wandb.init( project=project, - group=name, + group=group, name=f"{name}_{role}", tags=[role], config=config, From 1623c97f6099097c1a719d7ae7a855744336677d Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 22 Jul 2025 12:21:53 +0800 Subject: [PATCH 06/16] 1. Add `group` into `checkpoint_job_dir` and change `group` in `wandb.init` to `f"{group}_{role}"`. 2. Rename `SyncMethod.STATE_DICT` to `SyncMethod.MEMORY`. 3. Add `wait_for_saving` when trainer shutdown. 4. Refactor `explorer_status` to `explorer_status_counter` in `Synchronizer` for multi explorer. 5. add "rollout/model_version" to monitor. 6. apply some suggestions made by gemini. --- trinity/common/config.py | 4 +- trinity/common/constants.py | 2 +- trinity/common/synchronizer.py | 71 ++++++++++++------- trinity/explorer/explorer.py | 52 +++++++++----- trinity/trainer/trainer.py | 15 ++-- .../trainer/verl/fsdp_checkpoint_manager.py | 45 ++++++------ trinity/trainer/verl/fsdp_workers.py | 10 ++- trinity/trainer/verl_trainer.py | 4 +- trinity/utils/monitor.py | 4 +- 9 files changed, 131 insertions(+), 76 deletions(-) diff --git a/trinity/common/config.py b/trinity/common/config.py index 850d959aae..499a0c034c 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -711,7 +711,9 @@ def check_and_update(self) -> None: # noqa: C901 if not os.path.isabs(self.checkpoint_root_dir): self.checkpoint_root_dir = os.path.join(os.getcwd(), self.checkpoint_root_dir) # create a job dir at checkpoint_root_dir/project/name - self.checkpoint_job_dir = os.path.join(self.checkpoint_root_dir, self.project, self.name) + self.checkpoint_job_dir = os.path.join( + self.checkpoint_root_dir, self.project, self.group, self.name + ) # rename the experiment when necessary if not self.continue_from_checkpoint and ( os.path.exists(self.checkpoint_job_dir) and os.listdir(self.checkpoint_job_dir) diff --git a/trinity/common/constants.py b/trinity/common/constants.py index d4e0f549ee..392f2dc553 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -95,7 +95,7 @@ class SyncMethod(CaseInsensitiveEnum, metaclass=SyncMethodEnumMeta): NCCL = "nccl" CHECKPOINT = "checkpoint" - STATE_DICT = "state_dict" + MEMORY = "memory" class RunningStatus(Enum): diff --git a/trinity/common/synchronizer.py b/trinity/common/synchronizer.py index f801981902..261b02c04a 100644 --- a/trinity/common/synchronizer.py +++ b/trinity/common/synchronizer.py @@ -1,9 +1,9 @@ -""" """ +"""A centralized synchronizer for coordinating explorer and trainer.""" import asyncio import os from collections import defaultdict -from typing import List, Optional +from typing import Dict, List, Optional, Union import ray @@ -22,35 +22,46 @@ def __init__(self, config: Config): self.config = config self.trainer_status = RunningStatus.RUNNING self.last_trainer_sync_step = 0 - self.explorer_status = RunningStatus.RUNNING + self.explorer_status_counter: Dict[RunningStatus, int] = {} self.last_explorer_sync_step = 0 self._ready_condition = asyncio.Condition() self.model_state_dict = None - self.state_dict_version = 0 - self.checkpoint_shard_count_dict = defaultdict(lambda: 0) + self.model_version = 0 + self.checkpoint_shard_counter = defaultdict(lambda: 0) def set_trainer_status(self, status: RunningStatus): self.trainer_status = status - def set_explorer_status(self, status: RunningStatus): - self.explorer_status = status - def get_trainer_status(self) -> RunningStatus: return self.trainer_status - def get_explorer_status(self) -> RunningStatus: - return self.explorer_status + def set_explorer_status( + self, status: RunningStatus, old_status: Optional[RunningStatus] = None + ): + if old_status is not None: + assert ( + old_status in self.explorer_status_counter + ), f"Invalid explorer status {old_status}" + assert old_status != status + self.explorer_status_counter[old_status] -= 1 + assert self.explorer_status_counter[old_status] >= 0 + if status not in self.explorer_status_counter: + self.explorer_status_counter[status] = 0 + self.explorer_status_counter[status] += 1 + + def get_explorer_status_counter(self) -> Dict[RunningStatus, int]: + return self.explorer_status_counter async def set_model_state_dict_with_step_num( self, step_num: Optional[int] = None, world_size: Optional[int] = None ) -> int: if world_size is not None: # Used for trainer to update model assert step_num is not None - self.checkpoint_shard_count_dict[step_num] += 1 + self.checkpoint_shard_counter[step_num] += 1 self.logger.info( - f"Synchronizer received checkpoint {self.checkpoint_shard_count_dict[step_num]} of {world_size} shards" + f"Synchronizer received checkpoint {self.checkpoint_shard_counter[step_num]} of {world_size} shards" ) - if self.checkpoint_shard_count_dict[step_num] < world_size: + if self.checkpoint_shard_counter[step_num] < world_size: return step_num checkpoint_dir, checkpoint_step_num = get_checkpoint_dir_with_step_num( @@ -65,12 +76,12 @@ async def set_model_state_dict_with_step_num( async def set_model_state_dict(self, model_state_dict, trainer_step): self.model_state_dict = model_state_dict async with self._ready_condition: - self.state_dict_version = trainer_step + self.model_version = trainer_step self.logger.info(f"Set model state dict version to {trainer_step}.") self._ready_condition.notify_all() def get_model_state_dict(self): - return self.model_state_dict, self.state_dict_version + return self.model_state_dict, self.model_version def get_state_dict_meta(self): if self.model_state_dict is None: @@ -89,28 +100,40 @@ async def setup_weight_sync_group( async def wait_new_model_state_dict(self, current_version: int) -> int: # wait for the new model state dict; return new version async with self._ready_condition: - if self.state_dict_version <= current_version: + if self.model_version <= current_version: + self.set_explorer_status( + RunningStatus.WAITING_SYNC, old_status=RunningStatus.REQUIRE_SYNC + ) await asyncio.wait_for( self._ready_condition.wait(), timeout=self.config.synchronizer.sync_timeout, ) - return self.state_dict_version - - async def ready_to_sync(self, module: str): + return self.model_version + + async def ready_to_nccl_sync( + self, module: str, trainer_step: Optional[int] = None + ) -> Union[int, None]: + assert ( + sum(self.explorer_status_counter.values()) == 1 + ), "NCCL sync is only supported for one explorer." async with self._ready_condition: try: if module == "trainer": + self.model_version = trainer_step self.trainer_status = RunningStatus.WAITING_SYNC self._ready_condition.notify_all() - if self.explorer_status != RunningStatus.WAITING_SYNC: + if self.explorer_status_counter[RunningStatus.WAITING_SYNC] != 1: await asyncio.wait_for( self._ready_condition.wait_for( - lambda: self.explorer_status == RunningStatus.WAITING_SYNC, + lambda: self.explorer_status_counter[RunningStatus.WAITING_SYNC] + == 1, ), timeout=self.config.synchronizer.sync_timeout, ) elif module == "explorer": - self.explorer_status = RunningStatus.WAITING_SYNC + self.set_explorer_status( + RunningStatus.WAITING_SYNC, old_status=RunningStatus.REQUIRE_SYNC + ) self._ready_condition.notify_all() if self.trainer_status != RunningStatus.WAITING_SYNC: await asyncio.wait_for( @@ -119,13 +142,13 @@ async def ready_to_sync(self, module: str): ), timeout=self.config.synchronizer.sync_timeout, ) - return True + return self.model_version except asyncio.TimeoutError: another_module = "Trainer" if module == "explorer" else "Explorer" self.logger.error( f"{another_module} is not ready for model weight sync in {self.config.synchronizer.sync_timeout} seconds." ) - return False + return None @classmethod def get_actor(cls, config: Optional[Config] = None, namespace: Optional[str] = None): diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 40cf4a8738..eec4f7d812 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -71,14 +71,13 @@ def __init__(self, config: Config): # For checkpoint weights update # Use explorer to periodically load the latest model weights and # boradcast to all rollout models - self.state_dict_version = 0 + self.model_version = 0 if self.use_state_dict_weights_update: self.old_checkpoint = None self.state_dict = {} else: # nccl mode self.state_dict_meta = [] self.logger.info("Finished initializing Explorer.") - self._ready_to_sync_condition = asyncio.Condition() async def setup_weight_sync_group( self, master_address: str, master_port: int, state_dict_meta: List = None @@ -144,25 +143,26 @@ async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> in async def _state_dict_update(self): self.logger.info("Start to update state dict.") new_version = ray.get( - self.synchronizer.wait_new_model_state_dict.remote(self.state_dict_version) + self.synchronizer.wait_new_model_state_dict.remote(self.model_version) ) - if new_version > self.state_dict_version: + if new_version > self.model_version: self.logger.info(f"New model state dict version: {new_version}") - await asyncio.gather( - *[model.sync_model.remote(self.state_dict_version) for model in self.models] - ) - self.state_dict_version = new_version + await asyncio.gather(*[model.sync_model.remote(new_version) for model in self.models]) + self.model_version = new_version else: self.logger.warning( - f"No new model state dict found, current version: {self.state_dict_version}" + f"No new model state dict found, current version: {self.model_version}" ) async def _nccl_weights_update(self): assert self.state_dict_meta is not None - status = ray.get(self.synchronizer.ready_to_sync.remote("explorer")) - if not status: + new_version = ray.get( + self.synchronizer.ready_to_nccl_sync.remote("explorer", self.model_version) + ) + if new_version is None: self.logger.info("Trainer is not ready to sync weight. Skipping sync weight.") return + self.model_version = new_version await asyncio.gather( *[model.sync_model.remote(self.explore_step_num) for model in self.models] ) @@ -176,6 +176,7 @@ async def prepare(self) -> None: asyncio.create_task(self.setup_weight_sync_group(master_address, master_port)) ) asyncio.gather(*futures, return_exceptions=True) + await self.synchronizer.set_explorer_status.remote(RunningStatus.REQUIRE_SYNC) if self.experience_buffer: await self.experience_buffer.acquire() if self.config.explorer.eval_on_startup and self.explore_step_num == 0: @@ -225,6 +226,9 @@ async def explore_step(self) -> bool: except StopIteration: self.logger.warning("No more tasks to explore. Stop exploring.") await self.save_checkpoint(sync_weight=False) + await self.synchronizer.set_explorer_status.remote( + RunningStatus.STOPPED, old_status=RunningStatus.RUNNING + ) await self.experience_buffer.release() return False self.scheduler.schedule(tasks, batch_id=self.explore_step_num + 1) @@ -245,11 +249,16 @@ def need_sync(self) -> bool: if delta >= self.config.synchronizer.sync_interval: require_sync = True else: - require_sync = ray.get( - self.synchronizer.get_trainer_status == RunningStatus.REQUIRE_SYNC + require_sync = ( + ray.get(self.synchronizer.get_trainer_status.remote()) + == RunningStatus.REQUIRE_SYNC ) if require_sync: - ray.get(self.synchronizer.set_explorer_status.remote(RunningStatus.REQUIRE_SYNC)) + ray.get( + self.synchronizer.set_explorer_status.remote( + RunningStatus.REQUIRE_SYNC, old_status=RunningStatus.RUNNING + ) + ) return require_sync def need_eval(self) -> bool: @@ -308,7 +317,7 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None: await self.scheduler.wait_all() self.logger.info(f"All tasks before step {self.explore_step_num} have completed.") log_task = asyncio.create_task( - self._log_metrics(self.last_sync_step + 1, self.explore_step_num) + self._log_metrics(self.last_sync_step + 1, self.explore_step_num, self.model_version) ) if sync_weight: @@ -334,18 +343,23 @@ async def sync_weight(self) -> None: """Synchronize model weights.""" # call this method before training start to load the latest model weights await self.save_checkpoint(sync_weight=True) - ray.get(self.synchronizer.set_explorer_status.remote(RunningStatus.RUNNING)) + ray.get( + self.synchronizer.set_explorer_status.remote( + RunningStatus.RUNNING, old_status=RunningStatus.WAITING_SYNC + ) + ) - async def _log_metrics(self, start_step: int, end_step: int) -> None: + async def _log_metrics(self, start_step: int, end_step: int, model_version: int) -> None: for step in range(start_step, end_step + 1): self.logger.info(f"Log metrics of step {step}") - await self._log_explore_metrics(step=step) + await self._log_explore_metrics(step=step, model_version=model_version) await self._log_eval_metrics(step=step) - async def _log_explore_metrics(self, step: int) -> None: + async def _log_explore_metrics(self, step: int, model_version: int) -> None: results = await self.scheduler.get_results(batch_id=step) if results: metric = gather_metrics([status.metric for status in results], "rollout") + metric["rollout/model_version"] = model_version self.monitor.log(metric, step=step) async def _log_eval_metrics(self, step: Optional[int] = None, prefix: str = "eval") -> None: diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index db3cb6cd50..3195532efb 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -64,8 +64,10 @@ def need_sync(self) -> bool: if delta >= self.config.synchronizer.sync_interval: ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.REQUIRE_SYNC)) return ( - ray.get(self.synchronizer.get_explorer_status.remote()) - == RunningStatus.WAITING_SYNC + ray.get(self.synchronizer.get_explorer_status_counter.remote())[ + RunningStatus.WAITING_SYNC + ] + > 0 ) def sync_weight(self) -> None: @@ -74,7 +76,12 @@ def sync_weight(self) -> None: self.logger.info( f"Trainer synchronizing weights at step {self.engine.train_step_num} starting.." ) - assert ray.get(self.synchronizer.ready_to_sync.remote("trainer")) + result = ray.get( + self.synchronizer.ready_to_nccl_sync.remote("trainer", self.engine.train_step_num) + ) + if result is None: + self.logger.error("Trainer synchronizing weights failed.") + raise Exception self.engine.sync_weight() self.logger.info( f"Trainer synchronizing weights at step {self.engine.train_step_num} end." @@ -83,7 +90,7 @@ def sync_weight(self) -> None: elif self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT: self.engine.save_state_dict() # ray.get(self.synchronizer.set_model_state_dict_with_step_num.remote(self.engine.train_step_num)) - elif self.config.synchronizer.sync_method == SyncMethod.STATE_DICT: + elif self.config.synchronizer.sync_method == SyncMethod.MEMORY: self.engine.upload_state_dict() ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)) diff --git a/trinity/trainer/verl/fsdp_checkpoint_manager.py b/trinity/trainer/verl/fsdp_checkpoint_manager.py index e1d9ac5bea..1ed644cc93 100644 --- a/trinity/trainer/verl/fsdp_checkpoint_manager.py +++ b/trinity/trainer/verl/fsdp_checkpoint_manager.py @@ -28,6 +28,7 @@ ) from verl.utils.logger import log_with_rank +from trinity.common.constants import SyncMethod from trinity.common.synchronizer import Synchronizer @@ -35,6 +36,7 @@ class FSDPCheckpointManager(OldFSDPCheckpointManager): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) config = kwargs.pop("config", None) + self.synchronizer_config = config if config is not None: self.synchronizer = Synchronizer.get_actor(namespace=config.ray_namespace) else: @@ -63,6 +65,9 @@ def save_checkpoint( # noqa: C901 """ modified from verl.utils.checkpoint.fsdp_checkpoint_manager.py:save_checkpoint """ + if global_step == 0 and model_state_dict_only: + ray.get(self.synchronizer.set_model_state_dict.remote(None, global_step)) + return if local_path is None: return @@ -114,7 +119,6 @@ def save_checkpoint( # noqa: C901 if self.should_save_model: model_state_dict = self.model.state_dict() - # torch.save(model_state_dict, model_path) if self._model_state_dict_thread is not None: self._model_state_dict_thread.join() @@ -125,23 +129,23 @@ def _save_model_state_dict(): rank=self.rank, logger=logger, ) - ray.get( - self.synchronizer.set_model_state_dict_with_step_num.remote( - global_step, self.world_size + if ( + self.synchronizer_config is not None + and self.synchronizer_config.sync_method == SyncMethod.CHECKPOINT + ): + ray.get( + self.synchronizer.set_model_state_dict_with_step_num.remote( + global_step, self.world_size + ) ) - ) self._model_state_dict_thread = threading.Thread( target=_save_model_state_dict, - # target=torch.save, - # args=(model_state_dict, model_path), ) self._model_state_dict_thread.start() - # log_with_rank(f"Saved model to {os.path.abspath(model_path)}", rank=self.rank, logger=logger) if self.should_save_optimizer and not model_state_dict_only: optimizer_state_dict = self.optimizer.state_dict() - # torch.save(optimizer_state_dict, optim_path) if self._optimizer_state_dict_thread is not None: self._optimizer_state_dict_thread.join() @@ -155,11 +159,8 @@ def _save_optimizer_state_dict(): self._optimizer_state_dict_thread = threading.Thread( target=_save_optimizer_state_dict, - # target=torch.save, - # args=(optimizer_state_dict, optim_path), ) self._optimizer_state_dict_thread.start() - # log_with_rank(f"Saved optim to {os.path.abspath(optim_path)}", rank=self.rank, logger=logger) if self.should_save_extra and not model_state_dict_only: lr_scheduler_state_dict = ( @@ -169,7 +170,6 @@ def _save_optimizer_state_dict(): "lr_scheduler": lr_scheduler_state_dict, "rng": self.get_rng_state(), } - # torch.save(extra_state_dict, extra_path) if self._extra_state_dict_thread is not None: self._extra_state_dict_thread.join() @@ -183,11 +183,8 @@ def _save_extra_state_dict(): self._extra_state_dict_thread = threading.Thread( target=_save_extra_state_dict, - # target=torch.save, - # args=(extra_state_dict, extra_path), ) self._extra_state_dict_thread.start() - # log_with_rank(f"Saved extra_state to {os.path.abspath(extra_path)}", rank=self.rank, logger=logger) if self.rank == 0: # Save HF tokenizer/processor and model config on rank 0 to huggingface/ directory, no matter whether @@ -274,7 +271,6 @@ def _save_extra_state_dict(): f"Warning: {self.__class__.__name__}.save_checkpoint: Generation config file not found in, using a generation config created from the model config when saving hf_model." ) - # save_model.save_pretrained(hf_local_path, state_dict=state_dict) if self._save_model_thread is not None: self._save_model_thread.join() @@ -289,17 +285,22 @@ def _save_model(): self._save_model_thread = threading.Thread( target=_save_model, - # target=save_model.save_pretrained, - # args=(hf_local_path, state_dict), ) self._save_model_thread.start() self.processing_class.save_pretrained(hf_local_path) - # log_with_rank(f"Saved hf_model to {os.path.abspath(hf_local_path)}", rank=self.rank, logger=logger, log_only_rank_0=True) - # del state_dict - # del save_model # wait for rank0 to dump hf_model to local torch.distributed.barrier() if not model_state_dict_only: self.previous_saved_paths.append(local_path) + + def wait_for_saving(self) -> None: + if self._model_state_dict_thread is not None: + self._model_state_dict_thread.join() + if self._optimizer_state_dict_thread is not None: + self._optimizer_state_dict_thread.join() + if self._extra_state_dict_thread is not None: + self._extra_state_dict_thread.join() + if self._save_model_thread is not None: + self._save_model_thread.join() diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index d01866374e..5e858aacd7 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -42,8 +42,6 @@ from verl.single_controller.base.decorator import Dispatch, register from verl.utils import hf_processor, hf_tokenizer from verl.utils.activation_offload import enable_activation_offloading - -# from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.utils.debug import log_gpu_memory_usage from verl.utils.device import ( get_device_id, @@ -863,6 +861,10 @@ def clear_optimizer_state(self): if self._is_offload_optimizer: offload_fsdp_optimizer(self.actor_optimizer) + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def wait_for_saving(self) -> None: + self.checkpoint_manager.wait_for_saving() + class CriticWorker(Worker): def __init__(self, config): @@ -1288,3 +1290,7 @@ def clear_optimizer_state(self): self.critic_optimizer.zero_grad() if self._is_offload_optimizer: offload_fsdp_optimizer(self.critic_optimizer) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def wait_for_saving(self) -> None: + self.checkpoint_manager.wait_for_saving() diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index b9ff7df2d2..8204dd4ced 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -496,4 +496,6 @@ def sft_to_rft(self) -> None: print("sft to rft finished") def shutdown(self) -> None: - pass + self.actor_rollout_wg.wait_for_saving() + if self.algorithm.use_critic: + self.critic_wg.wait_for_saving() diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 6ba4d7482d..ab363e5199 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -110,8 +110,8 @@ def __init__( group = name self.logger = wandb.init( project=project, - group=group, - name=f"{name}_{role}", + group=f"{group}_{role}", + name=name, tags=[role], config=config, save_code=False, From a68926230d62931301c476f42e466341c803a9ef Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 22 Jul 2025 15:59:46 +0800 Subject: [PATCH 07/16] add doc string for `Synchronizer` and `FSDPCheckpointManager` --- trinity/common/synchronizer.py | 93 ++++++++++++++++++- .../trainer/verl/fsdp_checkpoint_manager.py | 11 +++ 2 files changed, 100 insertions(+), 4 deletions(-) diff --git a/trinity/common/synchronizer.py b/trinity/common/synchronizer.py index 261b02c04a..a6eb06a954 100644 --- a/trinity/common/synchronizer.py +++ b/trinity/common/synchronizer.py @@ -17,27 +17,47 @@ class Synchronizer: + """ + A central component to manage synchronization of models and states between + the trainer and one or more explorers in a distributed training setup. + + Attributes: + trainer_status: Current status of the trainer (e.g., running, waiting). + explorer_status_counter: Dictionary tracking the number of explorers in each status. + _ready_condition: Async condition variable for signaling state changes. + model_state_dict: The latest model weights. + model_version: Version number of the current model. + checkpoint_shard_counter: Tracks how many shards are received from trainer for a specific train step. + """ + def __init__(self, config: Config): self.logger = get_logger(__name__) self.config = config self.trainer_status = RunningStatus.RUNNING - self.last_trainer_sync_step = 0 self.explorer_status_counter: Dict[RunningStatus, int] = {} - self.last_explorer_sync_step = 0 self._ready_condition = asyncio.Condition() self.model_state_dict = None self.model_version = 0 self.checkpoint_shard_counter = defaultdict(lambda: 0) def set_trainer_status(self, status: RunningStatus): + """Update the status of the trainer.""" self.trainer_status = status def get_trainer_status(self) -> RunningStatus: + """Get the current status of the trainer.""" return self.trainer_status def set_explorer_status( self, status: RunningStatus, old_status: Optional[RunningStatus] = None ): + """ + Update the status count for an explorer. + + Args: + status: New status of the explorer. + old_status: Previous status if changing from one to another. + """ if old_status is not None: assert ( old_status in self.explorer_status_counter @@ -50,12 +70,23 @@ def set_explorer_status( self.explorer_status_counter[status] += 1 def get_explorer_status_counter(self) -> Dict[RunningStatus, int]: + """Return the current status counts for all explorers.""" return self.explorer_status_counter async def set_model_state_dict_with_step_num( self, step_num: Optional[int] = None, world_size: Optional[int] = None ) -> int: - if world_size is not None: # Used for trainer to update model + """ + Load and set the model state dictionary from a checkpoint at a specific step. + + Args: + step_num: Training step number corresponding to the checkpoint. + world_size: Number of shards expected for this checkpoint. + + Returns: + The updated model version (step number). + """ + if world_size is not None: # Used when trainer updates the model assert step_num is not None self.checkpoint_shard_counter[step_num] += 1 self.logger.info( @@ -74,6 +105,13 @@ async def set_model_state_dict_with_step_num( return checkpoint_step_num async def set_model_state_dict(self, model_state_dict, trainer_step): + """ + Set the new model state and update the version. + + Args: + model_state_dict: The PyTorch model state dictionary. + trainer_step: Step number associated with this model version. + """ self.model_state_dict = model_state_dict async with self._ready_condition: self.model_version = trainer_step @@ -81,9 +119,16 @@ async def set_model_state_dict(self, model_state_dict, trainer_step): self._ready_condition.notify_all() def get_model_state_dict(self): + """Return the current model state and its version.""" return self.model_state_dict, self.model_version def get_state_dict_meta(self): + """ + Return metadata about the model state (names, data types, shapes). + + Returns: + List of tuples: (name, dtype, shape). + """ if self.model_state_dict is None: return None update_weight_args_list = [] @@ -94,11 +139,29 @@ def get_state_dict_meta(self): async def setup_weight_sync_group( self, master_address: str, master_port: int, state_dict_meta: List = None ): + """ + Notify the explorer actor to setup weight sync group. + + This is used to initialize NCCL-based synchronization for distributed training. + + Args: + master_address: IP address of the master node. + master_port: Port used for synchronization. + state_dict_meta: Metadata of the model parameters. + """ explorer = ray.get_actor(self.config.explorer_name) await explorer.setup_weight_sync_group.remote(master_address, master_port, state_dict_meta) async def wait_new_model_state_dict(self, current_version: int) -> int: - # wait for the new model state dict; return new version + """ + Wait until a new model state is available. + + Args: + current_version: Current model version known to one explorer. + + Returns: + The new model version after it has been updated. + """ async with self._ready_condition: if self.model_version <= current_version: self.set_explorer_status( @@ -113,6 +176,18 @@ async def wait_new_model_state_dict(self, current_version: int) -> int: async def ready_to_nccl_sync( self, module: str, trainer_step: Optional[int] = None ) -> Union[int, None]: + """ + Prepare for NCCL-based synchronization between modules. + + Only supports one explorer currently. + + Args: + module: Either 'trainer' or 'explorer'. + trainer_step: Optional step number from the trainer. + + Returns: + The model version if both sides are ready; otherwise None. + """ assert ( sum(self.explorer_status_counter.values()) == 1 ), "NCCL sync is only supported for one explorer." @@ -152,6 +227,16 @@ async def ready_to_nccl_sync( @classmethod def get_actor(cls, config: Optional[Config] = None, namespace: Optional[str] = None): + """ + Get or create a remote Ray actor for the Synchronizer. + + Args: + config: Optional configuration to use for creating the actor. + namespace: Optional Ray namespace for the actor. + + Returns: + A reference to the Synchronizer actor. + """ if config is not None: return ( ray.remote(cls) diff --git a/trinity/trainer/verl/fsdp_checkpoint_manager.py b/trinity/trainer/verl/fsdp_checkpoint_manager.py index 1ed644cc93..9f9ca02e0d 100644 --- a/trinity/trainer/verl/fsdp_checkpoint_manager.py +++ b/trinity/trainer/verl/fsdp_checkpoint_manager.py @@ -47,6 +47,12 @@ def __init__(self, *args, **kwargs): self._save_model_thread = None def upload_state_dict(self, trainer_step: int): + """ + Uploads the full model state dictionary to the synchronizer actor for remote access. + + Args: + trainer_step (int): The current training step number. + """ assert self.synchronizer is not None state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with get_fsdp_state_ctx(self.model, StateDictType.FULL_STATE_DICT, state_dict_config, None): @@ -64,6 +70,8 @@ def save_checkpoint( # noqa: C901 ): """ modified from verl.utils.checkpoint.fsdp_checkpoint_manager.py:save_checkpoint + + The main improvement is using separate thread to save checkpoint and the implementation of checkpoint sync method. """ if global_step == 0 and model_state_dict_only: ray.get(self.synchronizer.set_model_state_dict.remote(None, global_step)) @@ -296,6 +304,9 @@ def _save_model(): self.previous_saved_paths.append(local_path) def wait_for_saving(self) -> None: + """ + Wait for all background saving threads to complete. + """ if self._model_state_dict_thread is not None: self._model_state_dict_thread.join() if self._optimizer_state_dict_thread is not None: From ac0b2f0ca2d957ab89a12dd699f10ed17869a8cc Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Thu, 24 Jul 2025 16:19:21 +0800 Subject: [PATCH 08/16] 1. Bug fix in `trainer_test` and `explorer_test` 2. Fix shutdown in `both` 3. Refactored the internal status transition logic of `Trainer` and `Explorer` in`Synchronizer`. 4. Avoid duplicate model saving. 5. Bug fix where model was exited before it was saved. --- tests/common/config_test.py | 4 - tests/tools.py | 4 +- tests/trainer/trainer_test.py | 5 +- trinity/cli/launcher.py | 7 +- trinity/common/synchronizer.py | 70 ++++++++++----- trinity/explorer/explorer.py | 85 +++++++++---------- trinity/explorer/scheduler.py | 2 +- trinity/trainer/trainer.py | 38 ++++----- .../trainer/verl/fsdp_checkpoint_manager.py | 68 ++++++++------- trinity/trainer/verl_trainer.py | 15 ++-- trinity/utils/monitor.py | 4 +- 11 files changed, 167 insertions(+), 135 deletions(-) diff --git a/tests/common/config_test.py b/tests/common/config_test.py index db7190856f..f3eea28740 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -35,10 +35,6 @@ def test_load_default_config(self): ) self.assertEqual(config.model.model_path, config.model.critic_model_path) self.assertEqual(config.model.model_path, config.explorer.rollout_model.model_path) - self.assertEqual( - config.trainer.trainer_config.trainer.save_freq, - config.synchronizer.sync_interval, - ) def test_all_examples_are_valid(self): example_dir = os.path.join(os.path.dirname(__file__), "..", "..", "examples") diff --git a/tests/tools.py b/tests/tools.py index 60df1122f2..b89607138b 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -18,7 +18,9 @@ def get_template_config() -> Config: config_path = os.path.join(os.path.dirname(__file__), "template", "config.yaml") - return load_config(config_path) + config = load_config(config_path) + config.ray_namespace = ray.get_runtime_context().namespace + return config def get_model_path() -> str: diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index a697539141..980ba3d56e 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -22,7 +22,7 @@ ) from trinity.cli.launcher import bench, both, explore, train from trinity.common.config import Config, StorageConfig -from trinity.common.constants import StorageType, SyncMethod +from trinity.common.constants import StorageType, SyncMethod, SyncStyle from trinity.common.models.utils import get_checkpoint_dir_with_step_num from trinity.manager.manager import CacheManager @@ -99,7 +99,7 @@ def test_trainer(self): self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_8, "actor"))) > 0) self.assertEqual(step_num, 8) # TODO: Reinit will fail when using v1 engine, find a way to fix it - ray.init(ignore_reinit_error=True) + ray.init(ignore_reinit_error=True, namespace=self.config.ray_namespace) # test bench mode self.config.mode = "bench" self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT @@ -332,6 +332,7 @@ def test_fully_async_mode(self, name, use_priority_queue): use_priority_queue=use_priority_queue, ) config.synchronizer.sync_method = SyncMethod.CHECKPOINT + config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER config.synchronizer.sync_interval = 8 config.monitor.monitor_type = "tensorboard" trainer_config = deepcopy(config) diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 76830a125f..66419b2cec 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -150,8 +150,11 @@ def both(config: Config) -> None: "============================================================" ) ray.wait(wait_ref, timeout=config.synchronizer.sync_timeout) - explorer.shutdown.remote() - trainer.shutdown.remote() + ray.wait( + [explorer.shutdown.remote(), trainer.shutdown.remote()], + timeout=config.synchronizer.sync_timeout, + num_returns=2, + ) def run(config_path: str, dlc: bool = False, plugin_dir: str = None): diff --git a/trinity/common/synchronizer.py b/trinity/common/synchronizer.py index a6eb06a954..144edb5298 100644 --- a/trinity/common/synchronizer.py +++ b/trinity/common/synchronizer.py @@ -33,8 +33,8 @@ class Synchronizer: def __init__(self, config: Config): self.logger = get_logger(__name__) self.config = config - self.trainer_status = RunningStatus.RUNNING - self.explorer_status_counter: Dict[RunningStatus, int] = {} + self.trainer_status = RunningStatus.STOPPED + self.explorer_status_counter: Dict[RunningStatus, int] = defaultdict(lambda: 0) self._ready_condition = asyncio.Condition() self.model_state_dict = None self.model_version = 0 @@ -62,9 +62,11 @@ def set_explorer_status( assert ( old_status in self.explorer_status_counter ), f"Invalid explorer status {old_status}" - assert old_status != status + assert old_status != status, f"Invalid status change from {old_status} to {status}" self.explorer_status_counter[old_status] -= 1 - assert self.explorer_status_counter[old_status] >= 0 + assert ( + self.explorer_status_counter[old_status] >= 0 + ), f"Invalid status count {old_status} (new status {status})" if status not in self.explorer_status_counter: self.explorer_status_counter[status] = 0 self.explorer_status_counter[status] += 1 @@ -88,9 +90,10 @@ async def set_model_state_dict_with_step_num( """ if world_size is not None: # Used when trainer updates the model assert step_num is not None + assert self.checkpoint_shard_counter[step_num] < world_size, "World size mismatch!" self.checkpoint_shard_counter[step_num] += 1 self.logger.info( - f"Synchronizer received checkpoint {self.checkpoint_shard_counter[step_num]} of {world_size} shards" + f"Synchronizer has received {self.checkpoint_shard_counter[step_num]} out of {world_size} shards from the checkpoint {step_num}." ) if self.checkpoint_shard_counter[step_num] < world_size: return step_num @@ -100,11 +103,14 @@ async def set_model_state_dict_with_step_num( trainer_type=self.config.trainer.trainer_type, step_num=step_num, ) - model_state_dict = load_state_dict(os.path.join(checkpoint_dir, "actor")) # TODO: to thread - await self.set_model_state_dict(model_state_dict, checkpoint_step_num) + if checkpoint_step_num != self.model_version: + model_state_dict = load_state_dict( + os.path.join(checkpoint_dir, "actor") + ) # TODO: to thread + await self.set_model_state_dict(model_state_dict, checkpoint_step_num) return checkpoint_step_num - async def set_model_state_dict(self, model_state_dict, trainer_step): + async def set_model_state_dict(self, model_state_dict: Union[dict, None], trainer_step: int): """ Set the new model state and update the version. @@ -152,7 +158,7 @@ async def setup_weight_sync_group( explorer = ray.get_actor(self.config.explorer_name) await explorer.setup_weight_sync_group.remote(master_address, master_port, state_dict_meta) - async def wait_new_model_state_dict(self, current_version: int) -> int: + async def wait_new_model_state_dict(self, current_version: int, no_wait: bool = False) -> int: """ Wait until a new model state is available. @@ -163,14 +169,21 @@ async def wait_new_model_state_dict(self, current_version: int) -> int: The new model version after it has been updated. """ async with self._ready_condition: - if self.model_version <= current_version: + assert ( + self.model_version >= current_version + ), f"The model version in Synchronizer ({self.model_version}) should be greater than that in Explorer ({current_version})!" + if self.model_version == current_version: + if not no_wait and self.trainer_status != RunningStatus.STOPPED: + # TODO: explorer need support no wait + # TODO: handle timeout + await asyncio.wait_for( + self._ready_condition.wait(), + timeout=self.config.synchronizer.sync_timeout, + ) + if self.model_version > current_version: self.set_explorer_status( RunningStatus.WAITING_SYNC, old_status=RunningStatus.REQUIRE_SYNC ) - await asyncio.wait_for( - self._ready_condition.wait(), - timeout=self.config.synchronizer.sync_timeout, - ) return self.model_version async def ready_to_nccl_sync( @@ -191,6 +204,29 @@ async def ready_to_nccl_sync( assert ( sum(self.explorer_status_counter.values()) == 1 ), "NCCL sync is only supported for one explorer." + + def sync_failed(): + if module == "explorer": + another_module = "Trainer" + self.set_explorer_status( + RunningStatus.REQUIRE_SYNC, old_status=RunningStatus.WAITING_SYNC + ) + else: + another_module = "Explorer" + self.trainer_status = RunningStatus.REQUIRE_SYNC + self.logger.error( + f"{another_module} is not ready for model weight sync in {self.config.synchronizer.sync_timeout} seconds." + ) + return None + + non_stop_cnt = sum( + value + for key, value in self.explorer_status_counter.items() + if key != RunningStatus.STOPPED + ) + if non_stop_cnt == 0: + return sync_failed() + # for status in RunningStatus: async with self._ready_condition: try: if module == "trainer": @@ -219,11 +255,7 @@ async def ready_to_nccl_sync( ) return self.model_version except asyncio.TimeoutError: - another_module = "Trainer" if module == "explorer" else "Explorer" - self.logger.error( - f"{another_module} is not ready for model weight sync in {self.config.synchronizer.sync_timeout} seconds." - ) - return None + return sync_failed() @classmethod def get_actor(cls, config: Optional[Config] = None, namespace: Optional[str] = None): diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 2ce4f54876..7caa555f5c 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -72,12 +72,8 @@ def __init__(self, config: Config): # For checkpoint weights update # Use explorer to periodically load the latest model weights and # boradcast to all rollout models - self.model_version = 0 - if self.use_state_dict_weights_update: - self.old_checkpoint = None - self.state_dict = {} - else: # nccl mode - self.state_dict_meta = [] + self.model_version = -1 + self.last_sync_successful = True self.logger.info("Finished initializing Explorer.") self.collect_experiences = self.config.explorer.collect_experiences self.generated_experience_cnt = 0 @@ -102,7 +98,6 @@ async def setup_weight_sync_group( f"master_address={master_address}, master_port={master_port}, " f"world_size={world_size}, rank_offset={base_offset}" ) - self.state_dict_meta = state_dict_meta # TODO: save state_dict in models refs = [ model.init_process_group.remote( @@ -130,21 +125,6 @@ def _init_scheduler(self) -> Scheduler: ) return Scheduler(self.config, self.models, self.auxiliary_models) - async def _update_model_weight(self, step_num: int, state_dict: dict) -> None: - # TODO: update model weight - self.state_dict = state_dict - if self.state_dict_meta is None: - update_weight_args_list = [] - for name, param in state_dict.items(): - update_weight_args_list.append((name, str(param.dtype), tuple(param.shape))) - self.state_dict_meta = update_weight_args_list - else: - update_weight_args_list = None - await asyncio.gather( - *[model.sync_model.remote(step_num, update_weight_args_list) for model in self.models] - ) - self.state_dict.clear() - async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> int: step_num = ray.get(self.synchronizer.set_model_state_dict_with_step_num.remote(step_num)) await asyncio.gather(*[model.sync_model.remote(step_num) for model in self.models]) @@ -156,29 +136,49 @@ async def _state_dict_update(self): self.synchronizer.wait_new_model_state_dict.remote(self.model_version) ) if new_version > self.model_version: - self.logger.info(f"New model state dict version: {new_version}") - await asyncio.gather(*[model.sync_model.remote(new_version) for model in self.models]) + if self.model_version != -1: + self.logger.info(f"New model state dict version: {new_version}") + await asyncio.gather( + *[model.sync_model.remote(new_version) for model in self.models] + ) self.model_version = new_version + self.last_sync_step = self.explore_step_num + ray.get( + self.synchronizer.set_explorer_status.remote( + RunningStatus.RUNNING, old_status=RunningStatus.WAITING_SYNC + ) + ) + self.last_sync_successful = True else: self.logger.warning( f"No new model state dict found, current version: {self.model_version}" ) + self.last_sync_successful = False async def _nccl_weights_update(self): - assert self.state_dict_meta is not None new_version = ray.get( self.synchronizer.ready_to_nccl_sync.remote("explorer", self.model_version) ) if new_version is None: self.logger.info("Trainer is not ready to sync weight. Skipping sync weight.") + self.last_sync_successful = False return self.model_version = new_version await asyncio.gather( - *[model.sync_model.remote(self.explore_step_num) for model in self.models] + *[model.sync_model.remote(self.model_version) for model in self.models] ) + self.last_sync_step = self.explore_step_num + ray.get( + self.synchronizer.set_explorer_status.remote( + RunningStatus.RUNNING, old_status=RunningStatus.WAITING_SYNC + ) + ) + self.last_sync_successful = True async def prepare(self) -> None: """Preparation before running.""" + if self.experience_buffer: + await self.experience_buffer.acquire() futures = [asyncio.create_task(self.scheduler.start())] if self.use_state_dict_weights_update: master_address, master_port = await self.models[0].get_available_address.remote() @@ -186,11 +186,9 @@ async def prepare(self) -> None: asyncio.create_task(self.setup_weight_sync_group(master_address, master_port)) ) asyncio.gather(*futures, return_exceptions=True) - await self.synchronizer.set_explorer_status.remote(RunningStatus.REQUIRE_SYNC) - if self.experience_buffer: - await self.experience_buffer.acquire() if self.config.explorer.eval_on_startup and self.explore_step_num == 0: self.eval() + await self.synchronizer.set_explorer_status.remote(RunningStatus.REQUIRE_SYNC) async def get_weight(self, name: str) -> torch.Tensor: """Get the weight of the loaded model (For checkpoint weights update).""" @@ -237,7 +235,10 @@ async def explore_step(self) -> bool: self.logger.warning("No more tasks to explore. Stop exploring.") await self.save_checkpoint(sync_weight=False) await self.synchronizer.set_explorer_status.remote( - RunningStatus.STOPPED, old_status=RunningStatus.RUNNING + RunningStatus.STOPPED, + old_status=RunningStatus.RUNNING + if self.last_sync_successful + else RunningStatus.REQUIRE_SYNC, ) await self.experience_buffer.release() return False @@ -249,7 +250,7 @@ def need_sync(self) -> bool: if self.config.synchronizer.sync_style == SyncStyle.FIXED: if self.explore_step_num <= self.config.synchronizer.sync_offset: return False - return ( + require_sync = ( self.explore_step_num - self.config.synchronizer.sync_offset ) % self.config.synchronizer.sync_interval == 0 else: @@ -263,13 +264,13 @@ def need_sync(self) -> bool: ray.get(self.synchronizer.get_trainer_status.remote()) == RunningStatus.REQUIRE_SYNC ) - if require_sync: - ray.get( - self.synchronizer.set_explorer_status.remote( - RunningStatus.REQUIRE_SYNC, old_status=RunningStatus.RUNNING - ) + if require_sync and self.last_sync_successful: + ray.get( + self.synchronizer.set_explorer_status.remote( + RunningStatus.REQUIRE_SYNC, old_status=RunningStatus.RUNNING ) - return require_sync + ) + return require_sync def need_eval(self) -> bool: return self.explore_step_num % self.config.explorer.eval_interval == 0 @@ -338,8 +339,9 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None: await self._state_dict_update() else: # nccl weights update await self._nccl_weights_update() - self.last_sync_step = self.explore_step_num - self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} finished") + self.logger.info( + f"Explorer sync_weights at step {self.explore_step_num} finished, model version = {self.model_version}." + ) # overlay log and weight sync await log_task @@ -354,11 +356,6 @@ async def sync_weight(self) -> None: """Synchronize model weights.""" # call this method before training start to load the latest model weights await self.save_checkpoint(sync_weight=True) - ray.get( - self.synchronizer.set_explorer_status.remote( - RunningStatus.RUNNING, old_status=RunningStatus.WAITING_SYNC - ) - ) async def _finish_steps(self, start_step: int, end_step: int, model_version: int) -> None: for step in range(start_step, end_step + 1): diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 5580d20fa0..2c3cb0c5fb 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -227,7 +227,7 @@ def task_done_callback(self, async_task: asyncio.Task): if async_task.cancelled(): return elif async_task.exception(): - self.logger.error(f"Task {task.task_id} failed: {async_task.exception()}") + self.logger.error(f"Task {task.task.task_id} failed: {async_task.exception()}") return else: status, exps, runner_id = async_task.result() diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 3195532efb..ff1248f8d6 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -4,7 +4,6 @@ """ from __future__ import annotations -import os import traceback from abc import ABC, abstractmethod @@ -30,6 +29,7 @@ def prepare(self) -> None: """Prepare the trainer.""" self.engine.prepare() self.last_trainer_sync_step = self.engine.train_step_num + ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)) def train(self) -> str: """Train the model.""" @@ -63,44 +63,40 @@ def need_sync(self) -> bool: delta = self.engine.train_step_num - self.last_trainer_sync_step if delta >= self.config.synchronizer.sync_interval: ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.REQUIRE_SYNC)) - return ( - ray.get(self.synchronizer.get_explorer_status_counter.remote())[ - RunningStatus.WAITING_SYNC - ] - > 0 + explorer_status_counter = ray.get( + self.synchronizer.get_explorer_status_counter.remote() ) + if self.config.synchronizer.sync_method == SyncMethod.NCCL: + return explorer_status_counter[RunningStatus.WAITING_SYNC] > 0 + else: # memory & checkpoint + return explorer_status_counter[RunningStatus.REQUIRE_SYNC] > 0 def sync_weight(self) -> None: """Sync the model weight.""" + self.logger.info( + f"Trainer synchronizing weights at step {self.engine.train_step_num} starting.." + ) if self.config.synchronizer.sync_method == SyncMethod.NCCL: - self.logger.info( - f"Trainer synchronizing weights at step {self.engine.train_step_num} starting.." - ) result = ray.get( self.synchronizer.ready_to_nccl_sync.remote("trainer", self.engine.train_step_num) ) if result is None: self.logger.error("Trainer synchronizing weights failed.") - raise Exception - self.engine.sync_weight() - self.logger.info( - f"Trainer synchronizing weights at step {self.engine.train_step_num} end." - ) - self.last_trainer_sync_step = self.engine.train_step_num + else: + self.engine.sync_weight() + self.last_trainer_sync_step = self.engine.train_step_num elif self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT: self.engine.save_state_dict() - # ray.get(self.synchronizer.set_model_state_dict_with_step_num.remote(self.engine.train_step_num)) elif self.config.synchronizer.sync_method == SyncMethod.MEMORY: self.engine.upload_state_dict() + self.logger.info(f"Trainer synchronizing weights at step {self.engine.train_step_num} end.") ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)) def shutdown(self) -> None: - # if checkpoint not saved, save the last checkpoint - step_num = self.engine.train_step_num - path = os.path.join(self.config.checkpoint_job_dir, f"global_step_{step_num}") - if not os.path.isdir(path) or len(os.listdir(path)) == 0: - self.engine.save_checkpoint() + ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.STOPPED)) + self.engine.save_checkpoint() self.engine.monitor.close() + self.engine.shutdown() class TrainEngineWrapper(ABC): diff --git a/trinity/trainer/verl/fsdp_checkpoint_manager.py b/trinity/trainer/verl/fsdp_checkpoint_manager.py index 9f9ca02e0d..5f70fb7d6c 100644 --- a/trinity/trainer/verl/fsdp_checkpoint_manager.py +++ b/trinity/trainer/verl/fsdp_checkpoint_manager.py @@ -3,7 +3,7 @@ import threading import warnings from dataclasses import asdict -from typing import Optional +from typing import Optional, Union import ray import torch @@ -46,19 +46,30 @@ def __init__(self, *args, **kwargs): self._extra_state_dict_thread = None self._save_model_thread = None - def upload_state_dict(self, trainer_step: int): + def _notify_synchronizer_with_step_num(self, global_step): + if getattr(self.synchronizer_config, "sync_method", None) == SyncMethod.CHECKPOINT: + ray.get( + self.synchronizer.set_model_state_dict_with_step_num.remote( + global_step, self.world_size + ) + ) + + def _upload_state_dict(self, state_dict: Union[dict, None], global_step: int): + if self.rank == 0: + ray.get(self.synchronizer.set_model_state_dict.remote(state_dict, global_step)) + + def upload_state_dict(self, global_step: int): """ Uploads the full model state dictionary to the synchronizer actor for remote access. Args: - trainer_step (int): The current training step number. + global_step (int): The current training step number. """ assert self.synchronizer is not None state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with get_fsdp_state_ctx(self.model, StateDictType.FULL_STATE_DICT, state_dict_config, None): state_dict = self.model.state_dict() - if self.rank == 0: - ray.get(self.synchronizer.set_model_state_dict.remote(state_dict, trainer_step)) + self._upload_state_dict(state_dict, global_step) def save_checkpoint( # noqa: C901 self, @@ -74,7 +85,7 @@ def save_checkpoint( # noqa: C901 The main improvement is using separate thread to save checkpoint and the implementation of checkpoint sync method. """ if global_step == 0 and model_state_dict_only: - ray.get(self.synchronizer.set_model_state_dict.remote(None, global_step)) + self._upload_state_dict(None, global_step) return if local_path is None: return @@ -125,32 +136,29 @@ def save_checkpoint( # noqa: C901 local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt" ) - if self.should_save_model: - model_state_dict = self.model.state_dict() - if self._model_state_dict_thread is not None: - self._model_state_dict_thread.join() - - def _save_model_state_dict(): - torch.save(model_state_dict, model_path) - log_with_rank( - f"Saved model to {os.path.abspath(model_path)}", - rank=self.rank, - logger=logger, - ) - if ( - self.synchronizer_config is not None - and self.synchronizer_config.sync_method == SyncMethod.CHECKPOINT - ): - ray.get( - self.synchronizer.set_model_state_dict_with_step_num.remote( - global_step, self.world_size - ) + if self.should_save_model or model_state_dict_only: + if os.path.exists(model_path): + if self._model_state_dict_thread is None: + # resume from checkpoint, so we can directly notify synchronizer. + self._notify_synchronizer_with_step_num(global_step) + else: + model_state_dict = self.model.state_dict() + if self._model_state_dict_thread is not None: + self._model_state_dict_thread.join() + + def _save_model_state_dict(): + torch.save(model_state_dict, model_path) + log_with_rank( + f"Saved model to {os.path.abspath(model_path)}", + rank=self.rank, + logger=logger, ) + self._notify_synchronizer_with_step_num(global_step) - self._model_state_dict_thread = threading.Thread( - target=_save_model_state_dict, - ) - self._model_state_dict_thread.start() + self._model_state_dict_thread = threading.Thread( + target=_save_model_state_dict, + ) + self._model_state_dict_thread.start() if self.should_save_optimizer and not model_state_dict_only: optimizer_state_dict = self.optimizer.state_dict() diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 8204dd4ced..d00fdf4c9d 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -156,6 +156,7 @@ def __init__( ) self.reset_experiences_example_table() self.logger = get_logger(__name__) + self.last_full_save_step = None def _validate_config(self): # TODO algorithm = ALGORITHM_TYPE.get(self.algorithm_config.algorithm_type) @@ -314,13 +315,6 @@ def train_step(self) -> bool: # noqa C901 prefix_metrics(sample_metrics, "sample", metrics) except StopIteration: print("No more data to train. Stop training.") - if ( - self.config.trainer.save_freq == 0 - or self.global_steps % self.config.trainer.save_freq != 0 - ): - self.logger.info(f"Saving at step {self.global_steps}.") - self._save_checkpoint() - self.logger.info(f"Saved at step {self.global_steps}.") return False self.global_steps += 1 self.logger.info(f"Sampling at step {self.global_steps} done.") @@ -408,7 +402,7 @@ def train_step(self) -> bool: # noqa C901 ): self.logger.info(f"Saving at step {self.global_steps}.") with marked_timer("save_checkpoint", timing_raw): - self._save_checkpoint() + self.save_checkpoint() self.logger.info(f"Saved at step {self.global_steps}.") self.logger.info(f"Training at step {self.global_steps} finished.") return train_status @@ -447,7 +441,9 @@ def _log_experiences(self, samples: List[Dict]) -> None: self.reset_experiences_example_table() def save_checkpoint(self) -> None: - self._save_checkpoint() + if self.last_full_save_step != self.global_steps: + self.last_full_save_step = self.global_steps + self._save_checkpoint() def sync_weight(self) -> None: self.actor_rollout_wg.sync_weight() @@ -464,6 +460,7 @@ def sft_to_rft(self) -> None: global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest # find global_step_folder + self.actor_rollout_wg.wait_for_saving() if self.config.trainer.resume_mode == "auto": if global_step_folder is None: print("Training from scratch") diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index ab363e5199..6ba4d7482d 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -110,8 +110,8 @@ def __init__( group = name self.logger = wandb.init( project=project, - group=f"{group}_{role}", - name=name, + group=group, + name=f"{name}_{role}", tags=[role], config=config, save_code=False, From c7e1e4c84b452899153a3ed32aef20f96a21fd75 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Thu, 24 Jul 2025 17:16:33 +0800 Subject: [PATCH 09/16] 1. Remove `update_with_checkpoint` 2. Rename `use_state_dict_weights_update` to `use_nccl_sync` 3. Add more explanation to `fsdp_checkpoint_manager`. --- tests/explorer/scheduler_test.py | 2 - trinity/common/models/vllm_model.py | 2 - trinity/common/models/vllm_worker.py | 4 +- trinity/explorer/explorer.py | 15 +++---- .../trainer/verl/fsdp_checkpoint_manager.py | 44 +++++++++++++++++-- 5 files changed, 49 insertions(+), 18 deletions(-) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 6a88e2a125..678be80a28 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -69,7 +69,6 @@ def init_process_group( group_name: str, backend: str = "nccl", timeout: int = 1200, - update_with_checkpoint: bool = True, ) -> None: pass @@ -91,7 +90,6 @@ def init_process_group( group_name: str, backend: str = "nccl", timeout: int = 1200, - update_with_checkpoint: bool = True, ) -> None: pass diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 8ab301efa1..6fdbbc208f 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -278,7 +278,6 @@ async def init_process_group( explorer_name: str, backend: str = "nccl", timeout: int = 1200, - update_with_checkpoint: bool = True, state_dict_meta: dict = None, ): return await self._collective_rpc( @@ -291,7 +290,6 @@ async def init_process_group( group_name, backend, timeout, - update_with_checkpoint, state_dict_meta, explorer_name, ray.get_runtime_context().namespace, diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 0688544e81..9835cd6d15 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -21,7 +21,6 @@ def init_process_group( group_name: str, backend: str = "nccl", timeout: int = 1200, - update_with_checkpoint: bool = True, state_dict_meta: list = None, explorer_name: str = None, namespace: str = None, @@ -30,10 +29,9 @@ def init_process_group( assert torch.distributed.is_initialized(), "default torch process group must be initialized" assert group_name != "", "group name must not be empty" self._state_dict_meta = state_dict_meta - self._update_with_checkpoint = update_with_checkpoint self._weight_update_rank = torch.distributed.get_rank() + rank_offset logger.info( - f"vLLM starting init_process_group ({'checkpoint' if self._update_with_checkpoint else 'nccl'}):\n" + f"vLLM starting init_process_group:\n" f" > address={master_address}:{master_port}\n" f" > rank={torch.distributed.get_rank()}\n" f" > rank_offset={rank_offset}\n" diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 7caa555f5c..313bec40aa 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -66,7 +66,7 @@ def __init__(self, config: Config): self.update_interval = ( self.config.synchronizer.sync_interval * self.config.buffer.batch_size ) - self.use_state_dict_weights_update = self.config.synchronizer.sync_method != SyncMethod.NCCL + self.use_nccl_sync = self.config.synchronizer.sync_method == SyncMethod.NCCL self.pending_eval_tasks = deque() # For checkpoint weights update @@ -89,7 +89,7 @@ async def setup_weight_sync_group( self, master_address: str, master_port: int, state_dict_meta: List = None ): # In checkpoint mode, we use explorer to store the model weights which has no rank - base_offset = 0 if self.use_state_dict_weights_update else 1 + base_offset = 1 if self.use_nccl_sync else 0 world_size = ( len(self.models) * self.config.explorer.rollout_model.tensor_parallel_size + base_offset ) @@ -109,7 +109,6 @@ async def setup_weight_sync_group( group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME, explorer_name=self.config.explorer.name, timeout=self.config.synchronizer.sync_timeout, - update_with_checkpoint=self.use_state_dict_weights_update, state_dict_meta=state_dict_meta, ) for i, model in enumerate(self.models) @@ -130,7 +129,7 @@ async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> in await asyncio.gather(*[model.sync_model.remote(step_num) for model in self.models]) return step_num # type: ignore - async def _state_dict_update(self): + async def _pull_latest_weights(self): self.logger.info("Start to update state dict.") new_version = ray.get( self.synchronizer.wait_new_model_state_dict.remote(self.model_version) @@ -180,7 +179,7 @@ async def prepare(self) -> None: if self.experience_buffer: await self.experience_buffer.acquire() futures = [asyncio.create_task(self.scheduler.start())] - if self.use_state_dict_weights_update: + if not self.use_nccl_sync: master_address, master_port = await self.models[0].get_available_address.remote() futures.append( asyncio.create_task(self.setup_weight_sync_group(master_address, master_port)) @@ -335,10 +334,10 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None: if sync_weight: # sync weights self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} started.") - if self.use_state_dict_weights_update: - await self._state_dict_update() - else: # nccl weights update + if self.use_nccl_sync: await self._nccl_weights_update() + else: # pull weights from Synchronizer + await self._pull_latest_weights() self.logger.info( f"Explorer sync_weights at step {self.explore_step_num} finished, model version = {self.model_version}." ) diff --git a/trinity/trainer/verl/fsdp_checkpoint_manager.py b/trinity/trainer/verl/fsdp_checkpoint_manager.py index 5f70fb7d6c..e558b025b2 100644 --- a/trinity/trainer/verl/fsdp_checkpoint_manager.py +++ b/trinity/trainer/verl/fsdp_checkpoint_manager.py @@ -33,20 +33,39 @@ class FSDPCheckpointManager(OldFSDPCheckpointManager): + """ + An enhanced version of the original FSDP checkpoint manager that: + + 1. Uploads model state dicts to a remote Synchronizer actor (either directly or via checkpoints). + 2. Offloads saving operations (model, optimizer, extra states) into background threads to avoid blocking the training loop. + + This class is useful in distributed training scenarios where synchronization and non-blocking I/O are important. + """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) config = kwargs.pop("config", None) self.synchronizer_config = config if config is not None: + # Retrieve the remote Synchronizer actor using the provided namespace self.synchronizer = Synchronizer.get_actor(namespace=config.ray_namespace) else: self.synchronizer = None + + # Threads for asynchronous saving of different components self._model_state_dict_thread = None self._optimizer_state_dict_thread = None self._extra_state_dict_thread = None self._save_model_thread = None def _notify_synchronizer_with_step_num(self, global_step): + """ + Notifies the Synchronizer actor about the current training step number, + used when SyncMethod is CHECKPOINT. + + Args: + global_step (int): The current global training step. + """ if getattr(self.synchronizer_config, "sync_method", None) == SyncMethod.CHECKPOINT: ray.get( self.synchronizer.set_model_state_dict_with_step_num.remote( @@ -55,6 +74,13 @@ def _notify_synchronizer_with_step_num(self, global_step): ) def _upload_state_dict(self, state_dict: Union[dict, None], global_step: int): + """ + Internal method to upload a state dict to the Synchronizer actor. + + Args: + state_dict (dict or None): The model state dictionary to upload. + global_step (int): The current training step number. + """ if self.rank == 0: ray.get(self.synchronizer.set_model_state_dict.remote(state_dict, global_step)) @@ -80,9 +106,21 @@ def save_checkpoint( # noqa: C901 model_state_dict_only: bool = False, ): """ - modified from verl.utils.checkpoint.fsdp_checkpoint_manager.py:save_checkpoint + Modified from verl.utils.checkpoint.fsdp_checkpoint_manager.py:save_checkpoint + + Saves the model checkpoint to disk, optionally uploads it to a remote Synchronizer, + and uses background threads to prevent blocking the main training loop. - The main improvement is using separate thread to save checkpoint and the implementation of checkpoint sync method. + Main improvements over the base class: + - Uses separate threads for saving model/optimizer/extras. + - Implements synchronization with a remote actor. + + Args: + local_path (str): Local directory path to save the checkpoint. + hdfs_path (str, optional): HDFS path for saving the checkpoint (not implemented here). + global_step (int): Current training step. + max_ckpt_to_keep (int, optional): Maximum number of checkpoints to keep locally. + model_state_dict_only (bool): Whether to only save the model state dict (no optimizer, etc.). """ if global_step == 0 and model_state_dict_only: self._upload_state_dict(None, global_step) @@ -139,7 +177,7 @@ def save_checkpoint( # noqa: C901 if self.should_save_model or model_state_dict_only: if os.path.exists(model_path): if self._model_state_dict_thread is None: - # resume from checkpoint, so we can directly notify synchronizer. + # If resuming from a checkpoint, notify synchronizer immediately self._notify_synchronizer_with_step_num(global_step) else: model_state_dict = self.model.state_dict() From a0b8320994761c81b4d957d929c60eace432cdcc Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Thu, 24 Jul 2025 18:25:54 +0800 Subject: [PATCH 10/16] 1. Add `block_until_saved` for checkpoint saving. 2. Rename `wait_for_saving` to `wait_on_save_thread`. --- trinity/trainer/trainer.py | 11 +++-------- trinity/trainer/verl/fsdp_checkpoint_manager.py | 2 +- trinity/trainer/verl/fsdp_workers.py | 8 ++++---- trinity/trainer/verl_trainer.py | 13 ++++++------- 4 files changed, 14 insertions(+), 20 deletions(-) diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index ff1248f8d6..21c9d44f12 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -43,6 +43,8 @@ def train(self) -> str: except Exception: self.logger.error(f"Error in Trainer:\n{traceback.format_exc()}") break + ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.STOPPED)) + self.engine.save_checkpoint(block_until_saved=True) self.logger.info("--------------------\n> Trainer finished.\n--------------------") return self.config.trainer.name @@ -93,10 +95,7 @@ def sync_weight(self) -> None: ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)) def shutdown(self) -> None: - ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.STOPPED)) - self.engine.save_checkpoint() self.engine.monitor.close() - self.engine.shutdown() class TrainEngineWrapper(ABC): @@ -116,7 +115,7 @@ def train_step(self) -> bool: """Training.""" @abstractmethod - def save_checkpoint(self) -> None: + def save_checkpoint(self, block_until_saved: bool = False) -> None: """Save the checkpoint.""" @abstractmethod @@ -131,10 +130,6 @@ def upload_state_dict(self) -> None: def save_state_dict(self) -> None: """Only save the model state dict for Synchronizer.""" - @abstractmethod - def shutdown(self) -> None: - """Shutdown the engine.""" - def get_trainer_wrapper(config: Config) -> TrainEngineWrapper: """Get a trainer wrapper.""" diff --git a/trinity/trainer/verl/fsdp_checkpoint_manager.py b/trinity/trainer/verl/fsdp_checkpoint_manager.py index e558b025b2..2e9b65312c 100644 --- a/trinity/trainer/verl/fsdp_checkpoint_manager.py +++ b/trinity/trainer/verl/fsdp_checkpoint_manager.py @@ -349,7 +349,7 @@ def _save_model(): if not model_state_dict_only: self.previous_saved_paths.append(local_path) - def wait_for_saving(self) -> None: + def wait_on_save_thread(self) -> None: """ Wait for all background saving threads to complete. """ diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 5e858aacd7..71f45b391f 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -862,8 +862,8 @@ def clear_optimizer_state(self): offload_fsdp_optimizer(self.actor_optimizer) @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def wait_for_saving(self) -> None: - self.checkpoint_manager.wait_for_saving() + def wait_on_save_thread(self) -> None: + self.checkpoint_manager.wait_on_save_thread() class CriticWorker(Worker): @@ -1292,5 +1292,5 @@ def clear_optimizer_state(self): offload_fsdp_optimizer(self.critic_optimizer) @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def wait_for_saving(self) -> None: - self.checkpoint_manager.wait_for_saving() + def wait_on_save_thread(self) -> None: + self.checkpoint_manager.wait_on_save_thread() diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index d00fdf4c9d..9114018036 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -440,10 +440,14 @@ def _log_experiences(self, samples: List[Dict]) -> None: ) self.reset_experiences_example_table() - def save_checkpoint(self) -> None: + def save_checkpoint(self, block_until_saved: bool = False) -> None: if self.last_full_save_step != self.global_steps: self.last_full_save_step = self.global_steps self._save_checkpoint() + if block_until_saved: + self.actor_rollout_wg.wait_on_save_thread() + if self.algorithm.use_critic: + self.critic_wg.wait_on_save_thread() def sync_weight(self) -> None: self.actor_rollout_wg.sync_weight() @@ -460,7 +464,7 @@ def sft_to_rft(self) -> None: global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest # find global_step_folder - self.actor_rollout_wg.wait_for_saving() + self.actor_rollout_wg.wait_on_save_thread() if self.config.trainer.resume_mode == "auto": if global_step_folder is None: print("Training from scratch") @@ -491,8 +495,3 @@ def sft_to_rft(self) -> None: if self.use_critic: self.critic_wg.clear_optimizer_state() print("sft to rft finished") - - def shutdown(self) -> None: - self.actor_rollout_wg.wait_for_saving() - if self.algorithm.use_critic: - self.critic_wg.wait_for_saving() From efcaf63e0f8ec5d4d6a094874fd7882be8259371 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Fri, 25 Jul 2025 17:56:33 +0800 Subject: [PATCH 11/16] Add synchronizer test --- tests/common/synchronizer_test.py | 255 +++++++++++++++++++++++++++ trinity/common/config.py | 2 +- trinity/common/synchronizer.py | 20 ++- trinity/explorer/explorer.py | 11 +- trinity/trainer/verl/fsdp_workers.py | 5 +- 5 files changed, 276 insertions(+), 17 deletions(-) create mode 100644 tests/common/synchronizer_test.py diff --git a/tests/common/synchronizer_test.py b/tests/common/synchronizer_test.py new file mode 100644 index 0000000000..eca06f3382 --- /dev/null +++ b/tests/common/synchronizer_test.py @@ -0,0 +1,255 @@ +# -*- coding: utf-8 -*- +"""Test cases for Synchronizer modules.""" + +import asyncio +import multiprocessing +import os +import shutil +import time +import unittest +from copy import deepcopy +from datetime import datetime +from typing import List + +import ray + +from tests.tools import ( + TensorBoardParser, + get_checkpoint_path, + get_model_path, + get_template_config, + get_unittest_dataset_config, +) +from trinity.algorithm.algorithm import ALGORITHM_TYPE +from trinity.cli.launcher import both, explore, train +from trinity.common.config import Config, StorageConfig +from trinity.common.constants import StorageType, SyncMethod, SyncStyle +from trinity.explorer.explorer import Explorer +from trinity.trainer.trainer import Trainer +from trinity.utils.log import get_logger + +logger = get_logger(__name__) +CHECKPOINT_ROOT_DIR = os.path.join(os.path.dirname(__file__), "temp_checkpoint_dir") + + +def trainer_monkey_patch(config: Config, max_steps: int, intervals: List[int]): + def new_train_step(self): + self.engine.algorithm = ALGORITHM_TYPE.get(config.algorithm.algorithm_type) + self.engine.global_steps += 1 + self.logger.info(f"Training at step {self.engine.global_steps} started.") + time.sleep(intervals[self.engine.global_steps - 1]) + metrics = {"actor/step": self.engine.global_steps} + self.engine.monitor.log(data=metrics, step=self.engine.global_steps) + self.logger.info(f"Training at step {self.engine.global_steps} finished.") + return self.engine.global_steps < max_steps + + Trainer.train_step = new_train_step + + +def explorer_monkey_patch(config: Config, max_steps: int, intervals: List[int]): + async def new_explore_step(self): + self.explore_step_num += 1 + return self.explore_step_num <= max_steps + + def wrapper(old_save_checkpoint): + async def new_save_checkpoint(self, sync_weight: bool = False): + await asyncio.sleep(intervals.pop(0)) + await old_save_checkpoint(self, sync_weight) + + return new_save_checkpoint + + Explorer.explore_step = new_explore_step + Explorer.save_checkpoint = wrapper(Explorer.save_checkpoint) + + +def run_trainer(config: Config, max_steps: int, intervals: List[int]) -> None: + ray.init(ignore_reinit_error=True, namespace=config.ray_namespace) + trainer_monkey_patch(config, max_steps, intervals) + train(config) + + +def run_explorer(config: Config, max_steps: int, intervals: List[int]) -> None: + ray.init(ignore_reinit_error=True, namespace=config.ray_namespace) + explorer_monkey_patch(config, max_steps, intervals) + explore(config) + + +def run_both( + config: Config, max_steps: int, trainer_intervals: List[int], explorer_intervals: List[int] +) -> None: + ray.init(ignore_reinit_error=True, namespace=config.ray_namespace) + trainer_monkey_patch(config, max_steps, trainer_intervals) + explorer_monkey_patch(config, max_steps, explorer_intervals) + both(config) + + +class TestSynchronizer(unittest.TestCase): + def setUp(self): + if multiprocessing.get_start_method(allow_none=True) != "spawn": + multiprocessing.set_start_method("spawn", force=True) + + def test_checkpoint_method_fixed_style(self): + config = get_template_config() + config.project = "unittest" + config.name = f"test_synchronizer_{datetime.now().strftime('%Y%m%d%H%M%S')}" + config.checkpoint_root_dir = get_checkpoint_path() + config.buffer.total_epochs = 1 + config.buffer.batch_size = 4 + config.cluster.gpu_per_node = 2 + config.cluster.node_num = 1 + config.model.model_path = get_model_path() + config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") + config.buffer.trainer_input.experience_buffer = StorageConfig( + name="exp_buffer", + storage_type=StorageType.QUEUE, + wrap_in_ray=True, + ) + config.synchronizer.sync_method = SyncMethod.CHECKPOINT + config.synchronizer.sync_style = SyncStyle.FIXED + config.synchronizer.sync_interval = 2 + config.trainer.save_interval = 100 + config.monitor.monitor_type = "tensorboard" + trainer_config = deepcopy(config) + trainer_config.mode = "train" + trainer_config.check_and_update() + + explorer1_config = deepcopy(config) + explorer1_config.mode = "explore" + explorer1_config.explorer.name = "explorer1" + config.cluster.gpu_per_node = 1 + config.cluster.node_num = 1 + explorer1_config.explorer.rollout_model.engine_num = 1 + explorer1_config.explorer.rollout_model.tensor_parallel_size = 1 + explorer1_config.buffer.explorer_output = StorageConfig( + name="exp_buffer", + storage_type=StorageType.QUEUE, + wrap_in_ray=True, + ) + explorer2_config = deepcopy(explorer1_config) + explorer2_config.explorer.name = "explorer2" + explorer1_config.check_and_update() + explorer2_config.check_and_update() + + trainer_process = multiprocessing.Process( + target=run_trainer, args=(trainer_config, 8, [2, 1, 2, 1, 2, 1, 2, 1]) + ) + trainer_process.start() + explorer_process_1 = multiprocessing.Process( + target=run_explorer, args=(explorer1_config, 8, [0, 2.5, 2.5, 2.5, 2.5]) + ) + explorer_process_1.start() + explorer_process_2 = multiprocessing.Process( + target=run_explorer, args=(explorer2_config, 8, [0, 0.5, 0.5, 0.5, 0.5]) + ) + explorer_process_2.start() + + explorer_process_1.join(timeout=200) + explorer_process_2.join(timeout=200) + trainer_process.join(timeout=200) + + # check the tensorboard + parser = TensorBoardParser( + os.path.join(trainer_config.monitor.cache_dir, "tensorboard", "trainer") + ) + actor_metrics = parser.metric_list("actor") + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8) + parser = TensorBoardParser( + os.path.join(explorer1_config.monitor.cache_dir, "tensorboard", "explorer1") + ) + rollout_metrics = parser.metric_list("rollout") + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) + parser = TensorBoardParser( + os.path.join(explorer2_config.monitor.cache_dir, "tensorboard", "explorer2") + ) + rollout_metrics = parser.metric_list("rollout") + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) + + def test_nccl_method_fixed_style(self): + config = get_template_config() + config.project = "unittest" + config.name = f"test_synchronizer_{datetime.now().strftime('%Y%m%d%H%M%S')}" + config.checkpoint_root_dir = get_checkpoint_path() + config.buffer.total_epochs = 1 + config.buffer.batch_size = 4 + config.cluster.gpu_per_node = 4 + config.cluster.node_num = 1 + config.model.model_path = get_model_path() + config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") + config.buffer.trainer_input.experience_buffer = StorageConfig( + name="exp_buffer", + storage_type=StorageType.QUEUE, + wrap_in_ray=True, + ) + config.synchronizer.sync_method = SyncMethod.NCCL + config.synchronizer.sync_style = SyncStyle.FIXED + config.synchronizer.sync_interval = 2 + config.trainer.save_interval = 100 + config.explorer.rollout_model.engine_num = 2 + config.explorer.rollout_model.tensor_parallel_size = 1 + config.monitor.monitor_type = "tensorboard" + config.mode = "both" + config.check_and_update() + + # TODO: test more interval cases + both_process = multiprocessing.Process( + target=run_both, args=(config, 8, [2, 1, 2, 1, 2, 1, 2, 1], [0, 2.5, 2.5, 2.5, 2.5]) + ) + both_process.start() + both_process.join(timeout=200) + + # check the tensorboard + parser = TensorBoardParser(os.path.join(config.monitor.cache_dir, "tensorboard", "trainer")) + actor_metrics = parser.metric_list("actor") + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8) + parser = TensorBoardParser( + os.path.join(config.monitor.cache_dir, "tensorboard", "explorer") + ) + rollout_metrics = parser.metric_list("rollout") + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) + + def test_nccl_method_dynamic_by_explorer_style(self): + config = get_template_config() + config.project = "unittest" + config.name = f"test_synchronizer_{datetime.now().strftime('%Y%m%d%H%M%S')}" + config.checkpoint_root_dir = get_checkpoint_path() + config.buffer.total_epochs = 1 + config.buffer.batch_size = 4 + config.cluster.gpu_per_node = 4 + config.cluster.node_num = 1 + config.model.model_path = get_model_path() + config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") + config.buffer.trainer_input.experience_buffer = StorageConfig( + name="exp_buffer", + storage_type=StorageType.QUEUE, + wrap_in_ray=True, + ) + config.synchronizer.sync_method = SyncMethod.NCCL + config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER + config.synchronizer.sync_interval = 2 + config.trainer.save_interval = 100 + config.explorer.rollout_model.engine_num = 2 + config.explorer.rollout_model.tensor_parallel_size = 1 + config.monitor.monitor_type = "tensorboard" + config.mode = "both" + config.check_and_update() + + # TODO: test more interval cases + both_process = multiprocessing.Process( + target=run_both, args=(config, 8, [2, 1, 2, 1, 2, 1, 2, 1], [0, 0.5, 0.5, 0.5, 0.5]) + ) + both_process.start() + both_process.join(timeout=200) + + # check the tensorboard + parser = TensorBoardParser(os.path.join(config.monitor.cache_dir, "tensorboard", "trainer")) + actor_metrics = parser.metric_list("actor") + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8) + parser = TensorBoardParser( + os.path.join(config.monitor.cache_dir, "tensorboard", "explorer") + ) + rollout_metrics = parser.metric_list("rollout") + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) + + def tearDown(self): + if os.path.exists(CHECKPOINT_ROOT_DIR): + shutil.rmtree(CHECKPOINT_ROOT_DIR) diff --git a/trinity/common/config.py b/trinity/common/config.py index 9227ee7c15..c752e5bdc9 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -745,7 +745,7 @@ def check_and_update(self) -> None: # noqa: C901 ) if ( self.mode in ["train", "explore", "bench"] - and self.synchronizer.sync_method != SyncMethod.CHECKPOINT + and self.synchronizer.sync_method == SyncMethod.NCCL ): self.synchronizer.sync_method = SyncMethod.CHECKPOINT logger.warning( diff --git a/trinity/common/synchronizer.py b/trinity/common/synchronizer.py index 144edb5298..17c870a72d 100644 --- a/trinity/common/synchronizer.py +++ b/trinity/common/synchronizer.py @@ -40,11 +40,14 @@ def __init__(self, config: Config): self.model_version = 0 self.checkpoint_shard_counter = defaultdict(lambda: 0) - def set_trainer_status(self, status: RunningStatus): + async def set_trainer_status(self, status: RunningStatus): """Update the status of the trainer.""" - self.trainer_status = status + async with self._ready_condition: + self.trainer_status = status + if status == RunningStatus.STOPPED: + self._ready_condition.notify_all() - def get_trainer_status(self) -> RunningStatus: + async def get_trainer_status(self) -> RunningStatus: """Get the current status of the trainer.""" return self.trainer_status @@ -155,7 +158,7 @@ async def setup_weight_sync_group( master_port: Port used for synchronization. state_dict_meta: Metadata of the model parameters. """ - explorer = ray.get_actor(self.config.explorer_name) + explorer = ray.get_actor(self.config.explorer.name, namespace=self.config.ray_namespace) await explorer.setup_weight_sync_group.remote(master_address, master_port, state_dict_meta) async def wait_new_model_state_dict(self, current_version: int, no_wait: bool = False) -> int: @@ -214,9 +217,7 @@ def sync_failed(): else: another_module = "Explorer" self.trainer_status = RunningStatus.REQUIRE_SYNC - self.logger.error( - f"{another_module} is not ready for model weight sync in {self.config.synchronizer.sync_timeout} seconds." - ) + self.logger.error(f"{another_module} is not ready for model weight sync.") return None non_stop_cnt = sum( @@ -249,10 +250,13 @@ def sync_failed(): if self.trainer_status != RunningStatus.WAITING_SYNC: await asyncio.wait_for( self._ready_condition.wait_for( - lambda: self.trainer_status == RunningStatus.WAITING_SYNC, + lambda: self.trainer_status + in {RunningStatus.WAITING_SYNC, RunningStatus.STOPPED}, ), timeout=self.config.synchronizer.sync_timeout, ) + if self.trainer_status == RunningStatus.STOPPED: + return sync_failed() return self.model_version except asyncio.TimeoutError: return sync_failed() diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 313bec40aa..d52dc9052e 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -130,13 +130,13 @@ async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> in return step_num # type: ignore async def _pull_latest_weights(self): - self.logger.info("Start to update state dict.") + self.logger.info("Start to pull latest model weights.") new_version = ray.get( self.synchronizer.wait_new_model_state_dict.remote(self.model_version) ) if new_version > self.model_version: if self.model_version != -1: - self.logger.info(f"New model state dict version: {new_version}") + self.logger.info(f"New model weights version: {new_version}") await asyncio.gather( *[model.sync_model.remote(new_version) for model in self.models] ) @@ -150,7 +150,7 @@ async def _pull_latest_weights(self): self.last_sync_successful = True else: self.logger.warning( - f"No new model state dict found, current version: {self.model_version}" + f"No new model weights found, current version: {self.model_version}" ) self.last_sync_successful = False @@ -364,15 +364,14 @@ async def _finish_steps(self, start_step: int, end_step: int, model_version: int async def _finish_explore_step(self, step: int, model_version: int) -> None: statuses, exps = await self.scheduler.get_results(batch_id=step) - metric = {} + metric = {"rollout/model_version": model_version} if self.config.explorer.collect_experiences: exp_cnt = await self.add_strategy.add(exps, step) self.generated_experience_cnt += exp_cnt metric["rollout/experience_count"] = exp_cnt if statuses: metric.update(gather_metrics([status.metric for status in statuses], "rollout")) - metric["rollout/model_version"] = model_version - self.monitor.log(metric, step=step) + self.monitor.log(metric, step=step) async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None: if not self.pending_eval_tasks: diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 71f45b391f..e85c8ef540 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -77,6 +77,7 @@ from trinity.common.config import AlgorithmConfig from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod +from trinity.common.synchronizer import Synchronizer from trinity.trainer.verl.fsdp_checkpoint_manager import FSDPCheckpointManager from trinity.utils.distributed import init_process_group @@ -592,8 +593,8 @@ def setup_weight_sync_group(self): master_address, master_port = self.get_availale_master_addr_port() world_size = self.config.synchronizer.explorer_world_size + 1 print(f"Trainer init_process_group {master_address}:{master_port} ({world_size}).") - explorer = ray.get_actor(self.config.explorer_name) - setup_ref = explorer.setup_weight_sync_group.remote( + synchronizer = Synchronizer.get_actor(self.config.synchronizer) + setup_ref = synchronizer.setup_weight_sync_group.remote( master_address, master_port, self.state_dict_meta ) timeout = self.config.synchronizer.sync_timeout From 3724919ec077d9299b5b6b19f831fc32867eb385 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 29 Jul 2025 10:16:12 +0800 Subject: [PATCH 12/16] bug fix for queue and sync test --- tests/common/synchronizer_test.py | 5 +++++ trinity/buffer/queue.py | 4 ++++ trinity/explorer/explorer.py | 2 +- trinity/trainer/verl_trainer.py | 12 ++++++------ 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/common/synchronizer_test.py b/tests/common/synchronizer_test.py index eca06f3382..c1b9b8e447 100644 --- a/tests/common/synchronizer_test.py +++ b/tests/common/synchronizer_test.py @@ -58,8 +58,13 @@ async def new_save_checkpoint(self, sync_weight: bool = False): return new_save_checkpoint + async def new_finish_explore_step(self, step: int, model_version: int) -> None: + metric = {"rollout/model_version": model_version} + self.monitor.log(metric, step=step) + Explorer.explore_step = new_explore_step Explorer.save_checkpoint = wrapper(Explorer.save_checkpoint) + Explorer._finish_explore_step = new_finish_explore_step def run_trainer(config: Config, max_steps: int, intervals: List[int]) -> None: diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py index e5d76e57d4..e28af726a8 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/queue.py @@ -71,6 +71,10 @@ def __init__(self, capacity: int): async def close(self) -> None: """Close the queue.""" self._closed = True + for getter in self._getters: + if not getter.done(): + getter.set_exception(StopAsyncIteration()) + self._getters.clear() def stopped(self) -> bool: """Check if there is no more data to read.""" diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index d52dc9052e..9b244fb30a 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -371,7 +371,7 @@ async def _finish_explore_step(self, step: int, model_version: int) -> None: metric["rollout/experience_count"] = exp_cnt if statuses: metric.update(gather_metrics([status.metric for status in statuses], "rollout")) - self.monitor.log(metric, step=step) + self.monitor.log(metric, step=step) async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None: if not self.pending_eval_tasks: diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 9114018036..16bcd63870 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -314,7 +314,7 @@ def train_step(self) -> bool: # noqa C901 batch, sample_metrics, exp_samples = self.sample_strategy.sample(self.global_steps + 1) prefix_metrics(sample_metrics, "sample", metrics) except StopIteration: - print("No more data to train. Stop training.") + self.logger.info("No more data to train. Stop training.") return False self.global_steps += 1 self.logger.info(f"Sampling at step {self.global_steps} done.") @@ -467,7 +467,7 @@ def sft_to_rft(self) -> None: self.actor_rollout_wg.wait_on_save_thread() if self.config.trainer.resume_mode == "auto": if global_step_folder is None: - print("Training from scratch") + self.logger.info("Training from scratch") return else: if not (self.config.trainer.resume_from_path and global_step_folder is not None): @@ -481,17 +481,17 @@ def sft_to_rft(self) -> None: if not os.path.isabs(global_step_folder): working_dir = os.getcwd() global_step_folder = os.path.join(working_dir, global_step_folder) - print(f"Load from checkpoint folder: {global_step_folder}") + self.logger.info(f"Load from checkpoint folder: {global_step_folder}") # set global step global_steps = int(global_step_folder.split("global_step_")[-1]) assert self.global_steps == global_steps + 1 - print(f"Resuming from {global_step_folder}") + self.logger.info(f"Resuming from {global_step_folder}") actor_path = os.path.join(global_step_folder, "actor") - print(f"Loading actor from {actor_path} to ref_policy_wg") + self.logger.info(f"Loading actor from {actor_path} to ref_policy_wg") self.ref_policy_wg.load_checkpoint(actor_path, del_local_after_load=False) self.actor_rollout_wg.clear_optimizer_state() if self.use_critic: self.critic_wg.clear_optimizer_state() - print("sft to rft finished") + self.logger.info("sft to rft finished") From 1327873fa0f5d44e85ea3d7cd151912eb287c6d9 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 29 Jul 2025 17:21:11 +0800 Subject: [PATCH 13/16] add `lifetime="detached"` to Synchronizer --- tests/common/synchronizer_test.py | 44 +++++++++++++++++++++++++------ trinity/common/synchronizer.py | 15 ++++++++++- trinity/explorer/explorer.py | 8 +++++- trinity/trainer/trainer.py | 7 ++++- 4 files changed, 63 insertions(+), 11 deletions(-) diff --git a/tests/common/synchronizer_test.py b/tests/common/synchronizer_test.py index c1b9b8e447..3db087ce9b 100644 --- a/tests/common/synchronizer_test.py +++ b/tests/common/synchronizer_test.py @@ -12,6 +12,7 @@ from typing import List import ray +from parameterized import parameterized from tests.tools import ( TensorBoardParser, @@ -39,7 +40,7 @@ def new_train_step(self): self.logger.info(f"Training at step {self.engine.global_steps} started.") time.sleep(intervals[self.engine.global_steps - 1]) metrics = {"actor/step": self.engine.global_steps} - self.engine.monitor.log(data=metrics, step=self.engine.global_steps) + self.monitor.log(data=metrics, step=self.engine.global_steps) self.logger.info(f"Training at step {self.engine.global_steps} finished.") return self.engine.global_steps < max_steps @@ -48,6 +49,8 @@ def new_train_step(self): def explorer_monkey_patch(config: Config, max_steps: int, intervals: List[int]): async def new_explore_step(self): + if self.explore_step_num == max_steps: + await self.save_checkpoint(sync_weight=False) self.explore_step_num += 1 return self.explore_step_num <= max_steps @@ -93,7 +96,31 @@ def setUp(self): if multiprocessing.get_start_method(allow_none=True) != "spawn": multiprocessing.set_start_method("spawn", force=True) - def test_checkpoint_method_fixed_style(self): + @parameterized.expand( + [ + ( + "checkpoint_fixed", + SyncMethod.CHECKPOINT, + SyncStyle.FIXED, + ), + ( + "checkpoint_dynamic_by_explorer", + SyncMethod.CHECKPOINT, + SyncStyle.DYNAMIC_BY_EXPLORER, + ), + ( + "memory_fixed", + SyncMethod.MEMORY, + SyncStyle.FIXED, + ), + ( + "memory_dynamic_by_explorer", + SyncMethod.MEMORY, + SyncStyle.DYNAMIC_BY_EXPLORER, + ), + ] + ) + def test_state_dict_based_sync(self, name, sync_method, sync_style): config = get_template_config() config.project = "unittest" config.name = f"test_synchronizer_{datetime.now().strftime('%Y%m%d%H%M%S')}" @@ -109,8 +136,8 @@ def test_checkpoint_method_fixed_style(self): storage_type=StorageType.QUEUE, wrap_in_ray=True, ) - config.synchronizer.sync_method = SyncMethod.CHECKPOINT - config.synchronizer.sync_style = SyncStyle.FIXED + config.synchronizer.sync_method = sync_method + config.synchronizer.sync_style = sync_style config.synchronizer.sync_interval = 2 config.trainer.save_interval = 100 config.monitor.monitor_type = "tensorboard" @@ -140,11 +167,12 @@ def test_checkpoint_method_fixed_style(self): ) trainer_process.start() explorer_process_1 = multiprocessing.Process( - target=run_explorer, args=(explorer1_config, 8, [0, 2.5, 2.5, 2.5, 2.5]) + target=run_explorer, + args=(explorer1_config, 8, [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5]), ) explorer_process_1.start() explorer_process_2 = multiprocessing.Process( - target=run_explorer, args=(explorer2_config, 8, [0, 0.5, 0.5, 0.5, 0.5]) + target=run_explorer, args=(explorer2_config, 8, [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]) ) explorer_process_2.start() @@ -197,7 +225,7 @@ def test_nccl_method_fixed_style(self): # TODO: test more interval cases both_process = multiprocessing.Process( - target=run_both, args=(config, 8, [2, 1, 2, 1, 2, 1, 2, 1], [0, 2.5, 2.5, 2.5, 2.5]) + target=run_both, args=(config, 8, [2, 1, 2, 1, 2, 1, 2, 1], [0, 2.5, 2.5, 2.5, 2.5, 0]) ) both_process.start() both_process.join(timeout=200) @@ -240,7 +268,7 @@ def test_nccl_method_dynamic_by_explorer_style(self): # TODO: test more interval cases both_process = multiprocessing.Process( - target=run_both, args=(config, 8, [2, 1, 2, 1, 2, 1, 2, 1], [0, 0.5, 0.5, 0.5, 0.5]) + target=run_both, args=(config, 8, [2, 1, 2, 1, 2, 1, 2, 1], [0, 0.5, 0.5, 0.5, 0.5, 0]) ) both_process.start() both_process.join(timeout=200) diff --git a/trinity/common/synchronizer.py b/trinity/common/synchronizer.py index 17c870a72d..c8cb827f46 100644 --- a/trinity/common/synchronizer.py +++ b/trinity/common/synchronizer.py @@ -39,6 +39,7 @@ def __init__(self, config: Config): self.model_state_dict = None self.model_version = 0 self.checkpoint_shard_counter = defaultdict(lambda: 0) + self.ref_count = 0 async def set_trainer_status(self, status: RunningStatus): """Update the status of the trainer.""" @@ -276,7 +277,19 @@ def get_actor(cls, config: Optional[Config] = None, namespace: Optional[str] = N if config is not None: return ( ray.remote(cls) - .options(name="synchronizer", namespace=config.ray_namespace, get_if_exists=True) + .options( + name="synchronizer", + namespace=config.ray_namespace, + get_if_exists=True, + lifetime="detached", + ) .remote(config) ) return ray.get_actor("synchronizer", namespace=namespace) + + def acquire(self): + self.ref_count += 1 + + def release(self): + self.ref_count -= 1 + return self.ref_count diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 9b244fb30a..65fd7a15f6 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -41,6 +41,7 @@ def __init__(self, config: Config): self.explore_step_num = explorer_meta.get("latest_iteration", 0) self.last_sync_step = self.explore_step_num if self.explore_step_num > 0 else -1 self.synchronizer = Synchronizer.get_actor(config) + ray.get(self.synchronizer.acquire.remote()) self.config = config self.algorithm_manager = AlgorithmManager(config) self.models, self.auxiliary_models = create_inference_models(config) @@ -219,7 +220,9 @@ async def explore(self) -> str: except Exception: self.logger.error(f"Error in Explorer: {traceback.format_exc()}") break - self.logger.info("--------------------\n> Explorer finished.\n--------------------") + self.logger.info( + f"--------------------\n> Explorer ({self.config.explorer.name}) finished.\n--------------------" + ) return self.config.explorer.name async def explore_step(self) -> bool: @@ -395,4 +398,7 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva async def shutdown(self) -> None: self.monitor.close() + if ray.get(self.synchronizer.release.remote()) == 0: + ray.kill(self.synchronizer) + self.logger.info("Synchronizer stopped.") await self.scheduler.stop() diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index fdb3d195c1..c4673c2023 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -28,10 +28,12 @@ def __init__(self, config: Config) -> None: self.config = config self.logger = get_logger(__name__) self.synchronizer = Synchronizer.get_actor(config) + ray.get(self.synchronizer.acquire.remote()) self.engine = get_trainer_wrapper(config) self.last_trainer_sync_step = 0 self.monitor = MONITOR.get(config.monitor.monitor_type)( project=config.project, + group=self.config.group, name=config.name, role=config.trainer.name, config=config, @@ -139,7 +141,10 @@ def _log_experiences(self, samples: List[Dict]) -> None: self._sample_exps_to_log.clear() def shutdown(self) -> None: - self.engine.monitor.close() + self.monitor.close() + if ray.get(self.synchronizer.release.remote()) == 0: + ray.kill(self.synchronizer) + self.logger.info("Synchronizer stopped.") @property def train_step_num(self) -> int: From 7f2ebc3d4dbe468a86eb33494f9fd7df60fc45b0 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 29 Jul 2025 19:15:07 +0800 Subject: [PATCH 14/16] doc fix and fix test --- tests/common/synchronizer_test.py | 157 +++++++++--------- trinity/common/synchronizer.py | 32 ++-- trinity/trainer/trainer.py | 8 +- .../trainer/verl/fsdp_checkpoint_manager.py | 2 +- 4 files changed, 100 insertions(+), 99 deletions(-) diff --git a/tests/common/synchronizer_test.py b/tests/common/synchronizer_test.py index 3db087ce9b..efcf44ac8e 100644 --- a/tests/common/synchronizer_test.py +++ b/tests/common/synchronizer_test.py @@ -12,7 +12,7 @@ from typing import List import ray -from parameterized import parameterized +from parameterized import parameterized_class from tests.tools import ( TensorBoardParser, @@ -91,36 +91,62 @@ def run_both( both(config) -class TestSynchronizer(unittest.TestCase): +class BaseTestSynchronizer(unittest.TestCase): def setUp(self): if multiprocessing.get_start_method(allow_none=True) != "spawn": multiprocessing.set_start_method("spawn", force=True) - @parameterized.expand( - [ - ( - "checkpoint_fixed", - SyncMethod.CHECKPOINT, - SyncStyle.FIXED, - ), - ( - "checkpoint_dynamic_by_explorer", - SyncMethod.CHECKPOINT, - SyncStyle.DYNAMIC_BY_EXPLORER, - ), - ( - "memory_fixed", - SyncMethod.MEMORY, - SyncStyle.FIXED, - ), - ( - "memory_dynamic_by_explorer", - SyncMethod.MEMORY, - SyncStyle.DYNAMIC_BY_EXPLORER, - ), - ] - ) - def test_state_dict_based_sync(self, name, sync_method, sync_style): + def tearDown(self): + checkpoint_path = get_checkpoint_path() + shutil.rmtree(os.path.join(checkpoint_path, "unittest")) + + +@parameterized_class( + ( + "sync_method", + "sync_style", + "max_steps", + "trainer_intervals", + "explorer1_intervals", + "explorer2_intervals", + ), + [ + ( + SyncMethod.CHECKPOINT, + SyncStyle.FIXED, + 8, + [2, 1, 2, 1, 2, 1, 2, 1], + [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5], + [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + ), + ( + SyncMethod.CHECKPOINT, + SyncStyle.DYNAMIC_BY_EXPLORER, + 8, + [2, 1, 2, 1, 2, 1, 2, 1], + [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5], + [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + ), + ( + SyncMethod.MEMORY, + SyncStyle.FIXED, + 8, + [2, 1, 2, 1, 2, 1, 2, 1], + [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5], + [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + ), + ( + SyncMethod.MEMORY, + SyncStyle.DYNAMIC_BY_EXPLORER, + 8, + [2, 1, 2, 1, 2, 1, 2, 1], + [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5], + [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + ), + ], +) +class TestStateDictBasedSynchronizer(BaseTestSynchronizer): + def test_synchronizer(self): config = get_template_config() config.project = "unittest" config.name = f"test_synchronizer_{datetime.now().strftime('%Y%m%d%H%M%S')}" @@ -136,8 +162,8 @@ def test_state_dict_based_sync(self, name, sync_method, sync_style): storage_type=StorageType.QUEUE, wrap_in_ray=True, ) - config.synchronizer.sync_method = sync_method - config.synchronizer.sync_style = sync_style + config.synchronizer.sync_method = self.sync_method + config.synchronizer.sync_style = self.sync_style config.synchronizer.sync_interval = 2 config.trainer.save_interval = 100 config.monitor.monitor_type = "tensorboard" @@ -163,16 +189,16 @@ def test_state_dict_based_sync(self, name, sync_method, sync_style): explorer2_config.check_and_update() trainer_process = multiprocessing.Process( - target=run_trainer, args=(trainer_config, 8, [2, 1, 2, 1, 2, 1, 2, 1]) + target=run_trainer, args=(trainer_config, self.max_steps, self.trainer_intervals) ) trainer_process.start() explorer_process_1 = multiprocessing.Process( target=run_explorer, - args=(explorer1_config, 8, [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5]), + args=(explorer1_config, self.max_steps, self.explorer1_intervals), ) explorer_process_1.start() explorer_process_2 = multiprocessing.Process( - target=run_explorer, args=(explorer2_config, 8, [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]) + target=run_explorer, args=(explorer2_config, self.max_steps, self.explorer2_intervals) ) explorer_process_2.start() @@ -197,50 +223,26 @@ def test_state_dict_based_sync(self, name, sync_method, sync_style): rollout_metrics = parser.metric_list("rollout") self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) - def test_nccl_method_fixed_style(self): - config = get_template_config() - config.project = "unittest" - config.name = f"test_synchronizer_{datetime.now().strftime('%Y%m%d%H%M%S')}" - config.checkpoint_root_dir = get_checkpoint_path() - config.buffer.total_epochs = 1 - config.buffer.batch_size = 4 - config.cluster.gpu_per_node = 4 - config.cluster.node_num = 1 - config.model.model_path = get_model_path() - config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") - config.buffer.trainer_input.experience_buffer = StorageConfig( - name="exp_buffer", - storage_type=StorageType.QUEUE, - wrap_in_ray=True, - ) - config.synchronizer.sync_method = SyncMethod.NCCL - config.synchronizer.sync_style = SyncStyle.FIXED - config.synchronizer.sync_interval = 2 - config.trainer.save_interval = 100 - config.explorer.rollout_model.engine_num = 2 - config.explorer.rollout_model.tensor_parallel_size = 1 - config.monitor.monitor_type = "tensorboard" - config.mode = "both" - config.check_and_update() - - # TODO: test more interval cases - both_process = multiprocessing.Process( - target=run_both, args=(config, 8, [2, 1, 2, 1, 2, 1, 2, 1], [0, 2.5, 2.5, 2.5, 2.5, 0]) - ) - both_process.start() - both_process.join(timeout=200) - # check the tensorboard - parser = TensorBoardParser(os.path.join(config.monitor.cache_dir, "tensorboard", "trainer")) - actor_metrics = parser.metric_list("actor") - self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8) - parser = TensorBoardParser( - os.path.join(config.monitor.cache_dir, "tensorboard", "explorer") - ) - rollout_metrics = parser.metric_list("rollout") - self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) - - def test_nccl_method_dynamic_by_explorer_style(self): +@parameterized_class( + ("sync_style", "max_steps", "trainer_intervals", "explorer_intervals"), + [ + ( + SyncStyle.FIXED, + 8, + [2, 1, 2, 1, 2, 1, 2, 1], + [0, 2.5, 2.5, 2.5, 2.5, 0], + ), + ( + SyncStyle.DYNAMIC_BY_EXPLORER, + 8, + [2, 1, 2, 1, 2, 1, 2, 1], + [0, 0.5, 0.5, 0.5, 0.5, 0], + ), + ], +) +class TestNCCLBasedSynchronizer(BaseTestSynchronizer): + def test_synchronizer(self): config = get_template_config() config.project = "unittest" config.name = f"test_synchronizer_{datetime.now().strftime('%Y%m%d%H%M%S')}" @@ -257,7 +259,7 @@ def test_nccl_method_dynamic_by_explorer_style(self): wrap_in_ray=True, ) config.synchronizer.sync_method = SyncMethod.NCCL - config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER + config.synchronizer.sync_style = self.sync_style config.synchronizer.sync_interval = 2 config.trainer.save_interval = 100 config.explorer.rollout_model.engine_num = 2 @@ -268,7 +270,8 @@ def test_nccl_method_dynamic_by_explorer_style(self): # TODO: test more interval cases both_process = multiprocessing.Process( - target=run_both, args=(config, 8, [2, 1, 2, 1, 2, 1, 2, 1], [0, 0.5, 0.5, 0.5, 0.5, 0]) + target=run_both, + args=(config, self.max_steps, self.trainer_intervals, self.explorer_intervals), ) both_process.start() both_process.join(timeout=200) diff --git a/trinity/common/synchronizer.py b/trinity/common/synchronizer.py index c8cb827f46..a8c099342c 100644 --- a/trinity/common/synchronizer.py +++ b/trinity/common/synchronizer.py @@ -23,7 +23,7 @@ class Synchronizer: Attributes: trainer_status: Current status of the trainer (e.g., running, waiting). - explorer_status_counter: Dictionary tracking the number of explorers in each status. + explorer_status_counts: Dictionary tracking the number of explorers in each status. _ready_condition: Async condition variable for signaling state changes. model_state_dict: The latest model weights. model_version: Version number of the current model. @@ -34,7 +34,7 @@ def __init__(self, config: Config): self.logger = get_logger(__name__) self.config = config self.trainer_status = RunningStatus.STOPPED - self.explorer_status_counter: Dict[RunningStatus, int] = defaultdict(lambda: 0) + self.explorer_status_counts: Dict[RunningStatus, int] = defaultdict(lambda: 0) self._ready_condition = asyncio.Condition() self.model_state_dict = None self.model_version = 0 @@ -64,20 +64,20 @@ def set_explorer_status( """ if old_status is not None: assert ( - old_status in self.explorer_status_counter + old_status in self.explorer_status_counts ), f"Invalid explorer status {old_status}" assert old_status != status, f"Invalid status change from {old_status} to {status}" - self.explorer_status_counter[old_status] -= 1 + self.explorer_status_counts[old_status] -= 1 assert ( - self.explorer_status_counter[old_status] >= 0 + self.explorer_status_counts[old_status] >= 0 ), f"Invalid status count {old_status} (new status {status})" - if status not in self.explorer_status_counter: - self.explorer_status_counter[status] = 0 - self.explorer_status_counter[status] += 1 + if status not in self.explorer_status_counts: + self.explorer_status_counts[status] = 0 + self.explorer_status_counts[status] += 1 - def get_explorer_status_counter(self) -> Dict[RunningStatus, int]: + def get_explorer_status_counts(self) -> Dict[RunningStatus, int]: """Return the current status counts for all explorers.""" - return self.explorer_status_counter + return self.explorer_status_counts async def set_model_state_dict_with_step_num( self, step_num: Optional[int] = None, world_size: Optional[int] = None @@ -108,9 +108,7 @@ async def set_model_state_dict_with_step_num( step_num=step_num, ) if checkpoint_step_num != self.model_version: - model_state_dict = load_state_dict( - os.path.join(checkpoint_dir, "actor") - ) # TODO: to thread + model_state_dict = load_state_dict(os.path.join(checkpoint_dir, "actor")) await self.set_model_state_dict(model_state_dict, checkpoint_step_num) return checkpoint_step_num @@ -206,7 +204,7 @@ async def ready_to_nccl_sync( The model version if both sides are ready; otherwise None. """ assert ( - sum(self.explorer_status_counter.values()) == 1 + sum(self.explorer_status_counts.values()) == 1 ), "NCCL sync is only supported for one explorer." def sync_failed(): @@ -223,7 +221,7 @@ def sync_failed(): non_stop_cnt = sum( value - for key, value in self.explorer_status_counter.items() + for key, value in self.explorer_status_counts.items() if key != RunningStatus.STOPPED ) if non_stop_cnt == 0: @@ -235,10 +233,10 @@ def sync_failed(): self.model_version = trainer_step self.trainer_status = RunningStatus.WAITING_SYNC self._ready_condition.notify_all() - if self.explorer_status_counter[RunningStatus.WAITING_SYNC] != 1: + if self.explorer_status_counts[RunningStatus.WAITING_SYNC] != 1: await asyncio.wait_for( self._ready_condition.wait_for( - lambda: self.explorer_status_counter[RunningStatus.WAITING_SYNC] + lambda: self.explorer_status_counts[RunningStatus.WAITING_SYNC] == 1, ), timeout=self.config.synchronizer.sync_timeout, diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index c4673c2023..9e3f492035 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -103,13 +103,13 @@ def need_sync(self) -> bool: delta = self.engine.train_step_num - self.last_trainer_sync_step if delta >= self.config.synchronizer.sync_interval: ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.REQUIRE_SYNC)) - explorer_status_counter = ray.get( - self.synchronizer.get_explorer_status_counter.remote() + explorer_status_counts = ray.get( + self.synchronizer.get_explorer_status_counts.remote() ) if self.config.synchronizer.sync_method == SyncMethod.NCCL: - return explorer_status_counter[RunningStatus.WAITING_SYNC] > 0 + return explorer_status_counts[RunningStatus.WAITING_SYNC] > 0 else: # memory & checkpoint - return explorer_status_counter[RunningStatus.REQUIRE_SYNC] > 0 + return explorer_status_counts[RunningStatus.REQUIRE_SYNC] > 0 def sync_weight(self) -> None: """Sync the model weight.""" diff --git a/trinity/trainer/verl/fsdp_checkpoint_manager.py b/trinity/trainer/verl/fsdp_checkpoint_manager.py index 2e9b65312c..1899cb5ad8 100644 --- a/trinity/trainer/verl/fsdp_checkpoint_manager.py +++ b/trinity/trainer/verl/fsdp_checkpoint_manager.py @@ -113,7 +113,7 @@ def save_checkpoint( # noqa: C901 Main improvements over the base class: - Uses separate threads for saving model/optimizer/extras. - - Implements synchronization with a remote actor. + - Implements synchronization with a remote actor. If the model is not trained (`global_step == 0`) or continues from a breakpoint, `Synchonizer` will be notified and the model will not be saved. Args: local_path (str): Local directory path to save the checkpoint. From 041a56fa449545b11018b9af577b570e0bc9b7e3 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 30 Jul 2025 10:32:28 +0800 Subject: [PATCH 15/16] doc fix and bug fix in unittest --- tests/common/synchronizer_test.py | 25 +++++++++++++++---------- trinity/trainer/trainer.py | 4 +--- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/common/synchronizer_test.py b/tests/common/synchronizer_test.py index efcf44ac8e..538f974526 100644 --- a/tests/common/synchronizer_test.py +++ b/tests/common/synchronizer_test.py @@ -74,12 +74,14 @@ def run_trainer(config: Config, max_steps: int, intervals: List[int]) -> None: ray.init(ignore_reinit_error=True, namespace=config.ray_namespace) trainer_monkey_patch(config, max_steps, intervals) train(config) + ray.shutdown(_exiting_interpreter=True) def run_explorer(config: Config, max_steps: int, intervals: List[int]) -> None: ray.init(ignore_reinit_error=True, namespace=config.ray_namespace) explorer_monkey_patch(config, max_steps, intervals) explore(config) + ray.shutdown(_exiting_interpreter=True) def run_both( @@ -89,6 +91,7 @@ def run_both( trainer_monkey_patch(config, max_steps, trainer_intervals) explorer_monkey_patch(config, max_steps, explorer_intervals) both(config) + ray.shutdown(_exiting_interpreter=True) class BaseTestSynchronizer(unittest.TestCase): @@ -117,7 +120,7 @@ def tearDown(self): 8, [2, 1, 2, 1, 2, 1, 2, 1], [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5], - [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], ), ( SyncMethod.CHECKPOINT, @@ -125,7 +128,7 @@ def tearDown(self): 8, [2, 1, 2, 1, 2, 1, 2, 1], [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5], - [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], ), ( SyncMethod.MEMORY, @@ -133,7 +136,7 @@ def tearDown(self): 8, [2, 1, 2, 1, 2, 1, 2, 1], [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5], - [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], ), ( SyncMethod.MEMORY, @@ -141,7 +144,7 @@ def tearDown(self): 8, [2, 1, 2, 1, 2, 1, 2, 1], [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5], - [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], ), ], ) @@ -174,8 +177,6 @@ def test_synchronizer(self): explorer1_config = deepcopy(config) explorer1_config.mode = "explore" explorer1_config.explorer.name = "explorer1" - config.cluster.gpu_per_node = 1 - config.cluster.node_num = 1 explorer1_config.explorer.rollout_model.engine_num = 1 explorer1_config.explorer.rollout_model.tensor_parallel_size = 1 explorer1_config.buffer.explorer_output = StorageConfig( @@ -192,6 +193,14 @@ def test_synchronizer(self): target=run_trainer, args=(trainer_config, self.max_steps, self.trainer_intervals) ) trainer_process.start() + ray.init(ignore_reinit_error=True) + while True: + try: + ray.get_actor("queue-exp_buffer", namespace=trainer_config.ray_namespace) + break + except ValueError: + print("waiting for trainer to start.") + time.sleep(5) explorer_process_1 = multiprocessing.Process( target=run_explorer, args=(explorer1_config, self.max_steps, self.explorer1_intervals), @@ -249,8 +258,6 @@ def test_synchronizer(self): config.checkpoint_root_dir = get_checkpoint_path() config.buffer.total_epochs = 1 config.buffer.batch_size = 4 - config.cluster.gpu_per_node = 4 - config.cluster.node_num = 1 config.model.model_path = get_model_path() config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") config.buffer.trainer_input.experience_buffer = StorageConfig( @@ -262,8 +269,6 @@ def test_synchronizer(self): config.synchronizer.sync_style = self.sync_style config.synchronizer.sync_interval = 2 config.trainer.save_interval = 100 - config.explorer.rollout_model.engine_num = 2 - config.explorer.rollout_model.tensor_parallel_size = 1 config.monitor.monitor_type = "tensorboard" config.mode = "both" config.check_and_update() diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 9e3f492035..00de93c10d 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -103,9 +103,7 @@ def need_sync(self) -> bool: delta = self.engine.train_step_num - self.last_trainer_sync_step if delta >= self.config.synchronizer.sync_interval: ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.REQUIRE_SYNC)) - explorer_status_counts = ray.get( - self.synchronizer.get_explorer_status_counts.remote() - ) + explorer_status_counts = ray.get(self.synchronizer.get_explorer_status_counts.remote()) if self.config.synchronizer.sync_method == SyncMethod.NCCL: return explorer_status_counts[RunningStatus.WAITING_SYNC] > 0 else: # memory & checkpoint From e3bf3954bc34b9b62675aecb2a66f2f3ce209b30 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Wed, 30 Jul 2025 16:37:40 +0800 Subject: [PATCH 16/16] Bug fix in EID and `ray.get` in `explorer` --- tests/explorer/workflow_test.py | 4 ++- trinity/common/config.py | 2 +- trinity/common/synchronizer.py | 7 +++-- trinity/explorer/explorer.py | 49 ++++++++++++++------------------- 4 files changed, 29 insertions(+), 33 deletions(-) diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index cc85eaf8c7..4f965c8a2e 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -1,13 +1,14 @@ # -*- coding: utf-8 -*- """Test for the workflow module""" import unittest -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Dict, Optional from unittest.mock import MagicMock from torch import Tensor from tests.tools import get_unittest_dataset_config +from trinity.common.experience import EID from trinity.common.rewards import RMGalleryFn from trinity.common.workflows import ( MathBoxedWorkflow, @@ -27,6 +28,7 @@ class MockResponse: unique_id: Optional[str] = "0" tokens: Optional[Tensor] = Tensor([0, 0]) prompt_length: int = 1 + eid: EID = field(default_factory=EID) class DummyWorkflow(Workflow): diff --git a/trinity/common/config.py b/trinity/common/config.py index c752e5bdc9..9714b521ab 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -749,7 +749,7 @@ def check_and_update(self) -> None: # noqa: C901 ): self.synchronizer.sync_method = SyncMethod.CHECKPOINT logger.warning( - f"`{self.mode}` mode only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`." + f"`{self.mode}` mode does not support NCCL synchronization, set `synchronizer.sync_method` to `checkpoint`." ) self._check_interval() diff --git a/trinity/common/synchronizer.py b/trinity/common/synchronizer.py index a8c099342c..3e2070effa 100644 --- a/trinity/common/synchronizer.py +++ b/trinity/common/synchronizer.py @@ -48,7 +48,7 @@ async def set_trainer_status(self, status: RunningStatus): if status == RunningStatus.STOPPED: self._ready_condition.notify_all() - async def get_trainer_status(self) -> RunningStatus: + def get_trainer_status(self) -> RunningStatus: """Get the current status of the trainer.""" return self.trainer_status @@ -173,7 +173,7 @@ async def wait_new_model_state_dict(self, current_version: int, no_wait: bool = async with self._ready_condition: assert ( self.model_version >= current_version - ), f"The model version in Synchronizer ({self.model_version}) should be greater than that in Explorer ({current_version})!" + ), f"The model version in Synchronizer ({self.model_version}) should be no smaller than that in Explorer ({current_version})!" if self.model_version == current_version: if not no_wait and self.trainer_status != RunningStatus.STOPPED: # TODO: explorer need support no wait @@ -226,7 +226,7 @@ def sync_failed(): ) if non_stop_cnt == 0: return sync_failed() - # for status in RunningStatus: + async with self._ready_condition: try: if module == "trainer": @@ -256,6 +256,7 @@ def sync_failed(): ) if self.trainer_status == RunningStatus.STOPPED: return sync_failed() + self.trainer_status = RunningStatus.RUNNING return self.model_version except asyncio.TimeoutError: return sync_failed() diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index e8b348bdb4..4f6428f1ba 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -41,7 +41,6 @@ def __init__(self, config: Config): self.explore_step_num = explorer_meta.get("latest_iteration", 0) self.last_sync_step = self.explore_step_num if self.explore_step_num > 0 else -1 self.synchronizer = Synchronizer.get_actor(config) - ray.get(self.synchronizer.acquire.remote()) self.config = config self.algorithm_manager = AlgorithmManager(config) self.models, self.auxiliary_models = create_inference_models(config) @@ -126,15 +125,13 @@ def _init_scheduler(self) -> Scheduler: return Scheduler(self.config, self.models, self.auxiliary_models) async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> int: - step_num = ray.get(self.synchronizer.set_model_state_dict_with_step_num.remote(step_num)) + step_num = await self.synchronizer.set_model_state_dict_with_step_num.remote(step_num) await asyncio.gather(*[model.sync_model.remote(step_num) for model in self.models]) return step_num # type: ignore async def _pull_latest_weights(self): self.logger.info("Start to pull latest model weights.") - new_version = ray.get( - self.synchronizer.wait_new_model_state_dict.remote(self.model_version) - ) + new_version = await self.synchronizer.wait_new_model_state_dict.remote(self.model_version) if new_version > self.model_version: if self.model_version != -1: self.logger.info(f"New model weights version: {new_version}") @@ -143,10 +140,8 @@ async def _pull_latest_weights(self): ) self.model_version = new_version self.last_sync_step = self.explore_step_num - ray.get( - self.synchronizer.set_explorer_status.remote( - RunningStatus.RUNNING, old_status=RunningStatus.WAITING_SYNC - ) + await self.synchronizer.set_explorer_status.remote( + RunningStatus.RUNNING, old_status=RunningStatus.WAITING_SYNC ) self.last_sync_successful = True else: @@ -156,8 +151,8 @@ async def _pull_latest_weights(self): self.last_sync_successful = False async def _nccl_weights_update(self): - new_version = ray.get( - self.synchronizer.ready_to_nccl_sync.remote("explorer", self.model_version) + new_version = await self.synchronizer.ready_to_nccl_sync.remote( + "explorer", self.model_version ) if new_version is None: self.logger.info("Trainer is not ready to sync weight. Skipping sync weight.") @@ -168,24 +163,25 @@ async def _nccl_weights_update(self): *[model.sync_model.remote(self.model_version) for model in self.models] ) self.last_sync_step = self.explore_step_num - ray.get( - self.synchronizer.set_explorer_status.remote( - RunningStatus.RUNNING, old_status=RunningStatus.WAITING_SYNC - ) + await self.synchronizer.set_explorer_status.remote( + RunningStatus.RUNNING, old_status=RunningStatus.WAITING_SYNC ) self.last_sync_successful = True async def prepare(self) -> None: """Preparation before running.""" + futures = [ + asyncio.create_task(self.scheduler.start()), + self.synchronizer.acquire.remote(), + ] if self.experience_buffer: - await self.experience_buffer.acquire() - futures = [asyncio.create_task(self.scheduler.start())] + futures.append(asyncio.create_task(self.experience_buffer.acquire())) if not self.use_nccl_sync: master_address, master_port = await self.models[0].get_available_address.remote() futures.append( asyncio.create_task(self.setup_weight_sync_group(master_address, master_port)) ) - asyncio.gather(*futures, return_exceptions=True) + await asyncio.gather(*futures, return_exceptions=True) if self.config.explorer.eval_on_startup and self.explore_step_num == 0: self.eval() await self.synchronizer.set_explorer_status.remote(RunningStatus.REQUIRE_SYNC) @@ -215,7 +211,7 @@ async def explore(self) -> str: break if self.need_eval(): self.eval() - if self.need_sync(): + if await self.need_sync(): await self.sync_weight() except Exception: self.logger.error(f"Error in Explorer: {traceback.format_exc()}") @@ -248,7 +244,7 @@ async def explore_step(self) -> bool: self.explore_step_num += 1 return True - def need_sync(self) -> bool: + async def need_sync(self) -> bool: if self.config.synchronizer.sync_style == SyncStyle.FIXED: if self.explore_step_num <= self.config.synchronizer.sync_offset: return False @@ -262,15 +258,12 @@ def need_sync(self) -> bool: if delta >= self.config.synchronizer.sync_interval: require_sync = True else: - require_sync = ( - ray.get(self.synchronizer.get_trainer_status.remote()) - == RunningStatus.REQUIRE_SYNC + require_sync = await ( + self.synchronizer.get_trainer_status.remote() == RunningStatus.REQUIRE_SYNC ) if require_sync and self.last_sync_successful: - ray.get( - self.synchronizer.set_explorer_status.remote( - RunningStatus.REQUIRE_SYNC, old_status=RunningStatus.RUNNING - ) + await self.synchronizer.set_explorer_status.remote( + RunningStatus.REQUIRE_SYNC, old_status=RunningStatus.RUNNING ) return require_sync @@ -399,7 +392,7 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva async def shutdown(self) -> None: self.monitor.close() - if ray.get(self.synchronizer.release.remote()) == 0: + if await self.synchronizer.release.remote() == 0: ray.kill(self.synchronizer) self.logger.info("Synchronizer stopped.") await self.scheduler.stop()