diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 74194ad57d..7a388390d8 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -10,12 +10,14 @@ global_config: total_epochs: 1 batch_size: 96 eval_interval: 1000 + eval_on_latest_ckp: true ``` - `mode`: The mode of the experiment, chosen from `both`, `train`, `explore` or `bench`. `both` means both trainer and explorer are launched; `train` means only trainer is launched; `explore` means only explorer is launched; `bench` conducts benchmark evaluation. Default is `both`. - `global_config.total_epochs`: The total number of epochs. It should be checked manually. - `global_config.batch_size`: The batch size used for training. It should be checked manually. - `global_config.eval_interval`: The interval steps between two evaluations. Default is `1000`. +- `global_config.eval_on_latest_ckp`: Whether to evaluate on only the latest checkpoint or all the checkpoints in the path. Only valid in `bench` mode. Default is `true`. ## Monitor diff --git a/pyproject.toml b/pyproject.toml index 437aae189d..e678a4b6f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ requires-python = ">=3.10" dependencies = [ "verl==0.3.0.post1", - "ray[default]==2.43.0", + "ray[default]>=2.45.0", "vllm>=0.8.5", "tensordict==0.6.2", "wandb", diff --git a/tests/tools.py b/tests/tools.py index ff4488f857..be415fba75 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -36,7 +36,7 @@ def get_unittest_dataset_config( dataset_name: str = "countdown", split: str = "train" ) -> StorageConfig: """Countdown sample dataset for 8 steps""" - if dataset_name == "countdown": + if dataset_name == "countdown" or dataset_name == "copy_countdown": return StorageConfig( name=dataset_name, path=os.path.join(os.path.dirname(__file__), "template", "data", "countdown"), @@ -86,10 +86,12 @@ def metric_exist(self, metric_name: str) -> bool: return metric_name in self._metrics def metric_max_step(self, metric_name: str) -> int: + return max(self.metric_steps(metric_name)) + + def metric_steps(self, metric_name: str) -> List[int]: if not self.metric_exist(metric_name): raise ValueError(f"Metric '{metric_name}' does not exist.") - steps = list(self._metrics[metric_name].keys()) - return max(steps) + return list(self._metrics[metric_name].keys()) def metric_list(self, metric_prefix: str) -> List[str]: return [name for name in self._metrics if name.startswith(metric_prefix)] diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 35dada1074..62f4abf745 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -14,7 +14,7 @@ get_template_config, get_unittest_dataset_config, ) -from trinity.cli.launcher import both +from trinity.cli.launcher import bench, both from trinity.common.constants import MonitorType, SyncMethod @@ -27,9 +27,11 @@ def setUp(self): self.config.model.model_path = get_model_path() self.config.explorer.engine_type = "vllm_async" self.config.explorer.repeat_times = 3 + self.config.explorer.use_v1 = False + self.config.monitor.name = f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}" self.config.monitor.monitor_type = MonitorType.TENSORBOARD self.config.model.checkpoint_path = os.path.join( - get_checkpoint_path(), f"train-{datetime.now().strftime('%Y%m%d%H%M%S')}" + get_checkpoint_path(), f"trainer-{datetime.now().strftime('%Y%m%d%H%M%S')}" ) self.config.synchronizer.sync_interval = 2 self.config.synchronizer.sync_method = SyncMethod.NCCL @@ -42,15 +44,20 @@ def test_trainer(self): class TestTrainerCountdown(BaseTrainerCase): def test_trainer(self): - """Test the trainer.""" + """Test the both and bench mode.""" + # test both mode self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") self.config.buffer.explorer_input.eval_tasksets.append( get_unittest_dataset_config("countdown", "test") ) + self.config.buffer.explorer_input.eval_tasksets.append( + get_unittest_dataset_config("copy_countdown", "test") + ) + self.config.trainer.save_interval = 4 self.config.check_and_update() - self.config.trainer.trainer_config.trainer.save_freq = 8 + self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 + self.config.trainer.trainer_config.trainer.max_critic_ckpt_to_keep = 2 both(self.config) - # check tensorboard parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard")) rollout_metrics = parser.metric_list("rollout") self.assertTrue(len(rollout_metrics) > 0) @@ -64,16 +71,41 @@ def test_trainer(self): response_metrics = parser.metric_list("response_length") self.assertTrue(len(response_metrics) > 0) self.assertEqual(parser.metric_max_step(response_metrics[0]), 8) + ray.shutdown(_exiting_interpreter=True) # check checkpoint from trinity.common.models.utils import get_checkpoint_dir_with_step_num - checkpoint_dir = get_checkpoint_dir_with_step_num( + checkpoint_step_4 = get_checkpoint_dir_with_step_num( + checkpoint_root_path=self.config.model.checkpoint_path, + trainer_type=self.config.trainer.trainer_type, + step_num=4, + ) + checkpoint_step_8 = get_checkpoint_dir_with_step_num( checkpoint_root_path=self.config.model.checkpoint_path, trainer_type=self.config.trainer.trainer_type, - step_num=None, + step_num=8, ) - self.assertTrue(os.path.exists(checkpoint_dir)) - self.assertTrue(checkpoint_dir.endswith("step_8")) + self.assertTrue(os.path.exists(checkpoint_step_4)) + self.assertTrue(os.path.exists(checkpoint_step_8)) + + ray.init(ignore_reinit_error=True) + # test bench mode + self.config.mode = "bench" + self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT + self.config.global_config.eval_on_latest_ckp = False + self.config.check_and_update() + bench(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.job_dir, "tensorboard")) + countdown_metrics = parser.metric_list("eval/countdown") + copy_countdown_metrics = parser.metric_list("eval/copy_countdown") + self.assertTrue(len(countdown_metrics) > 0) + self.assertTrue(len(copy_countdown_metrics) > 0) + countdown_metric_steps = parser.metric_steps(countdown_metrics[0]) + countdown_copy_metric_steps = parser.metric_steps(copy_countdown_metrics[0]) + self.assertEqual(2, len(countdown_metric_steps)) + self.assertEqual(2, len(countdown_copy_metric_steps)) + self.assertTrue(4 in countdown_metric_steps) + self.assertTrue(8 in countdown_metric_steps) def tearDown(self): # remove dir only when the test passed diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 482224918d..1b2689b2dc 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -18,12 +18,11 @@ def bench(config: Config) -> None: explorer = Explorer.remote(config) try: ray.get(explorer.prepare.remote()) - ray.get(explorer.sync_weight.remote()) - _, step = ray.get(explorer.eval.remote()) - logger.info("Evaluation finished.") - ray.get(explorer.flush_log.remote(step=step)) + ray.get(explorer.benchmark.remote()) + logger.info("Benchmark finished.") + ray.get(explorer.shutdown.remote()) except Exception as e: - logger.error(f"Evaluation failed: {e}") + logger.error(f"Benchmark failed: {e}") raise e @@ -35,6 +34,7 @@ def explore(config: Config) -> None: ray.get(explorer.sync_weight.remote()) ray.get(explorer.explore.remote()) logger.info("Explore finished.") + ray.get(explorer.shutdown.remote()) except Exception as e: logger.error(f"Explore failed: {e}") raise e @@ -60,6 +60,7 @@ def train(config: Config) -> None: try: ray.get(trainer.train.remote(algo_type)) logger.info("Train finished.") + ray.get(trainer.shutdown.remote()) except Exception as e: logger.error(f"Train failed {e}.") raise e @@ -133,6 +134,9 @@ def both(config: Config) -> None: ray.get(explorer.flush_log.remote(step=explore_step_num)) ray.get(trainer.flush_log.remote(step=train_step_num)) + ray.get(explorer.shutdown.remote()) + ray.get(trainer.shutdown.remote()) + def activate_data_module(data_workflow_url: str, config_path: str): """Check whether to activate data module and preprocess datasets.""" diff --git a/trinity/common/config.py b/trinity/common/config.py index 00bdb153a4..9c9ea72017 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -107,6 +107,7 @@ class GlobalConfig: total_epochs: int = 1 batch_size: int = 1 eval_interval: int = 100 + eval_on_latest_ckp: bool = True @dataclass @@ -299,7 +300,8 @@ def _check_interval(self) -> None: # check eval_interval if ( - self.trainer.algorithm_type != AlgorithmType.DPO + self.mode != "bench" + and self.trainer.algorithm_type != AlgorithmType.DPO and self.global_config.eval_interval % self.synchronizer.sync_interval != 0 ): self.global_config.eval_interval = ( @@ -311,12 +313,13 @@ def _check_interval(self) -> None: # check save_interval if ( - self.trainer.algorithm_type != AlgorithmType.DPO + self.mode != "bench" + and self.trainer.algorithm_type != AlgorithmType.DPO and self.synchronizer.sync_method == SyncMethod.CHECKPOINT ): if self.trainer.save_interval != self.synchronizer.sync_interval: logger.warning( - f"When `trainer.algorithm_type != DPO` and `synchronizer.sync_method == checkpoint`, " + f"When `trainer.algorithm_type` != `DPO` and `synchronizer.sync_method` == `checkpoint`, " f"`trainer.save_interval` will be set to " f"`synchronizer.sync_interval = {self.synchronizer.sync_interval}`." ) @@ -356,7 +359,7 @@ def _check_buffer(self) -> None: # noqa: C901 logger.info( f"Auto set `buffer.trainer_input.experience_buffer` to {self.buffer.trainer_input.experience_buffer}" ) - else: # TODO: to be check + elif self.mode == "train": # TODO: to be check if self.trainer.algorithm_type.is_dpo(): if ( self.buffer.trainer_input.experience_buffer is None @@ -365,7 +368,8 @@ def _check_buffer(self) -> None: # noqa: C901 raise ValueError( "`buffer.trainer_input.experience_buffer.path` is required when `trainer.algorithm_type == AlgorithmType.DPO`" ) - self.buffer.trainer_input.experience_buffer.algorithm_type = self.trainer.algorithm_type + if self.mode in ["both", "train"]: + self.buffer.trainer_input.experience_buffer.algorithm_type = self.trainer.algorithm_type # set buffer.explorer_output if self.buffer.explorer_output is None: @@ -418,7 +422,7 @@ def check_and_update(self) -> None: # noqa: C901 ) self.synchronizer.backend = self.explorer.backend if self.mode == "bench" and self.synchronizer.sync_method != SyncMethod.CHECKPOINT: - self.synchronizer.sync_method = "checkpoint" + self.synchronizer.sync_method = SyncMethod.CHECKPOINT logger.warning( "Bench mode only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`." ) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 6bed934f33..54ad581562 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -34,10 +34,11 @@ def __init__(self, config: Config): self.step_num = explorer_meta.get("latest_iteration", 0) self.config = config self.models = create_rollout_models(config) - self.experience_buffer = get_buffer_writer( - self.config.buffer.explorer_output, # type: ignore - self.config.buffer, - ) + if self.config.mode != "bench": + self.experience_buffer = get_buffer_writer( + self.config.buffer.explorer_output, # type: ignore + self.config.buffer, + ) self.config.buffer.explorer_input.taskset.index = explorer_meta.get("latest_task_index", 0) self.taskset = get_buffer_reader( self.config.buffer.explorer_input.taskset, self.config.buffer @@ -261,6 +262,29 @@ def wait(): self.monitor.log(log_metrics, step=self.step_num) # type: ignore return True, self.step_num + def benchmark(self) -> bool: + """Benchmark the model checkpoints.""" + # benchmark on the latest checkpoint + if self.config.global_config.eval_on_latest_ckp: + self._checkpoint_weights_update() + self.eval() + return True + + # benchmark on all checkoints + all_ckp_steps = sorted( + [ + int(ckp.split("global_step_")[-1]) + for ckp in os.listdir(self.config.model.checkpoint_path) + if os.path.isdir(os.path.join(self.config.model.checkpoint_path, ckp)) + and ckp.startswith("global_step_") + ] + ) + for step_num in all_ckp_steps: + self.step_num = step_num + self._checkpoint_weights_update(step_num=step_num) + self.eval() + return True + def sync_weight(self) -> None: """Synchronize model weights.""" # call this method before training start to load the latest model weights @@ -272,3 +296,6 @@ def sync_weight(self) -> None: def flush_log(self, step: int) -> None: """Flush the log of the current step.""" self.monitor.log({}, step=step, commit=True) + + def shutdown(self) -> None: + self.monitor.close() diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index f6edb4e6fb..b98c47e729 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -6,6 +6,7 @@ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. """ +import os from abc import ABC, abstractmethod from typing import Tuple @@ -59,7 +60,7 @@ def train_one_period(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tupl train_status, train_step_num = self.train_step(algo_type) if not train_status: return False, train_step_num - self.logger.info(f"Trainer steps {train_step_num} finished.") + self.logger.info(f"Train step {train_step_num} finished.") return True, train_step_num def train_step(self, algo_type: AlgorithmType = AlgorithmType.PPO) -> Tuple[bool, int]: @@ -119,6 +120,14 @@ def flush_log(self, step: int) -> None: """Flush the log of the current step.""" self.engine.logger.log({}, step=step, commit=True) + def shutdown(self) -> None: + # if checkpoint not saved, save the last checkpoint + step_num = self.engine.global_steps - 1 + path = os.path.join(self.config.model.checkpoint_path, f"global_step_{step_num}") + if not os.path.isdir(path) or len(os.listdir(path)) == 0: + self.engine.save_checkpoint() + self.engine.logger.close() + class TrainEngineWrapper(ABC): """A wrapper class to wrap various training engines.""" diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 23b96a3c11..f4c0db6372 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -54,6 +54,9 @@ def log(self, data: dict, step: int, commit: bool = False) -> None: """Log metrics.""" self.logger.log(data, step=step, commit=commit) + def close(self) -> None: + self.logger.close() + class TensorboardLogger: def __init__(self, project: str, name: str, role: str, config: Any = None) -> None: @@ -70,6 +73,9 @@ def log(self, data: dict, step: int, commit: bool = False) -> None: for key in data: self.logger.add_scalar(key, data[key], step) + def close(self) -> None: + self.logger.close() + def __del__(self) -> None: self.logger.close() @@ -95,5 +101,8 @@ def log(self, data: dict, step: int, commit: bool = False) -> None: self.logger.log(data, step=step, commit=commit) self.console_logger.info(f"Step {step}: {data}") + def close(self) -> None: + self.logger.finish() + def __del__(self) -> None: self.logger.finish()