diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 726e22290e..0ec438c2db 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -115,6 +115,56 @@ def tearDown(self): shutil.rmtree(self.config.checkpoint_job_dir) +class TestStepAheadAsyncRL(BaseTrainerCase): + def test_trainer(self): + """Test the explore step ahead trainer""" + # train 4 step, sync_offset=1, sync_interval=2 + # Explorer: + # | 1 | 2 | 3 |sync| 4 | + # |---|---|---|sync|---| + # Trainer: + # | 1 | 2 |sync| 3 | 4 | + # |---|---|sync|---|---| + self.config.buffer.total_epochs = 1 + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") + self.config.trainer.save_interval = 4 + self.config.synchronizer.sync_interval = 2 + self.config.synchronizer.sync_offset = 1 + self.config.check_and_update() + self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 1 + self.config.trainer.trainer_config.trainer.max_critic_ckpt_to_keep = 1 + + both(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + rollout_metrics = parser.metric_list("rollout") + self.assertTrue(len(rollout_metrics) > 0) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) + actor_metrics = parser.metric_list("actor") + self.assertTrue(len(actor_metrics) > 0) + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4) + actor_kl_metrics = parser.metric_list("actor/kl") + self.assertTrue(len(actor_kl_metrics) > 0) + critic_kl_metrics = parser.metric_list("critic/kl") + self.assertTrue(len(critic_kl_metrics) > 0) + response_metrics = parser.metric_list("response_length") + self.assertTrue(len(response_metrics) > 0) + self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) + ray.shutdown(_exiting_interpreter=True) + # check checkpoint + from trinity.common.models.utils import get_checkpoint_dir_with_step_num + + checkpoint_step_4 = get_checkpoint_dir_with_step_num( + checkpoint_root_path=self.config.checkpoint_job_dir, + trainer_type=self.config.trainer.trainer_type, + step_num=4, + ) + self.assertTrue(os.path.exists(checkpoint_step_4)) + + def tearDown(self): + # remove dir only when the test passed + shutil.rmtree(self.config.checkpoint_job_dir) + + class TestTrainerGSM8K(BaseTrainerCase): def test_trainer(self): """Test GSM8K.""" @@ -153,7 +203,7 @@ def tearDown(self): shutil.rmtree(self.config.checkpoint_job_dir) -class TestTrainerGSM8KWithSFT(BaseTrainerCase): +class TestTrainerSFTWarmupGSM8K(BaseTrainerCase): def test_trainer(self): """Test GSM8K With SFT.""" # test both mode diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 348c9262a0..a63b06a36d 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -9,6 +9,7 @@ import ray from trinity.common.config import Config, load_config +from trinity.common.constants import EXPLORER_NAME, TRAINER_NAME from trinity.explorer.explorer import Explorer from trinity.trainer.trainer import Trainer from trinity.utils.log import get_logger @@ -19,7 +20,7 @@ def bench(config: Config) -> None: """Evaluate model.""" - explorer = ray.remote(Explorer).options(name="explorer").remote(config) + explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config) try: ray.get(explorer.prepare.remote()) ray.get(explorer.benchmark.remote()) @@ -33,7 +34,7 @@ def bench(config: Config) -> None: def explore(config: Config) -> None: """Run explorer.""" try: - explorer = ray.remote(Explorer).options(name="explorer").remote(config) + explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config) ray.get(explorer.prepare.remote()) ray.get(explorer.sync_weight.remote()) ray.get(explorer.explore.remote()) @@ -46,7 +47,7 @@ def explore(config: Config) -> None: def train(config: Config) -> None: """Run trainer.""" try: - trainer = ray.remote(Trainer).options(name="trainer").remote(config) + trainer = ray.remote(Trainer).options(name=TRAINER_NAME).remote(config) ray.get(trainer.prepare.remote()) ray.get(trainer.sync_weight.remote()) ray.get(trainer.train.remote()) @@ -66,8 +67,8 @@ def both(config: Config) -> None: the latest step. The specific number of experiences may vary for different algorithms and tasks. """ - explorer = ray.remote(Explorer).options(name="explorer").remote(config) - trainer = ray.remote(Trainer).options(name="trainer").remote(config) + explorer = ray.remote(Explorer).options(name=EXPLORER_NAME).remote(config) + trainer = ray.remote(Trainer).options(name=TRAINER_NAME).remote(config) ray.get([explorer.__ray_ready__.remote(), trainer.__ray_ready__.remote()]) ray.get( [ @@ -81,15 +82,34 @@ def both(config: Config) -> None: trainer.sync_weight.remote(), ] ) - _, _ = ray.wait( + ready_ref, wait_ref = ray.wait( [ explorer.explore.remote(), trainer.train.remote(), ], num_returns=1, ) - explorer.shutdown.remote(), - trainer.shutdown.remote(), + + ready = ray.get(ready_ref[0]) + if ready == TRAINER_NAME: + logger.info( + "===========================================================\n" + "> Launcher detected that the `Trainer` process has finished.\n" + "> Stopping the explorer process immediately.\n" + "===========================================================" + ) + ray.wait(wait_ref, timeout=5) + elif ready == EXPLORER_NAME: + logger.info( + "============================================================\n" + "> Launcher detected that the `Explorer` process has finished.\n" + f"> Waiting {config.synchronizer.sync_timeout} s for the trainer process...\n" + "> You can force stop the Trainer process by pressing Ctrl+C.\n" + "============================================================" + ) + ray.wait(wait_ref, timeout=config.synchronizer.sync_timeout) + explorer.shutdown.remote() + trainer.shutdown.remote() def activate_data_module(data_workflow_url: str, config_path: str): diff --git a/trinity/common/constants.py b/trinity/common/constants.py index 3c49d65c21..9a428131fe 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -8,6 +8,9 @@ # names +EXPLORER_NAME = "explorer" +TRAINER_NAME = "trainer" + ROLLOUT_WEIGHT_SYNC_GROUP_NAME = "rollout_weight_sync" @@ -92,3 +95,11 @@ class SyncMethod(CaseInsensitiveEnum, metaclass=SyncMethodEnumMeta): NCCL = "nccl" CHECKPOINT = "checkpoint" + + +class RunningStatus(Enum): + """Running status of explorer and trainer.""" + + RUNNING = "running" + WAITING_SYNC = "waiting_sync" + STOPPED = "stopped" diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 4d5d3cf376..883e470381 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -4,6 +4,7 @@ import torch import torch.distributed +from trinity.common.constants import EXPLORER_NAME from trinity.utils.distributed import init_process_group, is_ipv6_address from trinity.utils.log import get_logger @@ -60,7 +61,7 @@ def update_weight(self): """Broadcast weight to all vllm workers from source rank 0 (actor model)""" assert self._state_dict_meta is not None if self._explorer_actor is None: - self._explorer_actor = ray.get_actor(name="explorer") + self._explorer_actor = ray.get_actor(name=EXPLORER_NAME) for name, dtype_str, shape in self._state_dict_meta: if self._weight_update_rank == 0: weight = ray.get(self._explorer_actor.get_weight.remote(name)) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 36527f1dcd..31ade5f84b 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -14,7 +14,12 @@ from trinity.buffer import get_buffer_writer from trinity.buffer.buffer import get_buffer_reader from trinity.common.config import Config -from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod +from trinity.common.constants import ( + EXPLORER_NAME, + ROLLOUT_WEIGHT_SYNC_GROUP_NAME, + RunningStatus, + SyncMethod, +) from trinity.common.models import create_inference_models from trinity.common.models.utils import ( get_checkpoint_dir_with_step_num, @@ -50,7 +55,7 @@ def __init__(self, config: Config): self.monitor = MONITOR.get(self.config.monitor.monitor_type)( project=self.config.project, name=self.config.name, - role="explorer", + role=EXPLORER_NAME, config=config, ) self.batch_size = config.buffer.batch_size @@ -69,6 +74,7 @@ def __init__(self, config: Config): self.state_dict = {} else: # nccl mode self.state_dict_meta = [] + self.status = RunningStatus.RUNNING self.logger.info("Finished initializing Explorer.") async def setup_weight_sync_group( @@ -162,35 +168,44 @@ async def get_weight(self, name: str) -> torch.Tensor: """Get the weight of the loaded model (For checkpoint weights update).""" return self.state_dict[name] - async def explore(self) -> None: + async def explore(self) -> str: while True: try: explore_contionue = self.explore_step() + if not explore_contionue: + break if self.need_sync(): self.wait_for_workflow_done() await self.sync_weight() if self.explore_step_num % self.config.explorer.eval_interval == 0: self.wait_for_workflow_done() self.eval() - if not explore_contionue: - break except Exception as e: self.logger.error(f"Error in Explorer: {e}") break - self.logger.info("--------------------\n> Explorer finished.\n--------------------\n") + self.logger.info("--------------------\n> Explorer finished.\n--------------------") + return EXPLORER_NAME def explore_step(self) -> bool: - self.explore_step_num += 1 - algo_config = self.algorithm_manager.get_current_algorithm_config(self.explore_step_num) + algo_config = self.algorithm_manager.get_current_algorithm_config(self.explore_step_num + 1) # skip warmup if algo_config.algorithm_type == "sft": + self.explore_step_num += 1 return True try: tasks = self.taskset.read() except StopIteration: self.logger.warning("No more tasks to explore. Stop exploring.") + self.cache.save_explorer( + current_step=self.explore_step_num, + current_task_index=self.explore_step_num * self.config.buffer.batch_size, + ) + self.status = RunningStatus.STOPPED + self.wait_for_workflow_done() + self.experience_buffer.finish() return False self.runner_pool.run_tasks(tasks) + self.explore_step_num += 1 return True def need_sync(self) -> bool: @@ -278,20 +293,25 @@ def wait_for_workflow_done(self) -> None: if not status.ok: self.logger.error(f"Error when running task: {status.message}") # submit another task to replace the failed task - self.runner_pool.run_tasks(self.taskset.read(batch_size=1)) + try: + tasks = self.taskset.read(batch_size=1) + except StopIteration: + self.logger.warning("No more tasks in taskset. Stop retrying.") + return + self.runner_pool.run_tasks(tasks) else: for metric_name, metric_value in status.metric.items(): all_metrics[metric_name].append(metric_value) # calculate metrics log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="rollout") # type: ignore self.monitor.log(log_metrics, step=self.explore_step_num) - self.logger.info(f"Explore step {self.explore_step_num} finished.") async def sync_weight(self) -> None: """Synchronize model weights.""" # call this method before training start to load the latest model weights - self.logger.info(f"Explorer synchronizing weights at step {self.explore_step_num}.") + self.logger.info(f"Explorer sync weights at step {self.explore_step_num}.") + self.status = RunningStatus.WAITING_SYNC if self.use_checkpoint_weights_update: await self._checkpoint_weights_update() else: # nccl weights update @@ -301,7 +321,11 @@ async def sync_weight(self) -> None: current_step=self.explore_step_num, current_task_index=self.explore_step_num * self.config.buffer.batch_size, ) - self.logger.info(f"Explorer synchronizing at step {self.explore_step_num} finished") + self.status = RunningStatus.RUNNING + self.logger.info(f"Explorer sync at step {self.explore_step_num} finished") + + async def running_status(self) -> RunningStatus: + return self.status def flush_log(self, step: int) -> None: """Flush the log of the current step.""" diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index ff43dfbb79..216c916c69 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -7,8 +7,15 @@ import os from abc import ABC, abstractmethod +import ray + from trinity.common.config import Config -from trinity.common.constants import SyncMethod +from trinity.common.constants import ( + EXPLORER_NAME, + TRAINER_NAME, + RunningStatus, + SyncMethod, +) from trinity.utils.log import get_logger @@ -19,24 +26,26 @@ def __init__(self, config: Config) -> None: self.config = config self.logger = get_logger(__name__) self.engine = get_trainer_wrapper(config) + self.explorer_ref = None def prepare(self) -> None: """Prepare the trainer.""" self.engine.prepare() - def train(self): + def train(self) -> str: """Train the model.""" while True: try: train_continue = self.train_step() - if self.need_sync(): - self.sync_weight() if not train_continue: break + if self.need_sync(): + self.sync_weight() except Exception as e: self.logger.error(f"Error in Trainer: {e}") break - self.logger.info("--------------------\n> Trainer finished.\n--------------------\n") + self.logger.info("--------------------\n> Trainer finished.\n--------------------") + return TRAINER_NAME def train_step(self) -> bool: """Train one step. @@ -53,6 +62,12 @@ def need_sync(self) -> bool: def sync_weight(self) -> None: """Sync the model weight.""" if self.config.synchronizer.sync_method == SyncMethod.NCCL: + if self.explorer_ref is None: + self.explorer_ref = ray.get_actor(EXPLORER_NAME) + explorer_status = ray.get(self.explorer_ref.running_status.remote()) + if explorer_status == RunningStatus.STOPPED: + self.logger.warning("Explorer has already stopped. Skipping sync weight.") + return self.logger.info(f"Trainer synchronizing weights at step {self.engine.train_step_num}.") self.engine.sync_weight() diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 69e99a153d..cbc88902a0 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -71,7 +71,11 @@ from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager from trinity.common.config import AlgorithmConfig -from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod +from trinity.common.constants import ( + EXPLORER_NAME, + ROLLOUT_WEIGHT_SYNC_GROUP_NAME, + SyncMethod, +) from trinity.utils.distributed import init_process_group, is_ipv6_address logger = logging.getLogger(__file__) @@ -573,7 +577,7 @@ def setup_weight_sync_group(self): master_address, master_port = self.get_availale_master_addr_port() world_size = self.config.synchronizer.explorer_world_size + 1 print(f"Trainer init_process_group {master_address}:{master_port} ({world_size}).") - explorer = ray.get_actor("explorer") + explorer = ray.get_actor(EXPLORER_NAME) setup_ref = explorer.setup_weight_sync_group.remote( master_address, master_port, self.state_dict_meta ) diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 4243e61d17..d041bea128 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -36,6 +36,7 @@ from trinity.algorithm.algorithm_manager import AlgorithmManager from trinity.algorithm.utils import prefix_metrics from trinity.common.config import Config +from trinity.common.constants import TRAINER_NAME from trinity.common.experience import Experiences from trinity.trainer.trainer import TrainEngineWrapper from trinity.utils.monitor import MONITOR @@ -149,7 +150,7 @@ def __init__( self.logger = MONITOR.get(global_config.monitor.monitor_type)( project=config.trainer.project_name, name=config.trainer.experiment_name, - role="trainer", + role=TRAINER_NAME, config=global_config, ) self.reset_experiences_example_table()