diff --git a/tests/template/verl_config.yaml b/tests/template/verl_config.yaml index d6dcf4a997..bb5c21612a 100644 --- a/tests/template/verl_config.yaml +++ b/tests/template/verl_config.yaml @@ -16,7 +16,7 @@ actor_rollout_ref: shuffle: False ulysses_sequence_parallel_size: 1 # sp size checkpoint: - contents: ['model', 'hf_model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + contents: ["model", "optimizer", "extra"] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space optim: lr: 1e-6 lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime @@ -72,6 +72,8 @@ critic: shuffle: ${actor_rollout_ref.actor.shuffle} grad_clip: 1.0 cliprange_value: 0.5 + checkpoint: + contents: ["model", "optimizer", "extra"] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space trainer: balance_batch: True diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index 8686a0d497..6e530d32ce 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -12,26 +12,40 @@ class SampleStrategy(ABC): - def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): + def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs) -> None: self.pad_token_id = buffer_config.pad_token_id self.trainer_type = trainer_type @abstractmethod def sample(self, step: int) -> Tuple[Any, Dict, List]: - """Sample experiences from buffer. + """Sample data from buffer. Args: step (`int`): The step number of current step. Returns: - `Any`: The sampled experiences. + `Any`: The sampled data. `Dict`: Metrics for logging. - `List`: Representative experiences for logging. + `List`: Representative data for logging. + """ + + # Experimental API + @abstractmethod + def warmup_state(self, step: int) -> Tuple[bool, bool]: + """Check the warmup state of the current step. + + Args: + step (`int`): The step number of current step. + + Returns: + `bool`: Current step is in warmup or not. + `bool`: Warmup is finished on this step or not. """ @classmethod + @abstractmethod def default_args(cls) -> dict: - return {} + """Get the default arguments of the sample strategy.""" @SAMPLE_STRATEGY.register_module("warmup") @@ -70,6 +84,13 @@ def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: else: raise NotImplementedError(f"backend {self.trainer_type} is not supported") + def warmup_state(self, step: int) -> Tuple[bool, bool]: + return step <= self.sft_warmup_steps, step == self.sft_warmup_steps + + @classmethod + def default_args(cls) -> dict: + return {} + @SAMPLE_STRATEGY.register_module("default") class DefaultSampleStrategy(SampleStrategy): @@ -93,6 +114,13 @@ def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: else: raise NotImplementedError(f"backend {self.trainer_type} is not supported") + def warmup_state(self, step: int) -> Tuple[bool, bool]: + return False, False + + @classmethod + def default_args(cls) -> dict: + return {} + @SAMPLE_STRATEGY.register_module("dpo") class DPOSampleStrategy(WarmupSampleStrategy): diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index cf4a7882aa..348c9262a0 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -2,6 +2,7 @@ import argparse import os import sys +import traceback from pathlib import Path from pprint import pprint @@ -18,44 +19,41 @@ def bench(config: Config) -> None: """Evaluate model.""" - explorer = Explorer.remote(config) + explorer = ray.remote(Explorer).options(name="explorer").remote(config) try: ray.get(explorer.prepare.remote()) ray.get(explorer.benchmark.remote()) logger.info("Benchmark finished.") ray.get(explorer.shutdown.remote()) - except Exception as e: - logger.error(f"Benchmark failed: {e}") - raise e + except Exception: + error_msg = traceback.format_exc() + logger.error(f"Benchmark failed:\n{error_msg}") def explore(config: Config) -> None: """Run explorer.""" - explorer = Explorer.remote(config) try: + explorer = ray.remote(Explorer).options(name="explorer").remote(config) ray.get(explorer.prepare.remote()) ray.get(explorer.sync_weight.remote()) ray.get(explorer.explore.remote()) - logger.info("Explore finished.") ray.get(explorer.shutdown.remote()) - except Exception as e: - logger.error(f"Explore failed: {e}") - raise e + except Exception: + error_msg = traceback.format_exc() + logger.error(f"Explorer failed:\n{error_msg}") def train(config: Config) -> None: """Run trainer.""" - - trainer = Trainer.remote(config) - ray.get(trainer.prepare.remote()) - try: + trainer = ray.remote(Trainer).options(name="trainer").remote(config) + ray.get(trainer.prepare.remote()) + ray.get(trainer.sync_weight.remote()) ray.get(trainer.train.remote()) - logger.info("Train finished.") ray.get(trainer.shutdown.remote()) - except Exception as e: - logger.error(f"Train failed {e}.") - raise e + except Exception: + error_msg = traceback.format_exc() + logger.error(f"Trainer failed:\n{error_msg}") def both(config: Config) -> None: @@ -68,54 +66,30 @@ def both(config: Config) -> None: the latest step. The specific number of experiences may vary for different algorithms and tasks. """ - explorer = Explorer.remote(config) - trainer = Trainer.remote(config) + explorer = ray.remote(Explorer).options(name="explorer").remote(config) + trainer = ray.remote(Trainer).options(name="trainer").remote(config) ray.get([explorer.__ray_ready__.remote(), trainer.__ray_ready__.remote()]) - logger.info("Setup explorer and trainer finished.") ray.get( [ explorer.prepare.remote(), trainer.prepare.remote(), ] ) - # sync weight before training start - ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()]) - - while True: - try: - ref_explore = explorer.explore_one_period.remote() - ref_train = trainer.train_one_period.remote() - explore_continue, explore_step_num = ray.get(ref_explore) - train_continue, train_step_num = ray.get(ref_train) - if not explore_continue: - # If explore finished, the trainer may not have enough experiences to continue, - # which will cause the trainer be blocked. So we stop the training process - # immediately. - # TODO: use a more elegant way to stop the training process. - logger.info("Explorer finished, stopping...") - break - if not train_continue: - logger.info("Trainer finished, stopping...") - break - ray.get([explorer.sync_weight.remote(), trainer.sync_weight.remote()]) - logger.info("Model weight synchronized.") - except Exception as e: - logger.error(e) - logger.error("Training stopped due to exception.") - raise e - if explore_step_num % config.explorer.eval_interval == 0: - try: - ray.get(explorer.eval.remote()) - logger.info("Evaluation finished.") - except Exception as e: - logger.error(e) - logger.error("Evaluation failed.") - raise e - ray.get(explorer.flush_log.remote(step=explore_step_num)) - ray.get(trainer.flush_log.remote(step=train_step_num)) - - ray.get(explorer.shutdown.remote()) - ray.get(trainer.shutdown.remote()) + ray.get( + [ + explorer.sync_weight.remote(), + trainer.sync_weight.remote(), + ] + ) + _, _ = ray.wait( + [ + explorer.explore.remote(), + trainer.train.remote(), + ], + num_returns=1, + ) + explorer.shutdown.remote(), + trainer.shutdown.remote(), def activate_data_module(data_workflow_url: str, config_path: str): diff --git a/trinity/common/config.py b/trinity/common/config.py index 9c45627d32..1409fa33f3 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -319,8 +319,10 @@ class SynchronizerConfig: sync_method: SyncMethod = SyncMethod.NCCL # sync weights every `sync_interval` steps sync_interval: int = 1 + # allow explorer to run `sync_offset` steps before sync + sync_offset: int = 0 # waiting for `sync_timeout` seconds before timeout in `nccl` method - sync_timeout: int = 1200 + sync_timeout: int = 1800 # wait for the lastest checkpoint to be ready # TODO: to be used wait_for_checkpoint: bool = False diff --git a/trinity/common/models/utils.py b/trinity/common/models/utils.py index a8751e7240..5cc770e64f 100644 --- a/trinity/common/models/utils.py +++ b/trinity/common/models/utils.py @@ -156,7 +156,6 @@ def get_verl_checkpoint_dir(checkpoint_path: str, step_num: Optional[int] = None iteration = f.read().strip() return os.path.join(checkpoint_path, f"global_step_{iteration}") else: - logger.error(f"No iteration file found in {checkpoint_path}") raise FileNotFoundError(f"No iteration file found in {checkpoint_path}") else: # load specific iteration checkpoint diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index 177e0e1a81..8a8a089afa 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -267,10 +267,9 @@ async def _collective_rpc( async def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool: """Sync model weights to vLLM.""" - if self.state_dict_meta is None: - self.state_dict_meta = update_weight_args_list - for args in self.state_dict_meta: - await self._collective_rpc("update_weight", args=args) + if update_weight_args_list is not None: + await self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,)) + await self._collective_rpc("update_weight") self.logger.info("Sync model weights to vLLM successfully.") self.ckp_version += 1 return True @@ -287,7 +286,6 @@ async def init_process_group( update_with_checkpoint: bool = True, state_dict_meta: dict = None, ): - self.state_dict_meta = state_dict_meta return await self._collective_rpc( "init_process_group", args=( @@ -299,12 +297,10 @@ async def init_process_group( backend, timeout, update_with_checkpoint, + state_dict_meta, ), ) - async def update_weight(self, name, dtype, shape, empty_cache=False): - return await self._collective_rpc("update_weight", args=(name, dtype, shape, empty_cache)) - async def run_api_server(self): """Run the OpenAI API server in a Ray actor. diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index c999a61bfa..878fe0bd9c 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -100,7 +100,6 @@ def init_process_group( update_with_checkpoint: bool = True, state_dict_meta: dict = None, ): - self.state_dict_meta = state_dict_meta return self.llm.collective_rpc( "init_process_group", args=( @@ -112,12 +111,10 @@ def init_process_group( backend, timeout, update_with_checkpoint, + state_dict_meta, ), ) - def update_weight(self, name, dtype, shape, empty_cache=False): - return self.llm.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache)) - def reset_prefix_cache(self): self.llm.llm_engine.reset_prefix_cache() @@ -279,11 +276,9 @@ def has_api_server(self) -> bool: def sync_model(self, update_weight_args_list: Optional[List[Tuple]] = None) -> bool: """Sync model weights to vLLM.""" - if self.state_dict_meta is None: - self.state_dict_meta = update_weight_args_list - with self.lock: - for args in self.state_dict_meta: - self.llm.collective_rpc("update_weight", args=args) + if update_weight_args_list is not None: + self._collective_rpc("set_state_dict_meta", args=(update_weight_args_list,)) + self._collective_rpc("update_weight") self.logger.info("Sync model weights to vLLM successfully.") self.ckp_version += 1 return True diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 4293811ab7..4d5d3cf376 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -21,22 +21,21 @@ def init_process_group( backend: str = "nccl", timeout: int = 1200, update_with_checkpoint: bool = True, + state_dict_meta: list = None, ): """Init torch process group for model weights update""" assert torch.distributed.is_initialized(), "default torch process group must be initialized" assert group_name != "", "group name must not be empty" + self.set_state_dict_meta(state_dict_meta) self._update_with_checkpoint = update_with_checkpoint - if self._update_with_checkpoint: - logger.info( - f"init_process_group (checkpoint): address={master_address}:{master_port}, rank={torch.distributed.get_rank()}, rank_offset={rank_offset}, world_size={world_size}" - ) - self._weight_update_rank = torch.distributed.get_rank() + rank_offset - else: - logger.info( - f"init_process_group (nccl): rank={torch.distributed.get_rank()}, rank_offset={rank_offset}, world_size={world_size}" - ) - self._weight_update_rank = torch.distributed.get_rank() + rank_offset - + self._weight_update_rank = torch.distributed.get_rank() + rank_offset + logger.info( + f"vLLM starting init_process_group ({'checkpoint' if self._update_with_checkpoint else 'nccl'}):\n" + f" > address={master_address}:{master_port}\n" + f" > rank={torch.distributed.get_rank()}\n" + f" > rank_offset={rank_offset}\n" + f" > world_size={world_size}" + ) if is_ipv6_address(master_address): # using tcp://ipv6:port will lead to ValueError init_method = f"tcp://[{master_address}]:{master_port}" @@ -51,24 +50,28 @@ def init_process_group( rank=self._weight_update_rank, group_name=group_name, ) - logger.info( - f"init_process_group: master_address={master_address}, master_port={master_port}, " - f"rank={self._weight_update_rank}, world_size={world_size}, group_name={group_name}" - ) + logger.info("vLLM init_process_group finished.") self._explorer_actor = None - def update_weight(self, name: str, dtype_str: str, shape: tuple, empty_cache=False): - """Broadcast weight to all vllm workers from source rank 0 (actor model)""" - if self._weight_update_rank == 0: - if self._explorer_actor is None: - self._explorer_actor = ray.get_actor(name="explorer") - weight = ray.get(self._explorer_actor.get_weight.remote(name)) - weight = weight.to(self.device) - else: - dtype = getattr(torch, dtype_str.split(".")[-1]) - weight = torch.empty(shape, dtype=dtype, device=self.device) - torch.distributed.broadcast(weight, 0, group=self._model_update_group) - weight = weight.type(self.model_config.dtype) + def set_state_dict_meta(self, state_dict_meta): + self._state_dict_meta = state_dict_meta - self.model_runner.model.load_weights(weights=[(name, weight)]) - del weight + def update_weight(self): + """Broadcast weight to all vllm workers from source rank 0 (actor model)""" + assert self._state_dict_meta is not None + if self._explorer_actor is None: + self._explorer_actor = ray.get_actor(name="explorer") + for name, dtype_str, shape in self._state_dict_meta: + if self._weight_update_rank == 0: + weight = ray.get(self._explorer_actor.get_weight.remote(name)) + weight = weight.to(self.device) + else: + dtype = getattr(torch, dtype_str.split(".")[-1]) + weight = torch.empty(shape, dtype=dtype, device=self.device) + torch.distributed.broadcast(weight, 0, group=self._model_update_group) + weight = weight.type(self.model_config.dtype) + self.model_runner.model.load_weights(weights=[(name, weight)]) + del weight + torch.distributed.barrier() + torch.cuda.synchronize() + torch.cuda.empty_cache() diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 26ee0b53c2..36527f1dcd 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -1,11 +1,13 @@ # -*- coding: utf-8 -*- """The explorer module""" +from __future__ import annotations + +import asyncio import os import time from collections import defaultdict from typing import List, Optional, Tuple -import ray import torch from trinity.algorithm.algorithm_manager import AlgorithmManager @@ -24,7 +26,6 @@ from trinity.utils.monitor import MONITOR -@ray.remote(name="explorer", concurrency_groups={"get_weight": 32, "setup_weight_sync_group": 1}) class Explorer: """Responsible for exploring the taskset.""" @@ -32,7 +33,7 @@ def __init__(self, config: Config): self.logger = get_logger(__name__) self.cache = CacheManager(config) explorer_meta = self.cache.load_explorer() - self.step_num = explorer_meta.get("latest_iteration", 0) + self.explore_step_num = explorer_meta.get("latest_iteration", 0) self.config = config self.algorithm_manager = AlgorithmManager(config) self.models, self.auxiliary_models = create_inference_models(config) @@ -70,8 +71,7 @@ def __init__(self, config: Config): self.state_dict_meta = [] self.logger.info("Finished initializing Explorer.") - @ray.method(concurrency_group="setup_weight_sync_group") - def setup_weight_sync_group( + async def setup_weight_sync_group( self, master_address: str, master_port: int, state_dict_meta: List = None ): # In checkpoint mode, we use explorer to store the model weights which has no rank @@ -100,7 +100,7 @@ def setup_weight_sync_group( ) for i, model in enumerate(self.models) ] - ray.get(refs) + await asyncio.gather(*refs) def _init_runner_pool(self) -> RunnerPool: if self.config.explorer.rollout_model.engine_type != "vllm_async": @@ -117,7 +117,7 @@ def _init_runner_pool(self) -> RunnerPool: self.logger.info(f"Setup {self.config.explorer.runner_num} WorkflowRunners") return RunnerPool(self.config, self.models, self.auxiliary_models) - def _update_model_weight(self, state_dict: dict) -> None: + async def _update_model_weight(self, state_dict: dict) -> None: # TODO: update model weight self.state_dict = state_dict if self.state_dict_meta is None: @@ -127,10 +127,12 @@ def _update_model_weight(self, state_dict: dict) -> None: self.state_dict_meta = update_weight_args_list else: update_weight_args_list = None - ray.get([model.sync_model.remote(update_weight_args_list) for model in self.models]) + await asyncio.gather( + *[model.sync_model.remote(update_weight_args_list) for model in self.models] + ) self.state_dict.clear() - def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None: + async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None: # TODO: support more checkpoint types try: checkpoint_dir = get_checkpoint_dir_with_step_num( @@ -141,104 +143,62 @@ def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> None: if checkpoint_dir == self.old_checkpoint: return model_weights = load_state_dict(os.path.join(checkpoint_dir, "actor")) - self._update_model_weight(model_weights) + await self._update_model_weight(model_weights) self.old_checkpoint = checkpoint_dir except Exception as e: - self.logger.error(f"Error when loading state_dict: {e}") + self.logger.warning(f"Fail to load checkpoint: {e}") - def _nccl_weights_update(self): + async def _nccl_weights_update(self): assert self.state_dict_meta is not None - ray.get([model.sync_model.remote() for model in self.models]) + await asyncio.gather(*[model.sync_model.remote() for model in self.models]) - def prepare(self) -> None: + async def prepare(self) -> None: """Preparation before running.""" if self.use_checkpoint_weights_update: - master_address, master_port = ray.get(self.models[0].get_available_address.remote()) - self.setup_weight_sync_group(master_address, master_port) + master_address, master_port = await self.models[0].get_available_address.remote() + await self.setup_weight_sync_group(master_address, master_port) - @ray.method(concurrency_group="get_weight") - def get_weight(self, name: str) -> torch.Tensor: + async def get_weight(self, name: str) -> torch.Tensor: """Get the weight of the loaded model (For checkpoint weights update).""" return self.state_dict[name] - def explore(self) -> None: - """Explore the entire dataset.""" + async def explore(self) -> None: while True: - explore_status, explore_iter = self.explore_one_period() - if not explore_status: + try: + explore_contionue = self.explore_step() + if self.need_sync(): + self.wait_for_workflow_done() + await self.sync_weight() + if self.explore_step_num % self.config.explorer.eval_interval == 0: + self.wait_for_workflow_done() + self.eval() + if not explore_contionue: + break + except Exception as e: + self.logger.error(f"Error in Explorer: {e}") break - self.sync_weight() - if explore_iter % self.config.explorer.eval_interval == 0: - self.eval() - self.logger.info("Evaluation finished.") - self.logger.info("Explorer finished.") + self.logger.info("--------------------\n> Explorer finished.\n--------------------\n") - def explore_one_period(self) -> Tuple[bool, int]: - """Explore for one period. - - Different from `explore()` which consumes all tasks in the task set, - `explore_one_period()` only consume `sync_interval * batch_size` - number of tasks. - Returns: - explore_status: whether there are more tasks to explore. - explore_step_num: the number of explore steps - """ - # skip for sft - algo_config = self.algorithm_manager.get_current_algorithm_config(self.step_num + 1) + def explore_step(self) -> bool: + self.explore_step_num += 1 + algo_config = self.algorithm_manager.get_current_algorithm_config(self.explore_step_num) + # skip warmup if algo_config.algorithm_type == "sft": - for _ in range(self.config.synchronizer.sync_interval): - self.step_num += 1 - if self.algorithm_manager.need_save(self.step_num): - break - return True, self.step_num - - st = time.time() - all_metrics = defaultdict(list) - - # submit tasks of this step + return True try: - tasks = [] - for _ in range(self.config.synchronizer.sync_interval): - tasks.extend(self.taskset.read()) - self.runner_pool.run_tasks(tasks) # type: ignore + tasks = self.taskset.read() except StopIteration: - self.experience_buffer.finish() - self.logger.warning("No more tasks in the task set. Stop exploring.") - return False, self.step_num - - # wait for all tasks of this step to finish - while self.runner_pool.has_next(): - status_list = self.runner_pool.get_next_unorder() - if not isinstance(status_list, list): - status_list = [status_list] - for status in status_list: - if not status.ok: - self.logger.error(f"Error when running task: {status.message}") - try: - # submit another task to replace the failed task - self.runner_pool.run_tasks(self.taskset.read()) - except StopIteration: - self.logger.warning("No more tasks in the task set. Stop exploring.") - return False, self.step_num - else: - for metric_name, metric_value in status.metric.items(): - all_metrics[metric_name].append(metric_value) - - # calculate metrics - log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="rollout") # type: ignore - log_metrics["rollout/step_time"] = time.time() - st - self.step_num += self.config.synchronizer.sync_interval - self.monitor.log(log_metrics, step=self.step_num) - - # save explore checkpoint - self.cache.save_explorer( - current_step=self.step_num, - current_task_index=self.step_num * self.config.buffer.batch_size, - # TODO: remove current_task_index - ) + self.logger.warning("No more tasks to explore. Stop exploring.") + return False + self.runner_pool.run_tasks(tasks) + return True - self.logger.info(f"Explore step {self.step_num} finished.") - return True, self.step_num + def need_sync(self) -> bool: + if self.explore_step_num <= self.config.synchronizer.sync_offset: + return False + return ( + self.explore_step_num - self.config.synchronizer.sync_offset + ) % self.config.synchronizer.sync_interval == 0 def eval(self) -> Tuple[bool, int]: """Evaluation on all evaluation data samples.""" @@ -247,7 +207,7 @@ def eval(self) -> Tuple[bool, int]: eval_tasksets.append(get_buffer_reader(eval_taskset_config, self.config.buffer)) if len(eval_tasksets) == 0: self.logger.warning("No evaluation data samples. Skip evaluation.") - return True, self.step_num + return True, self.explore_step_num self.logger.info("Evaluation started.") all_st = time.time() log_metrics = {} @@ -279,14 +239,15 @@ def wait(): log_metrics.update(metrics) log_metrics[f"eval/{eval_taskset.name}/time"] = time.time() - st log_metrics["eval/total_time"] = time.time() - all_st - self.monitor.log(log_metrics, step=self.step_num) # type: ignore - return True, self.step_num + self.monitor.log(log_metrics, step=self.explore_step_num) # type: ignore + self.logger.info("Evaluation finished.") + return True, self.explore_step_num - def benchmark(self) -> bool: + async def benchmark(self) -> bool: """Benchmark the model checkpoints.""" # benchmark on the latest checkpoint if self.config.explorer.eval_on_latest_checkpoint: - self._checkpoint_weights_update() + await self._checkpoint_weights_update() self.eval() return True @@ -300,18 +261,47 @@ def benchmark(self) -> bool: ] ) for step_num in all_ckp_steps: - self.step_num = step_num - self._checkpoint_weights_update(step_num=step_num) + self.explore_step_num = step_num + await self._checkpoint_weights_update(step_num=step_num) self.eval() return True - def sync_weight(self) -> None: + def wait_for_workflow_done(self) -> None: + """Wait for workflow to finish.""" + all_metrics = defaultdict(list) + # wait for all tasks of this step to finish + while self.runner_pool.has_next(): + status_list = self.runner_pool.get_next_unorder() + if not isinstance(status_list, list): + status_list = [status_list] + for status in status_list: + if not status.ok: + self.logger.error(f"Error when running task: {status.message}") + # submit another task to replace the failed task + self.runner_pool.run_tasks(self.taskset.read(batch_size=1)) + else: + for metric_name, metric_value in status.metric.items(): + all_metrics[metric_name].append(metric_value) + # calculate metrics + log_metrics = self.monitor.calculate_metrics(all_metrics, prefix="rollout") # type: ignore + self.monitor.log(log_metrics, step=self.explore_step_num) + + self.logger.info(f"Explore step {self.explore_step_num} finished.") + + async def sync_weight(self) -> None: """Synchronize model weights.""" # call this method before training start to load the latest model weights + self.logger.info(f"Explorer synchronizing weights at step {self.explore_step_num}.") if self.use_checkpoint_weights_update: - self._checkpoint_weights_update() + await self._checkpoint_weights_update() else: # nccl weights update - self._nccl_weights_update() + await self._nccl_weights_update() + # save explore checkpoint + self.cache.save_explorer( + current_step=self.explore_step_num, + current_task_index=self.explore_step_num * self.config.buffer.batch_size, + ) + self.logger.info(f"Explorer synchronizing at step {self.explore_step_num} finished") def flush_log(self, step: int) -> None: """Flush the log of the current step.""" diff --git a/trinity/manager/manager.py b/trinity/manager/manager.py index 3c148cbe12..baaf1242c3 100644 --- a/trinity/manager/manager.py +++ b/trinity/manager/manager.py @@ -47,7 +47,13 @@ def load_explorer(self) -> dict: try: with open(self.explorer_meta_path, "r", encoding="utf-8") as f: explorer_meta = json.load(f) - logger.info(f"Find existing explorer meta: {explorer_meta}") + logger.info( + "----------------------------------\n" + "Found existing explorer checkpoint:\n" + f" > {explorer_meta}\n" + "Continue exploring from this point.\n" + "----------------------------------" + ) return explorer_meta except Exception as e: logger.error(f"Failed to load explore meta file: {e}") @@ -62,7 +68,13 @@ def load_trainer(self) -> dict: try: with open(self.trainer_meta_path, "r", encoding="utf-8") as f: trainer_meta = json.load(f) - logger.info(f"Find existing trainer meta: {trainer_meta}") + logger.info( + "----------------------------------\n" + "Found existing trainer checkpoint:\n" + f" > {trainer_meta}\n" + "Continue training from this point.\n" + "----------------------------------" + ) return trainer_meta except Exception as e: logger.warning(f"Failed to load trainer meta file: {e}") diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 2920604fbb..ff43dfbb79 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -1,31 +1,23 @@ # -*- coding: utf-8 -*- """ Trainer Class -This file is modified from verl.trainer.main_ppo.py -And is a reproduction code of Jiayi-Pan/TinyZero. - -Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ +from __future__ import annotations + import os from abc import ABC, abstractmethod -from typing import Tuple - -import ray -from trinity.algorithm.algorithm_manager import AlgorithmManager from trinity.common.config import Config from trinity.common.constants import SyncMethod from trinity.utils.log import get_logger -@ray.remote(name="trainer") class Trainer: """Consume the experience and train the model.""" def __init__(self, config: Config) -> None: self.config = config self.logger = get_logger(__name__) - self.algorithm_manager = AlgorithmManager(config) self.engine = get_trainer_wrapper(config) def prepare(self) -> None: @@ -35,23 +27,18 @@ def prepare(self) -> None: def train(self): """Train the model.""" while True: - train_status, _ = self.train_step() - if not train_status: + try: + train_continue = self.train_step() + if self.need_sync(): + self.sync_weight() + if not train_continue: + break + except Exception as e: + self.logger.error(f"Error in Trainer: {e}") break + self.logger.info("--------------------\n> Trainer finished.\n--------------------\n") - def train_one_period(self) -> Tuple[bool, int]: - """Train for one period. Each period contains `sync_interval` steps. - Returns: - train_status: Whether to continue training. - train_step_num: The number of training steps""" - for _ in range(self.config.synchronizer.sync_interval): - train_status, train_step_num = self.train_step() - if not train_status: - return False, train_step_num - self.logger.info(f"Train step {train_step_num} finished.") - return True, train_step_num - - def train_step(self) -> Tuple[bool, int]: + def train_step(self) -> bool: """Train one step. Returns: @@ -59,9 +46,14 @@ def train_step(self) -> Tuple[bool, int]: """ return self.engine.train_step() + def need_sync(self) -> bool: + """Whether to sync the model weight.""" + return self.engine.train_step_num % self.config.synchronizer.sync_interval == 0 + def sync_weight(self) -> None: """Sync the model weight.""" if self.config.synchronizer.sync_method == SyncMethod.NCCL: + self.logger.info(f"Trainer synchronizing weights at step {self.engine.train_step_num}.") self.engine.sync_weight() def flush_log(self, step: int) -> None: @@ -90,7 +82,7 @@ def train_step_num(self) -> int: """Get the current training step number.""" @abstractmethod - def train_step(self) -> Tuple[bool, int]: + def train_step(self) -> bool: """Training.""" @abstractmethod diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 2a8308ea62..69e99a153d 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -606,6 +606,8 @@ def sync_weight(self): continue torch.distributed.broadcast(param, 0, group=self._model_update_group) param = None + torch.distributed.barrier() + torch.cuda.synchronize() torch.cuda.empty_cache() @register(dispatch_mode=Dispatch.ONE_TO_ALL) diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index bc15a25446..4243e61d17 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -6,7 +6,7 @@ import os import sys from pprint import pprint -from typing import Dict, List, Tuple +from typing import Dict, List import pandas as pd import ray @@ -285,14 +285,14 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl # TODO: compute total training steps self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize - def train_step(self) -> Tuple[bool, int]: # noqa C901 + def train_step(self) -> bool: # noqa C901 metrics = {} try: batch, sample_metrics, exp_samples = self.sample_strategy.sample(self.global_steps + 1) prefix_metrics(sample_metrics, "sample", metrics) except StopIteration: print("No more data to train. Stop training.") - return False, self.global_steps + return False self.global_steps += 1 timing_raw = {} algorithm_config = self.algorithm_manager.get_current_algorithm_config(self.global_steps) @@ -382,7 +382,7 @@ def train_step(self) -> Tuple[bool, int]: # noqa C901 ): with _timer("save_checkpoint", timing_raw): self._save_checkpoint() - return train_status, self.global_steps + return train_status def _log_single_experience( self, experiences: Experiences, idx: int, skip_special_tokens: bool