diff --git a/tests/common/config_test.py b/tests/common/config_test.py index db7190856f..f3eea28740 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -35,10 +35,6 @@ def test_load_default_config(self): ) self.assertEqual(config.model.model_path, config.model.critic_model_path) self.assertEqual(config.model.model_path, config.explorer.rollout_model.model_path) - self.assertEqual( - config.trainer.trainer_config.trainer.save_freq, - config.synchronizer.sync_interval, - ) def test_all_examples_are_valid(self): example_dir = os.path.join(os.path.dirname(__file__), "..", "..", "examples") diff --git a/tests/common/synchronizer_test.py b/tests/common/synchronizer_test.py new file mode 100644 index 0000000000..538f974526 --- /dev/null +++ b/tests/common/synchronizer_test.py @@ -0,0 +1,296 @@ +# -*- coding: utf-8 -*- +"""Test cases for Synchronizer modules.""" + +import asyncio +import multiprocessing +import os +import shutil +import time +import unittest +from copy import deepcopy +from datetime import datetime +from typing import List + +import ray +from parameterized import parameterized_class + +from tests.tools import ( + TensorBoardParser, + get_checkpoint_path, + get_model_path, + get_template_config, + get_unittest_dataset_config, +) +from trinity.algorithm.algorithm import ALGORITHM_TYPE +from trinity.cli.launcher import both, explore, train +from trinity.common.config import Config, StorageConfig +from trinity.common.constants import StorageType, SyncMethod, SyncStyle +from trinity.explorer.explorer import Explorer +from trinity.trainer.trainer import Trainer +from trinity.utils.log import get_logger + +logger = get_logger(__name__) +CHECKPOINT_ROOT_DIR = os.path.join(os.path.dirname(__file__), "temp_checkpoint_dir") + + +def trainer_monkey_patch(config: Config, max_steps: int, intervals: List[int]): + def new_train_step(self): + self.engine.algorithm = ALGORITHM_TYPE.get(config.algorithm.algorithm_type) + self.engine.global_steps += 1 + self.logger.info(f"Training at step {self.engine.global_steps} started.") + time.sleep(intervals[self.engine.global_steps - 1]) + metrics = {"actor/step": self.engine.global_steps} + self.monitor.log(data=metrics, step=self.engine.global_steps) + self.logger.info(f"Training at step {self.engine.global_steps} finished.") + return self.engine.global_steps < max_steps + + Trainer.train_step = new_train_step + + +def explorer_monkey_patch(config: Config, max_steps: int, intervals: List[int]): + async def new_explore_step(self): + if self.explore_step_num == max_steps: + await self.save_checkpoint(sync_weight=False) + self.explore_step_num += 1 + return self.explore_step_num <= max_steps + + def wrapper(old_save_checkpoint): + async def new_save_checkpoint(self, sync_weight: bool = False): + await asyncio.sleep(intervals.pop(0)) + await old_save_checkpoint(self, sync_weight) + + return new_save_checkpoint + + async def new_finish_explore_step(self, step: int, model_version: int) -> None: + metric = {"rollout/model_version": model_version} + self.monitor.log(metric, step=step) + + Explorer.explore_step = new_explore_step + Explorer.save_checkpoint = wrapper(Explorer.save_checkpoint) + Explorer._finish_explore_step = new_finish_explore_step + + +def run_trainer(config: Config, max_steps: int, intervals: List[int]) -> None: + ray.init(ignore_reinit_error=True, namespace=config.ray_namespace) + trainer_monkey_patch(config, max_steps, intervals) + train(config) + ray.shutdown(_exiting_interpreter=True) + + +def run_explorer(config: Config, max_steps: int, intervals: List[int]) -> None: + ray.init(ignore_reinit_error=True, namespace=config.ray_namespace) + explorer_monkey_patch(config, max_steps, intervals) + explore(config) + ray.shutdown(_exiting_interpreter=True) + + +def run_both( + config: Config, max_steps: int, trainer_intervals: List[int], explorer_intervals: List[int] +) -> None: + ray.init(ignore_reinit_error=True, namespace=config.ray_namespace) + trainer_monkey_patch(config, max_steps, trainer_intervals) + explorer_monkey_patch(config, max_steps, explorer_intervals) + both(config) + ray.shutdown(_exiting_interpreter=True) + + +class BaseTestSynchronizer(unittest.TestCase): + def setUp(self): + if multiprocessing.get_start_method(allow_none=True) != "spawn": + multiprocessing.set_start_method("spawn", force=True) + + def tearDown(self): + checkpoint_path = get_checkpoint_path() + shutil.rmtree(os.path.join(checkpoint_path, "unittest")) + + +@parameterized_class( + ( + "sync_method", + "sync_style", + "max_steps", + "trainer_intervals", + "explorer1_intervals", + "explorer2_intervals", + ), + [ + ( + SyncMethod.CHECKPOINT, + SyncStyle.FIXED, + 8, + [2, 1, 2, 1, 2, 1, 2, 1], + [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5], + [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + ), + ( + SyncMethod.CHECKPOINT, + SyncStyle.DYNAMIC_BY_EXPLORER, + 8, + [2, 1, 2, 1, 2, 1, 2, 1], + [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5], + [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + ), + ( + SyncMethod.MEMORY, + SyncStyle.FIXED, + 8, + [2, 1, 2, 1, 2, 1, 2, 1], + [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5], + [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + ), + ( + SyncMethod.MEMORY, + SyncStyle.DYNAMIC_BY_EXPLORER, + 8, + [2, 1, 2, 1, 2, 1, 2, 1], + [0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5], + [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + ), + ], +) +class TestStateDictBasedSynchronizer(BaseTestSynchronizer): + def test_synchronizer(self): + config = get_template_config() + config.project = "unittest" + config.name = f"test_synchronizer_{datetime.now().strftime('%Y%m%d%H%M%S')}" + config.checkpoint_root_dir = get_checkpoint_path() + config.buffer.total_epochs = 1 + config.buffer.batch_size = 4 + config.cluster.gpu_per_node = 2 + config.cluster.node_num = 1 + config.model.model_path = get_model_path() + config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") + config.buffer.trainer_input.experience_buffer = StorageConfig( + name="exp_buffer", + storage_type=StorageType.QUEUE, + wrap_in_ray=True, + ) + config.synchronizer.sync_method = self.sync_method + config.synchronizer.sync_style = self.sync_style + config.synchronizer.sync_interval = 2 + config.trainer.save_interval = 100 + config.monitor.monitor_type = "tensorboard" + trainer_config = deepcopy(config) + trainer_config.mode = "train" + trainer_config.check_and_update() + + explorer1_config = deepcopy(config) + explorer1_config.mode = "explore" + explorer1_config.explorer.name = "explorer1" + explorer1_config.explorer.rollout_model.engine_num = 1 + explorer1_config.explorer.rollout_model.tensor_parallel_size = 1 + explorer1_config.buffer.explorer_output = StorageConfig( + name="exp_buffer", + storage_type=StorageType.QUEUE, + wrap_in_ray=True, + ) + explorer2_config = deepcopy(explorer1_config) + explorer2_config.explorer.name = "explorer2" + explorer1_config.check_and_update() + explorer2_config.check_and_update() + + trainer_process = multiprocessing.Process( + target=run_trainer, args=(trainer_config, self.max_steps, self.trainer_intervals) + ) + trainer_process.start() + ray.init(ignore_reinit_error=True) + while True: + try: + ray.get_actor("queue-exp_buffer", namespace=trainer_config.ray_namespace) + break + except ValueError: + print("waiting for trainer to start.") + time.sleep(5) + explorer_process_1 = multiprocessing.Process( + target=run_explorer, + args=(explorer1_config, self.max_steps, self.explorer1_intervals), + ) + explorer_process_1.start() + explorer_process_2 = multiprocessing.Process( + target=run_explorer, args=(explorer2_config, self.max_steps, self.explorer2_intervals) + ) + explorer_process_2.start() + + explorer_process_1.join(timeout=200) + explorer_process_2.join(timeout=200) + trainer_process.join(timeout=200) + + # check the tensorboard + parser = TensorBoardParser( + os.path.join(trainer_config.monitor.cache_dir, "tensorboard", "trainer") + ) + actor_metrics = parser.metric_list("actor") + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8) + parser = TensorBoardParser( + os.path.join(explorer1_config.monitor.cache_dir, "tensorboard", "explorer1") + ) + rollout_metrics = parser.metric_list("rollout") + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) + parser = TensorBoardParser( + os.path.join(explorer2_config.monitor.cache_dir, "tensorboard", "explorer2") + ) + rollout_metrics = parser.metric_list("rollout") + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) + + +@parameterized_class( + ("sync_style", "max_steps", "trainer_intervals", "explorer_intervals"), + [ + ( + SyncStyle.FIXED, + 8, + [2, 1, 2, 1, 2, 1, 2, 1], + [0, 2.5, 2.5, 2.5, 2.5, 0], + ), + ( + SyncStyle.DYNAMIC_BY_EXPLORER, + 8, + [2, 1, 2, 1, 2, 1, 2, 1], + [0, 0.5, 0.5, 0.5, 0.5, 0], + ), + ], +) +class TestNCCLBasedSynchronizer(BaseTestSynchronizer): + def test_synchronizer(self): + config = get_template_config() + config.project = "unittest" + config.name = f"test_synchronizer_{datetime.now().strftime('%Y%m%d%H%M%S')}" + config.checkpoint_root_dir = get_checkpoint_path() + config.buffer.total_epochs = 1 + config.buffer.batch_size = 4 + config.model.model_path = get_model_path() + config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") + config.buffer.trainer_input.experience_buffer = StorageConfig( + name="exp_buffer", + storage_type=StorageType.QUEUE, + wrap_in_ray=True, + ) + config.synchronizer.sync_method = SyncMethod.NCCL + config.synchronizer.sync_style = self.sync_style + config.synchronizer.sync_interval = 2 + config.trainer.save_interval = 100 + config.monitor.monitor_type = "tensorboard" + config.mode = "both" + config.check_and_update() + + # TODO: test more interval cases + both_process = multiprocessing.Process( + target=run_both, + args=(config, self.max_steps, self.trainer_intervals, self.explorer_intervals), + ) + both_process.start() + both_process.join(timeout=200) + + # check the tensorboard + parser = TensorBoardParser(os.path.join(config.monitor.cache_dir, "tensorboard", "trainer")) + actor_metrics = parser.metric_list("actor") + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8) + parser = TensorBoardParser( + os.path.join(config.monitor.cache_dir, "tensorboard", "explorer") + ) + rollout_metrics = parser.metric_list("rollout") + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) + + def tearDown(self): + if os.path.exists(CHECKPOINT_ROOT_DIR): + shutil.rmtree(CHECKPOINT_ROOT_DIR) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 6a88e2a125..678be80a28 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -69,7 +69,6 @@ def init_process_group( group_name: str, backend: str = "nccl", timeout: int = 1200, - update_with_checkpoint: bool = True, ) -> None: pass @@ -91,7 +90,6 @@ def init_process_group( group_name: str, backend: str = "nccl", timeout: int = 1200, - update_with_checkpoint: bool = True, ) -> None: pass diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index cc85eaf8c7..4f965c8a2e 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -1,13 +1,14 @@ # -*- coding: utf-8 -*- """Test for the workflow module""" import unittest -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Dict, Optional from unittest.mock import MagicMock from torch import Tensor from tests.tools import get_unittest_dataset_config +from trinity.common.experience import EID from trinity.common.rewards import RMGalleryFn from trinity.common.workflows import ( MathBoxedWorkflow, @@ -27,6 +28,7 @@ class MockResponse: unique_id: Optional[str] = "0" tokens: Optional[Tensor] = Tensor([0, 0]) prompt_length: int = 1 + eid: EID = field(default_factory=EID) class DummyWorkflow(Workflow): diff --git a/tests/tools.py b/tests/tools.py index 60df1122f2..b89607138b 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -18,7 +18,9 @@ def get_template_config() -> Config: config_path = os.path.join(os.path.dirname(__file__), "template", "config.yaml") - return load_config(config_path) + config = load_config(config_path) + config.ray_namespace = ray.get_runtime_context().namespace + return config def get_model_path() -> str: diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index ff56f3ffc5..7c11a67b4f 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -22,7 +22,7 @@ ) from trinity.cli.launcher import bench, both, explore, train from trinity.common.config import Config, StorageConfig -from trinity.common.constants import StorageType, SyncMethod +from trinity.common.constants import StorageType, SyncMethod, SyncStyle from trinity.common.models.utils import get_checkpoint_dir_with_step_num from trinity.manager.manager import CacheManager @@ -99,7 +99,7 @@ def test_trainer(self): self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_8, "actor"))) > 0) self.assertEqual(step_num, 8) # TODO: Reinit will fail when using v1 engine, find a way to fix it - ray.init(ignore_reinit_error=True) + ray.init(ignore_reinit_error=True, namespace=self.config.ray_namespace) # test bench mode self.config.mode = "bench" self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT @@ -362,6 +362,7 @@ def test_fully_async_mode(self, name, use_priority_queue): use_priority_queue=use_priority_queue, ) config.synchronizer.sync_method = SyncMethod.CHECKPOINT + config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER config.synchronizer.sync_interval = 8 config.monitor.monitor_type = "tensorboard" trainer_config = deepcopy(config) diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py index e5d76e57d4..e28af726a8 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/queue.py @@ -71,6 +71,10 @@ def __init__(self, capacity: int): async def close(self) -> None: """Close the queue.""" self._closed = True + for getter in self._getters: + if not getter.done(): + getter.set_exception(StopAsyncIteration()) + self._getters.clear() def stopped(self) -> bool: """Check if there is no more data to read.""" diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index f207a5353d..50bcf5b8c0 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -151,8 +151,11 @@ def both(config: Config) -> None: "===============================================================" ) ray.wait(wait_ref, timeout=config.synchronizer.sync_timeout) - explorer.shutdown.remote() - trainer.shutdown.remote() + ray.wait( + [explorer.shutdown.remote(), trainer.shutdown.remote()], + timeout=config.synchronizer.sync_timeout, + num_returns=2, + ) def run(config_path: str, dlc: bool = False, plugin_dir: str = None): diff --git a/trinity/common/config.py b/trinity/common/config.py index ab7d8a08cd..9714b521ab 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -15,6 +15,7 @@ ReadStrategy, StorageType, SyncMethod, + SyncStyle, TaskType, ) from trinity.utils.log import get_logger @@ -387,6 +388,7 @@ class SynchronizerConfig: """Configs for model weight synchronization.""" sync_method: SyncMethod = SyncMethod.NCCL + sync_style: SyncStyle = SyncStyle.FIXED # sync weights every `sync_interval` steps sync_interval: int = 1 # allow explorer to run `sync_offset` steps before sync @@ -398,6 +400,7 @@ class SynchronizerConfig: # ! DO NOT SET, automatically calculated explorer_world_size: Optional[int] = None + ray_namespace: str = "" @dataclass @@ -406,6 +409,7 @@ class Config: mode: str = "both" # `explore`, `train`, `both` or `bench` project: str = "Trinity-RFT" + group: str = "" name: str = "rft" # the root dir for checkpoints checkpoint_root_dir: str = "" @@ -447,16 +451,6 @@ def _check_interval(self) -> None: f"`eval_interval` is not a multiple of `sync_interval`; adjusted to the nearest integer={self.explorer.eval_interval}." ) - # check save_interval - if self.synchronizer.sync_method == SyncMethod.CHECKPOINT: - if self.trainer.save_interval != self.synchronizer.sync_interval: - logger.warning( - f"When `algorithm.algorithm_type` != `dpo` and `synchronizer.sync_method` == `checkpoint`, " - f"`trainer.save_interval` will be set to " - f"`synchronizer.sync_interval = {self.synchronizer.sync_interval}`." - ) - self.trainer.save_interval = self.synchronizer.sync_interval - def _check_buffer(self) -> None: # noqa: C901 # TODO: split this function into different buffer read/writer # check explorer_input @@ -717,7 +711,9 @@ def check_and_update(self) -> None: # noqa: C901 if not os.path.isabs(self.checkpoint_root_dir): self.checkpoint_root_dir = os.path.join(os.getcwd(), self.checkpoint_root_dir) # create a job dir at checkpoint_root_dir/project/name - self.checkpoint_job_dir = os.path.join(self.checkpoint_root_dir, self.project, self.name) + self.checkpoint_job_dir = os.path.join( + self.checkpoint_root_dir, self.project, self.group, self.name + ) # rename the experiment when necessary if not self.continue_from_checkpoint and ( os.path.exists(self.checkpoint_job_dir) and os.listdir(self.checkpoint_job_dir) @@ -742,17 +738,18 @@ def check_and_update(self) -> None: # noqa: C901 self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens # check synchronizer + self.synchronizer.ray_namespace = self.ray_namespace self.synchronizer.explorer_world_size = ( self.explorer.rollout_model.engine_num * self.explorer.rollout_model.tensor_parallel_size ) if ( self.mode in ["train", "explore", "bench"] - and self.synchronizer.sync_method != SyncMethod.CHECKPOINT + and self.synchronizer.sync_method == SyncMethod.NCCL ): self.synchronizer.sync_method = SyncMethod.CHECKPOINT logger.warning( - f"`{self.mode}` mode only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`." + f"`{self.mode}` mode does not support NCCL synchronization, set `synchronizer.sync_method` to `checkpoint`." ) self._check_interval() diff --git a/trinity/common/constants.py b/trinity/common/constants.py index bac4941453..392f2dc553 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -95,12 +95,14 @@ class SyncMethod(CaseInsensitiveEnum, metaclass=SyncMethodEnumMeta): NCCL = "nccl" CHECKPOINT = "checkpoint" + MEMORY = "memory" class RunningStatus(Enum): """Running status of explorer and trainer.""" RUNNING = "running" + REQUIRE_SYNC = "require_sync" WAITING_SYNC = "waiting_sync" STOPPED = "stopped" @@ -119,3 +121,9 @@ class OpType(Enum): SUB = "sub" MUL = "mul" DIV = "div" + + +class SyncStyle(CaseInsensitiveEnum): + FIXED = "fixed" + DYNAMIC_BY_TRAINER = "dynamic_by_trainer" + DYNAMIC_BY_EXPLORER = "dynamic_by_explorer" diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index a1e6070b92..7b99dce17b 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -3,7 +3,7 @@ import os import re -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Union import aiohttp import ray @@ -260,12 +260,8 @@ async def _collective_rpc( method, timeout, args, kwargs ) - async def sync_model( - self, model_version: int, update_weight_args_list: Optional[List[Tuple]] = None - ) -> bool: + async def sync_model(self, model_version: int) -> bool: """Sync model weights to vLLM.""" - if update_weight_args_list is not None: - await self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,)) await self._collective_rpc("update_weight") self.logger.info("Sync model weights to vLLM successfully.") self.model_version = model_version @@ -281,7 +277,6 @@ async def init_process_group( explorer_name: str, backend: str = "nccl", timeout: int = 1200, - update_with_checkpoint: bool = True, state_dict_meta: dict = None, ): return await self._collective_rpc( @@ -294,7 +289,6 @@ async def init_process_group( group_name, backend, timeout, - update_with_checkpoint, state_dict_meta, explorer_name, ray.get_runtime_context().namespace, diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 7509942176..9835cd6d15 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -4,6 +4,7 @@ import torch import torch.distributed +from trinity.common.synchronizer import Synchronizer from trinity.utils.distributed import init_process_group from trinity.utils.log import get_logger @@ -20,7 +21,6 @@ def init_process_group( group_name: str, backend: str = "nccl", timeout: int = 1200, - update_with_checkpoint: bool = True, state_dict_meta: list = None, explorer_name: str = None, namespace: str = None, @@ -28,11 +28,10 @@ def init_process_group( """Init torch process group for model weights update""" assert torch.distributed.is_initialized(), "default torch process group must be initialized" assert group_name != "", "group name must not be empty" - self.set_state_dict_meta(state_dict_meta) - self._update_with_checkpoint = update_with_checkpoint + self._state_dict_meta = state_dict_meta self._weight_update_rank = torch.distributed.get_rank() + rank_offset logger.info( - f"vLLM starting init_process_group ({'checkpoint' if self._update_with_checkpoint else 'nccl'}):\n" + f"vLLM starting init_process_group:\n" f" > address={master_address}:{master_port}\n" f" > rank={torch.distributed.get_rank()}\n" f" > rank_offset={rank_offset}\n" @@ -51,21 +50,17 @@ def init_process_group( logger.info("vLLM init_process_group finished.") self._explorer_name = explorer_name self._namespace = namespace - self._explorer_actor = None - - def set_state_dict_meta(self, state_dict_meta): - self._state_dict_meta = state_dict_meta + self.synchronizer = Synchronizer.get_actor(namespace=self._namespace) def update_weight(self): """Broadcast weight to all vllm workers from source rank 0 (actor model)""" - assert self._state_dict_meta is not None - if self._explorer_actor is None: - self._explorer_actor = ray.get_actor( - name=self._explorer_name, namespace=self._namespace - ) + if self._state_dict_meta is None: + self._state_dict_meta = ray.get(self.synchronizer.get_state_dict_meta.remote()) + if self._weight_update_rank == 0: + state_dict, _ = ray.get(self.synchronizer.get_model_state_dict.remote()) for name, dtype_str, shape in self._state_dict_meta: if self._weight_update_rank == 0: - weight = ray.get(self._explorer_actor.get_weight.remote(name)) + weight = state_dict[name] weight = weight.to(self.device) else: dtype = getattr(torch, dtype_str.split(".")[-1]) diff --git a/trinity/common/synchronizer.py b/trinity/common/synchronizer.py new file mode 100644 index 0000000000..3e2070effa --- /dev/null +++ b/trinity/common/synchronizer.py @@ -0,0 +1,294 @@ +"""A centralized synchronizer for coordinating explorer and trainer.""" + +import asyncio +import os +from collections import defaultdict +from typing import Dict, List, Optional, Union + +import ray + +from trinity.common.config import Config +from trinity.common.constants import RunningStatus +from trinity.common.models.utils import ( + get_checkpoint_dir_with_step_num, + load_state_dict, +) +from trinity.utils.log import get_logger + + +class Synchronizer: + """ + A central component to manage synchronization of models and states between + the trainer and one or more explorers in a distributed training setup. + + Attributes: + trainer_status: Current status of the trainer (e.g., running, waiting). + explorer_status_counts: Dictionary tracking the number of explorers in each status. + _ready_condition: Async condition variable for signaling state changes. + model_state_dict: The latest model weights. + model_version: Version number of the current model. + checkpoint_shard_counter: Tracks how many shards are received from trainer for a specific train step. + """ + + def __init__(self, config: Config): + self.logger = get_logger(__name__) + self.config = config + self.trainer_status = RunningStatus.STOPPED + self.explorer_status_counts: Dict[RunningStatus, int] = defaultdict(lambda: 0) + self._ready_condition = asyncio.Condition() + self.model_state_dict = None + self.model_version = 0 + self.checkpoint_shard_counter = defaultdict(lambda: 0) + self.ref_count = 0 + + async def set_trainer_status(self, status: RunningStatus): + """Update the status of the trainer.""" + async with self._ready_condition: + self.trainer_status = status + if status == RunningStatus.STOPPED: + self._ready_condition.notify_all() + + def get_trainer_status(self) -> RunningStatus: + """Get the current status of the trainer.""" + return self.trainer_status + + def set_explorer_status( + self, status: RunningStatus, old_status: Optional[RunningStatus] = None + ): + """ + Update the status count for an explorer. + + Args: + status: New status of the explorer. + old_status: Previous status if changing from one to another. + """ + if old_status is not None: + assert ( + old_status in self.explorer_status_counts + ), f"Invalid explorer status {old_status}" + assert old_status != status, f"Invalid status change from {old_status} to {status}" + self.explorer_status_counts[old_status] -= 1 + assert ( + self.explorer_status_counts[old_status] >= 0 + ), f"Invalid status count {old_status} (new status {status})" + if status not in self.explorer_status_counts: + self.explorer_status_counts[status] = 0 + self.explorer_status_counts[status] += 1 + + def get_explorer_status_counts(self) -> Dict[RunningStatus, int]: + """Return the current status counts for all explorers.""" + return self.explorer_status_counts + + async def set_model_state_dict_with_step_num( + self, step_num: Optional[int] = None, world_size: Optional[int] = None + ) -> int: + """ + Load and set the model state dictionary from a checkpoint at a specific step. + + Args: + step_num: Training step number corresponding to the checkpoint. + world_size: Number of shards expected for this checkpoint. + + Returns: + The updated model version (step number). + """ + if world_size is not None: # Used when trainer updates the model + assert step_num is not None + assert self.checkpoint_shard_counter[step_num] < world_size, "World size mismatch!" + self.checkpoint_shard_counter[step_num] += 1 + self.logger.info( + f"Synchronizer has received {self.checkpoint_shard_counter[step_num]} out of {world_size} shards from the checkpoint {step_num}." + ) + if self.checkpoint_shard_counter[step_num] < world_size: + return step_num + + checkpoint_dir, checkpoint_step_num = get_checkpoint_dir_with_step_num( + checkpoint_root_path=self.config.checkpoint_job_dir, + trainer_type=self.config.trainer.trainer_type, + step_num=step_num, + ) + if checkpoint_step_num != self.model_version: + model_state_dict = load_state_dict(os.path.join(checkpoint_dir, "actor")) + await self.set_model_state_dict(model_state_dict, checkpoint_step_num) + return checkpoint_step_num + + async def set_model_state_dict(self, model_state_dict: Union[dict, None], trainer_step: int): + """ + Set the new model state and update the version. + + Args: + model_state_dict: The PyTorch model state dictionary. + trainer_step: Step number associated with this model version. + """ + self.model_state_dict = model_state_dict + async with self._ready_condition: + self.model_version = trainer_step + self.logger.info(f"Set model state dict version to {trainer_step}.") + self._ready_condition.notify_all() + + def get_model_state_dict(self): + """Return the current model state and its version.""" + return self.model_state_dict, self.model_version + + def get_state_dict_meta(self): + """ + Return metadata about the model state (names, data types, shapes). + + Returns: + List of tuples: (name, dtype, shape). + """ + if self.model_state_dict is None: + return None + update_weight_args_list = [] + for name, param in self.model_state_dict.items(): + update_weight_args_list.append((name, str(param.dtype), tuple(param.shape))) + return update_weight_args_list + + async def setup_weight_sync_group( + self, master_address: str, master_port: int, state_dict_meta: List = None + ): + """ + Notify the explorer actor to setup weight sync group. + + This is used to initialize NCCL-based synchronization for distributed training. + + Args: + master_address: IP address of the master node. + master_port: Port used for synchronization. + state_dict_meta: Metadata of the model parameters. + """ + explorer = ray.get_actor(self.config.explorer.name, namespace=self.config.ray_namespace) + await explorer.setup_weight_sync_group.remote(master_address, master_port, state_dict_meta) + + async def wait_new_model_state_dict(self, current_version: int, no_wait: bool = False) -> int: + """ + Wait until a new model state is available. + + Args: + current_version: Current model version known to one explorer. + + Returns: + The new model version after it has been updated. + """ + async with self._ready_condition: + assert ( + self.model_version >= current_version + ), f"The model version in Synchronizer ({self.model_version}) should be no smaller than that in Explorer ({current_version})!" + if self.model_version == current_version: + if not no_wait and self.trainer_status != RunningStatus.STOPPED: + # TODO: explorer need support no wait + # TODO: handle timeout + await asyncio.wait_for( + self._ready_condition.wait(), + timeout=self.config.synchronizer.sync_timeout, + ) + if self.model_version > current_version: + self.set_explorer_status( + RunningStatus.WAITING_SYNC, old_status=RunningStatus.REQUIRE_SYNC + ) + return self.model_version + + async def ready_to_nccl_sync( + self, module: str, trainer_step: Optional[int] = None + ) -> Union[int, None]: + """ + Prepare for NCCL-based synchronization between modules. + + Only supports one explorer currently. + + Args: + module: Either 'trainer' or 'explorer'. + trainer_step: Optional step number from the trainer. + + Returns: + The model version if both sides are ready; otherwise None. + """ + assert ( + sum(self.explorer_status_counts.values()) == 1 + ), "NCCL sync is only supported for one explorer." + + def sync_failed(): + if module == "explorer": + another_module = "Trainer" + self.set_explorer_status( + RunningStatus.REQUIRE_SYNC, old_status=RunningStatus.WAITING_SYNC + ) + else: + another_module = "Explorer" + self.trainer_status = RunningStatus.REQUIRE_SYNC + self.logger.error(f"{another_module} is not ready for model weight sync.") + return None + + non_stop_cnt = sum( + value + for key, value in self.explorer_status_counts.items() + if key != RunningStatus.STOPPED + ) + if non_stop_cnt == 0: + return sync_failed() + + async with self._ready_condition: + try: + if module == "trainer": + self.model_version = trainer_step + self.trainer_status = RunningStatus.WAITING_SYNC + self._ready_condition.notify_all() + if self.explorer_status_counts[RunningStatus.WAITING_SYNC] != 1: + await asyncio.wait_for( + self._ready_condition.wait_for( + lambda: self.explorer_status_counts[RunningStatus.WAITING_SYNC] + == 1, + ), + timeout=self.config.synchronizer.sync_timeout, + ) + elif module == "explorer": + self.set_explorer_status( + RunningStatus.WAITING_SYNC, old_status=RunningStatus.REQUIRE_SYNC + ) + self._ready_condition.notify_all() + if self.trainer_status != RunningStatus.WAITING_SYNC: + await asyncio.wait_for( + self._ready_condition.wait_for( + lambda: self.trainer_status + in {RunningStatus.WAITING_SYNC, RunningStatus.STOPPED}, + ), + timeout=self.config.synchronizer.sync_timeout, + ) + if self.trainer_status == RunningStatus.STOPPED: + return sync_failed() + self.trainer_status = RunningStatus.RUNNING + return self.model_version + except asyncio.TimeoutError: + return sync_failed() + + @classmethod + def get_actor(cls, config: Optional[Config] = None, namespace: Optional[str] = None): + """ + Get or create a remote Ray actor for the Synchronizer. + + Args: + config: Optional configuration to use for creating the actor. + namespace: Optional Ray namespace for the actor. + + Returns: + A reference to the Synchronizer actor. + """ + if config is not None: + return ( + ray.remote(cls) + .options( + name="synchronizer", + namespace=config.ray_namespace, + get_if_exists=True, + lifetime="detached", + ) + .remote(config) + ) + return ray.get_actor("synchronizer", namespace=namespace) + + def acquire(self): + self.ref_count += 1 + + def release(self): + self.ref_count -= 1 + return self.ref_count diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 6cffa185be..d03bbf3b74 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -230,6 +230,7 @@ class Trainer: total_epochs: int = 30 total_training_steps: Optional[int] = None project_name: str = "" + group_name: str = "" experiment_name: str = "" logger: List[str] = field(default_factory=list) val_generations_to_log_to_wandb: int = 0 @@ -306,6 +307,7 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.trainer.sync_freq = config.synchronizer.sync_interval self.trainer.save_freq = config.trainer.save_interval self.trainer.project_name = config.project + self.trainer.group_name = config.group self.trainer.experiment_name = config.name self.trainer.default_local_dir = config.checkpoint_job_dir self.trainer.sft_warmup_steps = config.buffer.trainer_input.sft_warmup_steps diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 371337d144..4f6428f1ba 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -9,6 +9,7 @@ from collections import deque from typing import List, Optional +import ray import torch from trinity.algorithm import ADD_STRATEGY @@ -20,12 +21,10 @@ ROLLOUT_WEIGHT_SYNC_GROUP_NAME, RunningStatus, SyncMethod, + SyncStyle, ) from trinity.common.models import create_inference_models -from trinity.common.models.utils import ( - get_checkpoint_dir_with_step_num, - load_state_dict, -) +from trinity.common.synchronizer import Synchronizer from trinity.explorer.scheduler import Scheduler from trinity.manager.manager import CacheManager from trinity.utils.log import get_logger @@ -41,6 +40,7 @@ def __init__(self, config: Config): explorer_meta = self.cache.load_explorer() self.explore_step_num = explorer_meta.get("latest_iteration", 0) self.last_sync_step = self.explore_step_num if self.explore_step_num > 0 else -1 + self.synchronizer = Synchronizer.get_actor(config) self.config = config self.algorithm_manager = AlgorithmManager(config) self.models, self.auxiliary_models = create_inference_models(config) @@ -57,6 +57,7 @@ def __init__(self, config: Config): self.scheduler = self._init_scheduler() self.monitor = MONITOR.get(self.config.monitor.monitor_type)( project=self.config.project, + group=self.config.group, name=self.config.name, role=self.config.explorer.name, config=config, @@ -65,22 +66,15 @@ def __init__(self, config: Config): self.update_interval = ( self.config.synchronizer.sync_interval * self.config.buffer.batch_size ) - self.use_checkpoint_weights_update = ( - self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT - ) + self.use_nccl_sync = self.config.synchronizer.sync_method == SyncMethod.NCCL self.pending_eval_tasks = deque() # For checkpoint weights update # Use explorer to periodically load the latest model weights and # boradcast to all rollout models - if self.use_checkpoint_weights_update: - self.old_checkpoint = None - self.state_dict = {} - else: # nccl mode - self.state_dict_meta = [] - self.status = RunningStatus.RUNNING + self.model_version = -1 + self.last_sync_successful = True self.logger.info("Finished initializing Explorer.") - self._ready_to_sync_condition = asyncio.Condition() self.collect_experiences = self.config.explorer.collect_experiences self.generated_experience_cnt = 0 if self.collect_experiences: @@ -95,7 +89,7 @@ async def setup_weight_sync_group( self, master_address: str, master_port: int, state_dict_meta: List = None ): # In checkpoint mode, we use explorer to store the model weights which has no rank - base_offset = 0 if self.use_checkpoint_weights_update else 1 + base_offset = 1 if self.use_nccl_sync else 0 world_size = ( len(self.models) * self.config.explorer.rollout_model.tensor_parallel_size + base_offset ) @@ -104,7 +98,6 @@ async def setup_weight_sync_group( f"master_address={master_address}, master_port={master_port}, " f"world_size={world_size}, rank_offset={base_offset}" ) - self.state_dict_meta = state_dict_meta # TODO: save state_dict in models refs = [ model.init_process_group.remote( @@ -116,7 +109,6 @@ async def setup_weight_sync_group( group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME, explorer_name=self.config.explorer.name, timeout=self.config.synchronizer.sync_timeout, - update_with_checkpoint=self.use_checkpoint_weights_update, state_dict_meta=state_dict_meta, ) for i, model in enumerate(self.models) @@ -132,77 +124,67 @@ def _init_scheduler(self) -> Scheduler: ) return Scheduler(self.config, self.models, self.auxiliary_models) - async def _update_model_weight(self, step_num: int, state_dict: dict) -> None: - # TODO: update model weight - self.state_dict = state_dict - if self.state_dict_meta is None: - update_weight_args_list = [] - for name, param in state_dict.items(): - update_weight_args_list.append((name, str(param.dtype), tuple(param.shape))) - self.state_dict_meta = update_weight_args_list - else: - update_weight_args_list = None - await asyncio.gather( - *[model.sync_model.remote(step_num, update_weight_args_list) for model in self.models] - ) - self.state_dict.clear() - async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> int: - # TODO: support more checkpoint types - try: - checkpoint_dir, checkpoint_step_num = get_checkpoint_dir_with_step_num( - checkpoint_root_path=self.config.checkpoint_job_dir, - trainer_type=self.config.trainer.trainer_type, - step_num=step_num, + step_num = await self.synchronizer.set_model_state_dict_with_step_num.remote(step_num) + await asyncio.gather(*[model.sync_model.remote(step_num) for model in self.models]) + return step_num # type: ignore + + async def _pull_latest_weights(self): + self.logger.info("Start to pull latest model weights.") + new_version = await self.synchronizer.wait_new_model_state_dict.remote(self.model_version) + if new_version > self.model_version: + if self.model_version != -1: + self.logger.info(f"New model weights version: {new_version}") + await asyncio.gather( + *[model.sync_model.remote(new_version) for model in self.models] + ) + self.model_version = new_version + self.last_sync_step = self.explore_step_num + await self.synchronizer.set_explorer_status.remote( + RunningStatus.RUNNING, old_status=RunningStatus.WAITING_SYNC ) - if checkpoint_dir == self.old_checkpoint: - return checkpoint_step_num - model_weights = load_state_dict(os.path.join(checkpoint_dir, "actor")) - await self._update_model_weight(checkpoint_step_num, model_weights) - self.old_checkpoint = checkpoint_dir - return checkpoint_step_num - except Exception as e: - self.logger.warning(f"Fail to load checkpoint: {e}") - return 0 + self.last_sync_successful = True + else: + self.logger.warning( + f"No new model weights found, current version: {self.model_version}" + ) + self.last_sync_successful = False async def _nccl_weights_update(self): - assert self.state_dict_meta is not None - async with self._ready_to_sync_condition: - try: - await asyncio.wait_for( - self._ready_to_sync_condition.wait_for( - lambda: self.status == RunningStatus.WAITING_SYNC, - ), - timeout=self.config.synchronizer.sync_timeout, - ) - except asyncio.TimeoutError as e: - self.logger.error( - f"Trainer is not ready for model weight sync in {self.config.synchronizer.sync_timeout} seconds." - ) - raise e + new_version = await self.synchronizer.ready_to_nccl_sync.remote( + "explorer", self.model_version + ) + if new_version is None: + self.logger.info("Trainer is not ready to sync weight. Skipping sync weight.") + self.last_sync_successful = False + return + self.model_version = new_version await asyncio.gather( - *[model.sync_model.remote(self.explore_step_num) for model in self.models] + *[model.sync_model.remote(self.model_version) for model in self.models] ) - self.status = RunningStatus.RUNNING - - async def ready_to_sync(self): - async with self._ready_to_sync_condition: - self.status = RunningStatus.WAITING_SYNC - self._ready_to_sync_condition.notify_all() + self.last_sync_step = self.explore_step_num + await self.synchronizer.set_explorer_status.remote( + RunningStatus.RUNNING, old_status=RunningStatus.WAITING_SYNC + ) + self.last_sync_successful = True async def prepare(self) -> None: """Preparation before running.""" - futures = [asyncio.create_task(self.scheduler.start())] - if self.use_checkpoint_weights_update: + futures = [ + asyncio.create_task(self.scheduler.start()), + self.synchronizer.acquire.remote(), + ] + if self.experience_buffer: + futures.append(asyncio.create_task(self.experience_buffer.acquire())) + if not self.use_nccl_sync: master_address, master_port = await self.models[0].get_available_address.remote() futures.append( asyncio.create_task(self.setup_weight_sync_group(master_address, master_port)) ) - asyncio.gather(*futures, return_exceptions=True) - if self.experience_buffer: - await self.experience_buffer.acquire() + await asyncio.gather(*futures, return_exceptions=True) if self.config.explorer.eval_on_startup and self.explore_step_num == 0: self.eval() + await self.synchronizer.set_explorer_status.remote(RunningStatus.REQUIRE_SYNC) async def get_weight(self, name: str) -> torch.Tensor: """Get the weight of the loaded model (For checkpoint weights update).""" @@ -229,12 +211,14 @@ async def explore(self) -> str: break if self.need_eval(): self.eval() - if self.need_sync(): + if await self.need_sync(): await self.sync_weight() except Exception: self.logger.error(f"Error in Explorer: {traceback.format_exc()}") break - self.logger.info("--------------------\n> Explorer finished.\n--------------------") + self.logger.info( + f"--------------------\n> Explorer ({self.config.explorer.name}) finished.\n--------------------" + ) return self.config.explorer.name async def explore_step(self) -> bool: @@ -248,19 +232,40 @@ async def explore_step(self) -> bool: except StopIteration: self.logger.warning("No more tasks to explore. Stop exploring.") await self.save_checkpoint(sync_weight=False) - self.status = RunningStatus.STOPPED + await self.synchronizer.set_explorer_status.remote( + RunningStatus.STOPPED, + old_status=RunningStatus.RUNNING + if self.last_sync_successful + else RunningStatus.REQUIRE_SYNC, + ) await self.experience_buffer.release() return False self.scheduler.schedule(tasks, batch_id=self.explore_step_num + 1) self.explore_step_num += 1 return True - def need_sync(self) -> bool: - if self.explore_step_num <= self.config.synchronizer.sync_offset: - return False - return ( - self.explore_step_num - self.config.synchronizer.sync_offset - ) % self.config.synchronizer.sync_interval == 0 + async def need_sync(self) -> bool: + if self.config.synchronizer.sync_style == SyncStyle.FIXED: + if self.explore_step_num <= self.config.synchronizer.sync_offset: + return False + require_sync = ( + self.explore_step_num - self.config.synchronizer.sync_offset + ) % self.config.synchronizer.sync_interval == 0 + else: + require_sync = False + if self.config.synchronizer.sync_style == SyncStyle.DYNAMIC_BY_EXPLORER: + delta = self.explore_step_num - self.last_sync_step + if delta >= self.config.synchronizer.sync_interval: + require_sync = True + else: + require_sync = await ( + self.synchronizer.get_trainer_status.remote() == RunningStatus.REQUIRE_SYNC + ) + if require_sync and self.last_sync_successful: + await self.synchronizer.set_explorer_status.remote( + RunningStatus.REQUIRE_SYNC, old_status=RunningStatus.RUNNING + ) + return require_sync def need_eval(self) -> bool: return self.explore_step_num % self.config.explorer.eval_interval == 0 @@ -319,18 +324,19 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None: await self.scheduler.wait_all() self.logger.info(f"All tasks before step {self.explore_step_num} have completed.") log_task = asyncio.create_task( - self._finish_steps(self.last_sync_step + 1, self.explore_step_num) + self._finish_steps(self.last_sync_step + 1, self.explore_step_num, self.model_version) ) if sync_weight: # sync weights self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} started.") - if self.use_checkpoint_weights_update: - await self._checkpoint_weights_update() - else: # nccl weights update + if self.use_nccl_sync: await self._nccl_weights_update() - self.last_sync_step = self.explore_step_num - self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} finished") + else: # pull weights from Synchronizer + await self._pull_latest_weights() + self.logger.info( + f"Explorer sync_weights at step {self.explore_step_num} finished, model version = {self.model_version}." + ) # overlay log and weight sync await log_task @@ -346,15 +352,15 @@ async def sync_weight(self) -> None: # call this method before training start to load the latest model weights await self.save_checkpoint(sync_weight=True) - async def _finish_steps(self, start_step: int, end_step: int) -> None: + async def _finish_steps(self, start_step: int, end_step: int, model_version: int) -> None: for step in range(start_step, end_step + 1): self.logger.info(f"Log metrics of step {step}") - await self._finish_explore_step(step=step) + await self._finish_explore_step(step=step, model_version=model_version) await self._finish_eval_step(step=step) - async def _finish_explore_step(self, step: int) -> None: + async def _finish_explore_step(self, step: int, model_version: int) -> None: statuses, exps = await self.scheduler.get_results(batch_id=step) - metric = {} + metric = {"rollout/model_version": model_version} if self.config.explorer.collect_experiences: exp_cnt, add_strategy_metric = await self.add_strategy.add(exps, step) self.generated_experience_cnt += exp_cnt @@ -384,9 +390,9 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva metric[f"{prefix}/total_time"] = time.time() - st self.monitor.log(metric, step) - async def running_status(self) -> RunningStatus: - return self.status - async def shutdown(self) -> None: self.monitor.close() + if await self.synchronizer.release.remote() == 0: + ray.kill(self.synchronizer) + self.logger.info("Synchronizer stopped.") await self.scheduler.stop() diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 5580d20fa0..2c3cb0c5fb 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -227,7 +227,7 @@ def task_done_callback(self, async_task: asyncio.Task): if async_task.cancelled(): return elif async_task.exception(): - self.logger.error(f"Task {task.task_id} failed: {async_task.exception()}") + self.logger.error(f"Task {task.task.task_id} failed: {async_task.exception()}") return else: status, exps, runner_id = async_task.result() diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index e16837cd2a..00de93c10d 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -4,7 +4,6 @@ """ from __future__ import annotations -import os import traceback from abc import ABC, abstractmethod from typing import Dict, List, Tuple @@ -15,8 +14,9 @@ from trinity.algorithm import SAMPLE_STRATEGY from trinity.algorithm.utils import prefix_metrics from trinity.common.config import Config -from trinity.common.constants import RunningStatus, SyncMethod +from trinity.common.constants import RunningStatus, SyncMethod, SyncStyle from trinity.common.experience import Experiences +from trinity.common.synchronizer import Synchronizer from trinity.utils.log import get_logger from trinity.utils.monitor import MONITOR @@ -27,10 +27,13 @@ class Trainer: def __init__(self, config: Config) -> None: self.config = config self.logger = get_logger(__name__) + self.synchronizer = Synchronizer.get_actor(config) + ray.get(self.synchronizer.acquire.remote()) self.engine = get_trainer_wrapper(config) - self.explorer_ref = None + self.last_trainer_sync_step = 0 self.monitor = MONITOR.get(config.monitor.monitor_type)( project=config.project, + group=self.config.group, name=config.name, role=config.trainer.name, config=config, @@ -44,6 +47,8 @@ def __init__(self, config: Config) -> None: def prepare(self) -> None: """Prepare the trainer.""" self.engine.prepare() + self.last_trainer_sync_step = self.engine.train_step_num + ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)) def train(self) -> str: """Train the model.""" @@ -57,6 +62,8 @@ def train(self) -> str: except Exception: self.logger.error(f"Error in Trainer:\n{traceback.format_exc()}") break + ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.STOPPED)) + self.engine.save_checkpoint(block_until_saved=True) self.logger.info("--------------------\n> Trainer finished.\n--------------------") return self.config.trainer.name @@ -89,25 +96,39 @@ def train_step(self) -> bool: def need_sync(self) -> bool: """Whether to sync the model weight.""" - return self.engine.train_step_num % self.config.synchronizer.sync_interval == 0 + if self.config.synchronizer.sync_style == SyncStyle.FIXED: + return self.engine.train_step_num % self.config.synchronizer.sync_interval == 0 + else: + if self.config.synchronizer.sync_style == SyncStyle.DYNAMIC_BY_TRAINER: + delta = self.engine.train_step_num - self.last_trainer_sync_step + if delta >= self.config.synchronizer.sync_interval: + ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.REQUIRE_SYNC)) + explorer_status_counts = ray.get(self.synchronizer.get_explorer_status_counts.remote()) + if self.config.synchronizer.sync_method == SyncMethod.NCCL: + return explorer_status_counts[RunningStatus.WAITING_SYNC] > 0 + else: # memory & checkpoint + return explorer_status_counts[RunningStatus.REQUIRE_SYNC] > 0 def sync_weight(self) -> None: """Sync the model weight.""" + self.logger.info( + f"Trainer synchronizing weights at step {self.engine.train_step_num} starting.." + ) if self.config.synchronizer.sync_method == SyncMethod.NCCL: - self.logger.info( - f"Trainer synchronizing weights at step {self.engine.train_step_num} starting.." - ) - if self.explorer_ref is None: - self.explorer_ref = ray.get_actor(self.config.explorer.name) - explorer_status = ray.get(self.explorer_ref.running_status.remote()) - if explorer_status == RunningStatus.STOPPED: - self.logger.warning("Explorer has already stopped. Skipping sync weight.") - return - ray.get(self.explorer_ref.ready_to_sync.remote()) - self.engine.sync_weight() - self.logger.info( - f"Trainer synchronizing weights at step {self.engine.train_step_num} end." + result = ray.get( + self.synchronizer.ready_to_nccl_sync.remote("trainer", self.engine.train_step_num) ) + if result is None: + self.logger.error("Trainer synchronizing weights failed.") + else: + self.engine.sync_weight() + self.last_trainer_sync_step = self.engine.train_step_num + elif self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT: + self.engine.save_state_dict() + elif self.config.synchronizer.sync_method == SyncMethod.MEMORY: + self.engine.upload_state_dict() + self.logger.info(f"Trainer synchronizing weights at step {self.engine.train_step_num} end.") + ray.get(self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)) def _log_experiences(self, samples: List[Dict]) -> None: self._sample_exps_to_log.extend(samples) @@ -118,10 +139,10 @@ def _log_experiences(self, samples: List[Dict]) -> None: self._sample_exps_to_log.clear() def shutdown(self) -> None: - # if checkpoint not saved, save the last checkpoint - path = os.path.join(self.config.checkpoint_job_dir, f"global_step_{self.train_step_num}") - if not os.path.isdir(path) or len(os.listdir(path)) == 0: - self.engine.save_checkpoint() + self.monitor.close() + if ray.get(self.synchronizer.release.remote()) == 0: + ray.kill(self.synchronizer) + self.logger.info("Synchronizer stopped.") @property def train_step_num(self) -> int: @@ -154,7 +175,7 @@ def train_step(self, batch: Experiences) -> Tuple[bool, Dict]: """ @abstractmethod - def save_checkpoint(self) -> None: + def save_checkpoint(self, block_until_saved: bool = False) -> None: """Save the checkpoint.""" @abstractmethod @@ -162,8 +183,12 @@ def sync_weight(self) -> None: """Sync the model weight.""" @abstractmethod - def shutdown(self) -> None: - """Shutdown the engine.""" + def upload_state_dict(self) -> None: + """Upload the state dict to Synchronizer.""" + + @abstractmethod + def save_state_dict(self) -> None: + """Only save the model state dict for Synchronizer.""" def get_trainer_wrapper(config: Config) -> TrainEngineWrapper: diff --git a/trinity/trainer/verl/fsdp_checkpoint_manager.py b/trinity/trainer/verl/fsdp_checkpoint_manager.py new file mode 100644 index 0000000000..1899cb5ad8 --- /dev/null +++ b/trinity/trainer/verl/fsdp_checkpoint_manager.py @@ -0,0 +1,363 @@ +import json +import os +import threading +import warnings +from dataclasses import asdict +from typing import Optional, Union + +import ray +import torch +from accelerate import init_empty_weights +from torch.distributed.fsdp import ( + FullStateDictConfig, + ShardedOptimStateDictConfig, + ShardedStateDictConfig, + StateDictType, +) +from transformers import GenerationConfig +from verl.utils.checkpoint.fsdp_checkpoint_manager import ( + FSDPCheckpointManager as OldFSDPCheckpointManager, +) +from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPConfig, logger +from verl.utils.device import is_cuda_available +from verl.utils.fs import local_mkdir_safe +from verl.utils.fsdp_utils import ( + fsdp_version, + get_fsdp_full_state_dict, + get_fsdp_state_ctx, +) +from verl.utils.logger import log_with_rank + +from trinity.common.constants import SyncMethod +from trinity.common.synchronizer import Synchronizer + + +class FSDPCheckpointManager(OldFSDPCheckpointManager): + """ + An enhanced version of the original FSDP checkpoint manager that: + + 1. Uploads model state dicts to a remote Synchronizer actor (either directly or via checkpoints). + 2. Offloads saving operations (model, optimizer, extra states) into background threads to avoid blocking the training loop. + + This class is useful in distributed training scenarios where synchronization and non-blocking I/O are important. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + config = kwargs.pop("config", None) + self.synchronizer_config = config + if config is not None: + # Retrieve the remote Synchronizer actor using the provided namespace + self.synchronizer = Synchronizer.get_actor(namespace=config.ray_namespace) + else: + self.synchronizer = None + + # Threads for asynchronous saving of different components + self._model_state_dict_thread = None + self._optimizer_state_dict_thread = None + self._extra_state_dict_thread = None + self._save_model_thread = None + + def _notify_synchronizer_with_step_num(self, global_step): + """ + Notifies the Synchronizer actor about the current training step number, + used when SyncMethod is CHECKPOINT. + + Args: + global_step (int): The current global training step. + """ + if getattr(self.synchronizer_config, "sync_method", None) == SyncMethod.CHECKPOINT: + ray.get( + self.synchronizer.set_model_state_dict_with_step_num.remote( + global_step, self.world_size + ) + ) + + def _upload_state_dict(self, state_dict: Union[dict, None], global_step: int): + """ + Internal method to upload a state dict to the Synchronizer actor. + + Args: + state_dict (dict or None): The model state dictionary to upload. + global_step (int): The current training step number. + """ + if self.rank == 0: + ray.get(self.synchronizer.set_model_state_dict.remote(state_dict, global_step)) + + def upload_state_dict(self, global_step: int): + """ + Uploads the full model state dictionary to the synchronizer actor for remote access. + + Args: + global_step (int): The current training step number. + """ + assert self.synchronizer is not None + state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with get_fsdp_state_ctx(self.model, StateDictType.FULL_STATE_DICT, state_dict_config, None): + state_dict = self.model.state_dict() + self._upload_state_dict(state_dict, global_step) + + def save_checkpoint( # noqa: C901 + self, + local_path: str, + hdfs_path: str = None, + global_step: int = 0, + max_ckpt_to_keep: Optional[int] = None, + model_state_dict_only: bool = False, + ): + """ + Modified from verl.utils.checkpoint.fsdp_checkpoint_manager.py:save_checkpoint + + Saves the model checkpoint to disk, optionally uploads it to a remote Synchronizer, + and uses background threads to prevent blocking the main training loop. + + Main improvements over the base class: + - Uses separate threads for saving model/optimizer/extras. + - Implements synchronization with a remote actor. If the model is not trained (`global_step == 0`) or continues from a breakpoint, `Synchonizer` will be notified and the model will not be saved. + + Args: + local_path (str): Local directory path to save the checkpoint. + hdfs_path (str, optional): HDFS path for saving the checkpoint (not implemented here). + global_step (int): Current training step. + max_ckpt_to_keep (int, optional): Maximum number of checkpoints to keep locally. + model_state_dict_only (bool): Whether to only save the model state dict (no optimizer, etc.). + """ + if global_step == 0 and model_state_dict_only: + self._upload_state_dict(None, global_step) + return + if local_path is None: + return + + # record the previous global step + self.previous_global_step = global_step + + # remove previous local_path, only rank 0 should do this + if ( + self.rank == 0 + and max_ckpt_to_keep + and isinstance(max_ckpt_to_keep, int) + and max_ckpt_to_keep > 0 + and len(self.previous_saved_paths) >= max_ckpt_to_keep # type: ignore + ): + keep_start = len(self.previous_saved_paths) - max_ckpt_to_keep + 1 # type: ignore + self.remove_previous_save_local_path(self.previous_saved_paths[:keep_start]) # type: ignore + self.previous_saved_paths = self.previous_saved_paths[keep_start:] # type: ignore + + local_path = local_mkdir_safe(local_path) + torch.distributed.barrier() + + # check if the checkpoint_save_contents is valid + if self.should_save_model: + assert ( + self.model is not None + ), "model must be provided when checkpoint_contents.save includes ['model']" + if self.should_save_optimizer: + assert ( + self.optimizer is not None + ), "optimizer must be provided when checkpoint_contents.save includes ['optimizer']" + + # every rank will save its own model and optim shard + state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + with get_fsdp_state_ctx( + self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg + ): + model_path = os.path.join( + local_path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt" + ) + optim_path = os.path.join( + local_path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt" + ) + extra_path = os.path.join( + local_path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt" + ) + + if self.should_save_model or model_state_dict_only: + if os.path.exists(model_path): + if self._model_state_dict_thread is None: + # If resuming from a checkpoint, notify synchronizer immediately + self._notify_synchronizer_with_step_num(global_step) + else: + model_state_dict = self.model.state_dict() + if self._model_state_dict_thread is not None: + self._model_state_dict_thread.join() + + def _save_model_state_dict(): + torch.save(model_state_dict, model_path) + log_with_rank( + f"Saved model to {os.path.abspath(model_path)}", + rank=self.rank, + logger=logger, + ) + self._notify_synchronizer_with_step_num(global_step) + + self._model_state_dict_thread = threading.Thread( + target=_save_model_state_dict, + ) + self._model_state_dict_thread.start() + + if self.should_save_optimizer and not model_state_dict_only: + optimizer_state_dict = self.optimizer.state_dict() + if self._optimizer_state_dict_thread is not None: + self._optimizer_state_dict_thread.join() + + def _save_optimizer_state_dict(): + torch.save(optimizer_state_dict, optim_path) + log_with_rank( + f"Saved optim to {os.path.abspath(optim_path)}", + rank=self.rank, + logger=logger, + ) + + self._optimizer_state_dict_thread = threading.Thread( + target=_save_optimizer_state_dict, + ) + self._optimizer_state_dict_thread.start() + + if self.should_save_extra and not model_state_dict_only: + lr_scheduler_state_dict = ( + self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None + ) + extra_state_dict = { + "lr_scheduler": lr_scheduler_state_dict, + "rng": self.get_rng_state(), + } + if self._extra_state_dict_thread is not None: + self._extra_state_dict_thread.join() + + def _save_extra_state_dict(): + torch.save(extra_state_dict, extra_path) + log_with_rank( + f"Saved extra_state to {os.path.abspath(extra_path)}", + rank=self.rank, + logger=logger, + ) + + self._extra_state_dict_thread = threading.Thread( + target=_save_extra_state_dict, + ) + self._extra_state_dict_thread.start() + + if self.rank == 0: + # Save HF tokenizer/processor and model config on rank 0 to huggingface/ directory, no matter whether + # huggingface model is requested to be saved or not. + + if fsdp_version(self.model) == 1: + unwrap_model = self.model._fsdp_wrapped_module + else: + unwrap_model = self.model + + hf_config_tokenizer_path = os.path.join(local_path, "huggingface") + local_mkdir_safe(hf_config_tokenizer_path) + model_config = unwrap_model.config + if ( + unwrap_model.can_generate() + and hasattr(model_config, "name_or_path") + and model_config.name_or_path + ): + # Some model's name_or_path is empty if not initialized from pretrained, + # in this cases, we don't save generation config. + generation_config = GenerationConfig.from_pretrained(model_config.name_or_path) + generation_config.save_pretrained(hf_config_tokenizer_path) + else: + generation_config = None + + model_config.save_pretrained(hf_config_tokenizer_path) + self.processing_class.save_pretrained(hf_config_tokenizer_path) + log_with_rank( + f"Saved model config and tokenizer class to {os.path.abspath(hf_config_tokenizer_path)}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) + + # Also save runtime FSDP config + fsdp_config_path = os.path.join(local_path, "fsdp_config.json") + fsdp_config = FSDPConfig( + FSDP_version=fsdp_version(self.model), + world_size=self.world_size, + ) + with open(fsdp_config_path, "w") as f: + json.dump(asdict(fsdp_config), f, indent=4) + + # wait for everyone to dump to local + torch.distributed.barrier() + + if self.should_save_hf_model and not model_state_dict_only: + # Only rank 0 will save hf model and, + # offload to cpu to save LLMs which may be too large to fit in one GPU + state_dict = get_fsdp_full_state_dict(self.model, offload_to_cpu=True, rank0_only=True) + + if self.rank == 0: + hf_local_path = os.path.join(local_path, "huggingface") + os.makedirs(hf_local_path, exist_ok=True) + + if "ForTokenClassification" in model_config.architectures[0]: + from transformers import AutoModelForTokenClassification + + auto_model_cls = AutoModelForTokenClassification + elif "ForCausalLM" in model_config.architectures[0]: + from transformers import AutoModelForCausalLM + + auto_model_cls = AutoModelForCausalLM + elif "ForConditionalGeneration" in model_config.architectures[0]: + from transformers import AutoModelForVision2Seq + + auto_model_cls = AutoModelForVision2Seq + else: + raise NotImplementedError( + f"Unknown architecture {model_config['architectures']}" + ) + + with init_empty_weights(): + save_model = auto_model_cls.from_config( + model_config, torch_dtype=torch.bfloat16 + ) + save_model.to_empty(device="cpu") + + if save_model.can_generate(): + if generation_config is not None: + save_model.generation_config = generation_config + else: + print( + f"Warning: {self.__class__.__name__}.save_checkpoint: Generation config file not found in, using a generation config created from the model config when saving hf_model." + ) + + if self._save_model_thread is not None: + self._save_model_thread.join() + + def _save_model(): + save_model.save_pretrained(hf_local_path, state_dict=state_dict) + log_with_rank( + f"Saved hf_model to {os.path.abspath(hf_local_path)}", + rank=self.rank, + logger=logger, + log_only_rank_0=True, + ) + + self._save_model_thread = threading.Thread( + target=_save_model, + ) + self._save_model_thread.start() + self.processing_class.save_pretrained(hf_local_path) + + # wait for rank0 to dump hf_model to local + torch.distributed.barrier() + + if not model_state_dict_only: + self.previous_saved_paths.append(local_path) + + def wait_on_save_thread(self) -> None: + """ + Wait for all background saving threads to complete. + """ + if self._model_state_dict_thread is not None: + self._model_state_dict_thread.join() + if self._optimizer_state_dict_thread is not None: + self._optimizer_state_dict_thread.join() + if self._extra_state_dict_thread is not None: + self._extra_state_dict_thread.join() + if self._save_model_thread is not None: + self._save_model_thread.join() diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 00a67ee002..e85c8ef540 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -42,7 +42,6 @@ from verl.single_controller.base.decorator import Dispatch, register from verl.utils import hf_processor, hf_tokenizer from verl.utils.activation_offload import enable_activation_offloading -from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.utils.debug import log_gpu_memory_usage from verl.utils.device import ( get_device_id, @@ -78,6 +77,8 @@ from trinity.common.config import AlgorithmConfig from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod +from trinity.common.synchronizer import Synchronizer +from trinity.trainer.verl.fsdp_checkpoint_manager import FSDPCheckpointManager from trinity.utils.distributed import init_process_group logger = logging.getLogger(__file__) @@ -560,14 +561,12 @@ def init_model(self): lr_scheduler=self.actor_lr_scheduler, processing_class=self.processor if self.processor is not None else self.tokenizer, checkpoint_config=self.config.actor.checkpoint, + config=self.config.synchronizer, ) @register(dispatch_mode=Dispatch.ONE_TO_ALL) def setup_weight_sync_group(self): - if ( - hasattr(self.config, "synchronizer") - and getattr(self.config.synchronizer, "sync_method", None) == SyncMethod.NCCL - ): + if self.config.synchronizer.sync_method == SyncMethod.NCCL: model = self.actor_module_fsdp self.named_modules = [] self.state_dict_meta = [] @@ -594,8 +593,8 @@ def setup_weight_sync_group(self): master_address, master_port = self.get_availale_master_addr_port() world_size = self.config.synchronizer.explorer_world_size + 1 print(f"Trainer init_process_group {master_address}:{master_port} ({world_size}).") - explorer = ray.get_actor(self.config.explorer_name) - setup_ref = explorer.setup_weight_sync_group.remote( + synchronizer = Synchronizer.get_actor(self.config.synchronizer) + setup_ref = synchronizer.setup_weight_sync_group.remote( master_address, master_port, self.state_dict_meta ) timeout = self.config.synchronizer.sync_timeout @@ -628,6 +627,10 @@ def sync_weight(self): torch.distributed.barrier() torch.cuda.empty_cache() + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def upload_state_dict(self, trainer_step: int): + self.checkpoint_manager.upload_state_dict(trainer_step) + @register(dispatch_mode=Dispatch.ONE_TO_ALL) def set_algorithm(self, algo_config: AlgorithmConfig): self.actor.set_algorithm(algo_config) @@ -762,7 +765,14 @@ def compute_ref_log_prob(self, data: DataProto): return output @register(dispatch_mode=Dispatch.ONE_TO_ALL) - def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): + def save_checkpoint( + self, + local_path, + hdfs_path=None, + global_step=0, + max_ckpt_to_keep=None, + model_state_dict_only=False, + ): from verl.utils.logger import log_with_rank # only support save and load ckpt for actor @@ -776,6 +786,7 @@ def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to hdfs_path=hdfs_path, global_step=global_step, max_ckpt_to_keep=max_ckpt_to_keep, + model_state_dict_only=model_state_dict_only, ) dist.barrier() @@ -851,6 +862,10 @@ def clear_optimizer_state(self): if self._is_offload_optimizer: offload_fsdp_optimizer(self.actor_optimizer) + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def wait_on_save_thread(self) -> None: + self.checkpoint_manager.wait_on_save_thread() + class CriticWorker(Worker): def __init__(self, config): @@ -1276,3 +1291,7 @@ def clear_optimizer_state(self): self.critic_optimizer.zero_grad() if self._is_offload_optimizer: offload_fsdp_optimizer(self.critic_optimizer) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def wait_on_save_thread(self) -> None: + self.checkpoint_manager.wait_on_save_thread() diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 93dbf15ec2..2f50b18ca2 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -145,6 +145,7 @@ def __init__( ) self.init_workers() self.logger = get_logger(__name__) + self.last_full_save_step = None def _validate_config(self): # TODO algorithm = ALGORITHM_TYPE.get(self.algorithm_config.algorithm_type) @@ -245,9 +246,6 @@ def init_workers(self): self.actor_rollout_wg = all_wg["actor"] self.actor_rollout_wg.init_model() - def reset_experiences_example_table(self): - self.sample_exps_to_log = [] - @property def train_step_num(self) -> int: return self.global_steps @@ -268,6 +266,24 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl # TODO: compute total training steps self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize + def save_state_dict(self): # checkpoint sync + local_global_step_folder = os.path.join( + self.config.trainer.default_local_dir, f"global_step_{self.global_steps}" + ) + + actor_local_path = os.path.join(local_global_step_folder, "actor") + if not os.path.exists(actor_local_path): + self.actor_rollout_wg.save_checkpoint( + actor_local_path, + None, + self.global_steps, + max_ckpt_to_keep=None, + model_state_dict_only=True, + ) + + def upload_state_dict(self): # state dict sync + self.actor_rollout_wg.upload_state_dict(self.global_steps) + def train_step(self, batch: Experiences) -> Tuple[bool, Dict]: # noqa C901 self.logger.info(f"Training at step {self.global_steps + 1} started.") batch = to_data_proto(batch) @@ -332,15 +348,6 @@ def train_step(self, batch: Experiences) -> Tuple[bool, Dict]: # noqa C901 actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) - if ( - self.config.trainer.save_freq > 0 - and self.global_steps % self.config.trainer.save_freq == 0 - ): - self.logger.info(f"Saving at step {self.global_steps}.") - with marked_timer("save_checkpoint", timing_raw): - self._save_checkpoint() - self.logger.info(f"Saved at step {self.global_steps}.") - # collect metrics metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) @@ -350,20 +357,29 @@ def train_step(self, batch: Experiences) -> Tuple[bool, Dict]: # noqa C901 ) train_status = self.global_steps < self.total_training_steps - if not train_status or self.algorithm_manager.need_save(self.global_steps): - if ( - self.config.trainer.save_freq == 0 - or self.global_steps % self.config.trainer.save_freq != 0 - ): - self.logger.info(f"Saving at step {self.global_steps}.") - with marked_timer("save_checkpoint", timing_raw): - self._save_checkpoint() - self.logger.info(f"Saved at step {self.global_steps}.") + if ( + not train_status + or self.algorithm_manager.need_save(self.global_steps) + or ( + self.config.trainer.save_freq > 0 + and self.global_steps % self.config.trainer.save_freq == 0 + ) + ): + self.logger.info(f"Saving at step {self.global_steps}.") + with marked_timer("save_checkpoint", timing_raw): + self.save_checkpoint() + self.logger.info(f"Saved at step {self.global_steps}.") self.logger.info(f"Training at step {self.global_steps} finished.") return train_status, metrics - def save_checkpoint(self) -> None: - self._save_checkpoint() + def save_checkpoint(self, block_until_saved: bool = False) -> None: + if self.last_full_save_step != self.global_steps: + self.last_full_save_step = self.global_steps + self._save_checkpoint() + if block_until_saved: + self.actor_rollout_wg.wait_on_save_thread() + if self.algorithm.use_critic: + self.critic_wg.wait_on_save_thread() def sync_weight(self) -> None: self.actor_rollout_wg.sync_weight() @@ -380,9 +396,10 @@ def sft_to_rft(self) -> None: global_step_folder = find_latest_ckpt_path(checkpoint_folder) # None if no latest # find global_step_folder + self.actor_rollout_wg.wait_on_save_thread() if self.config.trainer.resume_mode == "auto": if global_step_folder is None: - print("Training from scratch") + self.logger.info("Training from scratch") return else: if not (self.config.trainer.resume_from_path and global_step_folder is not None): @@ -396,20 +413,17 @@ def sft_to_rft(self) -> None: if not os.path.isabs(global_step_folder): working_dir = os.getcwd() global_step_folder = os.path.join(working_dir, global_step_folder) - print(f"Load from checkpoint folder: {global_step_folder}") + self.logger.info(f"Load from checkpoint folder: {global_step_folder}") # set global step global_steps = int(global_step_folder.split("global_step_")[-1]) assert self.global_steps == global_steps + 1 - print(f"Resuming from {global_step_folder}") + self.logger.info(f"Resuming from {global_step_folder}") actor_path = os.path.join(global_step_folder, "actor") - print(f"Loading actor from {actor_path} to ref_policy_wg") + self.logger.info(f"Loading actor from {actor_path} to ref_policy_wg") self.ref_policy_wg.load_checkpoint(actor_path, del_local_after_load=False) self.actor_rollout_wg.clear_optimizer_state() if self.use_critic: self.critic_wg.clear_optimizer_state() - print("sft to rft finished") - - def shutdown(self) -> None: - pass + self.logger.info("sft to rft finished") diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 5896fc110d..6ba4d7482d 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -80,7 +80,9 @@ def calculate_metrics( @MONITOR.register_module("tensorboard") class TensorboardMonitor(Monitor): - def __init__(self, project: str, name: str, role: str, config: Config = None) -> None: + def __init__( + self, project: str, group: str, name: str, role: str, config: Config = None + ) -> None: self.tensorboard_dir = os.path.join(config.monitor.cache_dir, "tensorboard", role) os.makedirs(self.tensorboard_dir, exist_ok=True) self.logger = SummaryWriter(self.tensorboard_dir) @@ -101,10 +103,14 @@ def close(self) -> None: @MONITOR.register_module("wandb") class WandbMonitor(Monitor): - def __init__(self, project: str, name: str, role: str, config: Config = None) -> None: + def __init__( + self, project: str, group: str, name: str, role: str, config: Config = None + ) -> None: + if not group: + group = name self.logger = wandb.init( project=project, - group=name, + group=group, name=f"{name}_{role}", tags=[role], config=config,