diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index f34f1dffd4..2d3cca506c 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -11,6 +11,7 @@ project: Trinity-RFT name: example mode: both checkpoint_root_dir: /PATH/TO/CHECKPOINT +continue_from_checkpoint: true algorithm: # Algorithm-related parameters @@ -68,6 +69,7 @@ checkpoint_root_dir: /PATH/TO/CHECKPOINT - `explore`: Only launches the explorer. - `bench`: Used for benchmarking. - `checkpoint_root_dir`: Root directory where all checkpoints and logs will be saved. Checkpoints for this experiment will be stored in `///`. +- `continue_from_checkpoint`: If set to `true`, the experiment will continue from the latest checkpoint in the checkpoint path (if any); otherwise, it will rename the current experiment to `_` and start a new experiment. - `ray_namespace`: Namespace for the modules launched in the current experiment. If not specified, it will be set to `/`. --- diff --git a/tests/common/config_test.py b/tests/common/config_test.py index 2c5ef463b3..db7190856f 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -1,11 +1,15 @@ # -*- coding: utf-8 -*- """Test cases for Config modules.""" +import datetime import os +import shutil import unittest from tests.tools import get_template_config from trinity.common.config import InferenceModelConfig, load_config +CHECKPOINT_ROOT_DIR = os.path.join(os.path.dirname(__file__), "temp_checkpoint_dir") + class TestConfig(unittest.TestCase): def test_load_default_config(self): @@ -54,3 +58,26 @@ def test_all_examples_are_valid(self): except Exception as e: print(f"Error loading config {config_path}: {e}") raise e + + def test_continue_from_checkpoint_is_valid(self): + config = get_template_config() + config.name = "test" + config.project = "unittest" + config.checkpoint_root_dir = CHECKPOINT_ROOT_DIR + + dir_path = os.path.join(config.checkpoint_root_dir, config.project, config.name) + os.makedirs(os.path.join(dir_path, "global_step_1")) + + config.continue_from_checkpoint = True + config.check_and_update() + self.assertEqual(config.name, "test") + + config.continue_from_checkpoint = False + config.check_and_update() + self.assertTrue(config.name.startswith("test_")) + timestamp = config.name.split("_")[-1] + self.assertTrue(datetime.datetime.strptime(timestamp, "%Y%m%d%H%M%S")) + + def tearDown(self): + if os.path.exists(CHECKPOINT_ROOT_DIR): + shutil.rmtree(CHECKPOINT_ROOT_DIR) diff --git a/trinity/common/config.py b/trinity/common/config.py index 506fe2c147..3a16304a0b 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -2,6 +2,7 @@ """Configs for RFT.""" import os from dataclasses import dataclass, field +from datetime import datetime from typing import Any, Dict, List, Optional from omegaconf import OmegaConf @@ -399,6 +400,8 @@ class Config: checkpoint_job_dir: str = "" # If not set, automatically generated as f"{config.project}-{config.name}" ray_namespace: str = "" + # whether to continue training from the last checkpoint in checkpoint_job_dir (if any) + continue_from_checkpoint: bool = True algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig) data_processor: DataProcessorConfig = field(default_factory=DataProcessorConfig) @@ -713,6 +716,15 @@ def check_and_update(self) -> None: # noqa: C901 self.checkpoint_root_dir = os.path.join(os.getcwd(), self.checkpoint_root_dir) # create a job dir at checkpoint_root_dir/project/name self.checkpoint_job_dir = os.path.join(self.checkpoint_root_dir, self.project, self.name) + # rename the experiment when necessary + if not self.continue_from_checkpoint and ( + os.path.exists(self.checkpoint_job_dir) and os.listdir(self.checkpoint_job_dir) + ): + ori_name = self.name + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + self.name = f"{ori_name}_{timestamp}" + self.checkpoint_job_dir = f"{self.checkpoint_job_dir}_{timestamp}" + logger.warning(f"Experiment [{ori_name}] already exists, renamed as {self.name}.") os.makedirs(self.checkpoint_job_dir, exist_ok=True) # check and update model path diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 1aaebe3aeb..14c8da52d0 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -293,6 +293,10 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.trainer.experiment_name = config.name self.trainer.default_local_dir = config.checkpoint_job_dir self.trainer.sft_warmup_steps = config.buffer.trainer_input.sft_warmup_steps + if not config.continue_from_checkpoint: + self.trainer.resume_mode = "disable" + else: + self.trainer.resume_mode = "auto" self.buffer = config.buffer # TODO: use dynamic read_batch_size to support multi-round scenarios