Skip to content

Commit 90f4e91

Browse files
authored
Add continue_from_checkpoint (#129)
1 parent 00f3b27 commit 90f4e91

File tree

4 files changed

+45
-0
lines changed

4 files changed

+45
-0
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ project: Trinity-RFT
1111
name: example
1212
mode: both
1313
checkpoint_root_dir: /PATH/TO/CHECKPOINT
14+
continue_from_checkpoint: true
1415

1516
algorithm:
1617
# Algorithm-related parameters
@@ -68,6 +69,7 @@ checkpoint_root_dir: /PATH/TO/CHECKPOINT
6869
- `explore`: Only launches the explorer.
6970
- `bench`: Used for benchmarking.
7071
- `checkpoint_root_dir`: Root directory where all checkpoints and logs will be saved. Checkpoints for this experiment will be stored in `<checkpoint_root_dir>/<project>/<name>/`.
72+
- `continue_from_checkpoint`: If set to `true`, the experiment will continue from the latest checkpoint in the checkpoint path (if any); otherwise, it will rename the current experiment to `<name>_<timestamp>` and start a new experiment.
7173
- `ray_namespace`: Namespace for the modules launched in the current experiment. If not specified, it will be set to `<project>/<name>`.
7274

7375
---

tests/common/config_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
# -*- coding: utf-8 -*-
22
"""Test cases for Config modules."""
3+
import datetime
34
import os
5+
import shutil
46
import unittest
57

68
from tests.tools import get_template_config
79
from trinity.common.config import InferenceModelConfig, load_config
810

11+
CHECKPOINT_ROOT_DIR = os.path.join(os.path.dirname(__file__), "temp_checkpoint_dir")
12+
913

1014
class TestConfig(unittest.TestCase):
1115
def test_load_default_config(self):
@@ -54,3 +58,26 @@ def test_all_examples_are_valid(self):
5458
except Exception as e:
5559
print(f"Error loading config {config_path}: {e}")
5660
raise e
61+
62+
def test_continue_from_checkpoint_is_valid(self):
63+
config = get_template_config()
64+
config.name = "test"
65+
config.project = "unittest"
66+
config.checkpoint_root_dir = CHECKPOINT_ROOT_DIR
67+
68+
dir_path = os.path.join(config.checkpoint_root_dir, config.project, config.name)
69+
os.makedirs(os.path.join(dir_path, "global_step_1"))
70+
71+
config.continue_from_checkpoint = True
72+
config.check_and_update()
73+
self.assertEqual(config.name, "test")
74+
75+
config.continue_from_checkpoint = False
76+
config.check_and_update()
77+
self.assertTrue(config.name.startswith("test_"))
78+
timestamp = config.name.split("_")[-1]
79+
self.assertTrue(datetime.datetime.strptime(timestamp, "%Y%m%d%H%M%S"))
80+
81+
def tearDown(self):
82+
if os.path.exists(CHECKPOINT_ROOT_DIR):
83+
shutil.rmtree(CHECKPOINT_ROOT_DIR)

trinity/common/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"""Configs for RFT."""
33
import os
44
from dataclasses import dataclass, field
5+
from datetime import datetime
56
from typing import Any, Dict, List, Optional
67

78
from omegaconf import OmegaConf
@@ -400,6 +401,8 @@ class Config:
400401
checkpoint_job_dir: str = ""
401402
# If not set, automatically generated as f"{config.project}-{config.name}"
402403
ray_namespace: str = ""
404+
# whether to continue training from the last checkpoint in checkpoint_job_dir (if any)
405+
continue_from_checkpoint: bool = True
403406

404407
algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig)
405408
data_processor: DataProcessorConfig = field(default_factory=DataProcessorConfig)
@@ -714,6 +717,15 @@ def check_and_update(self) -> None: # noqa: C901
714717
self.checkpoint_root_dir = os.path.join(os.getcwd(), self.checkpoint_root_dir)
715718
# create a job dir at checkpoint_root_dir/project/name
716719
self.checkpoint_job_dir = os.path.join(self.checkpoint_root_dir, self.project, self.name)
720+
# rename the experiment when necessary
721+
if not self.continue_from_checkpoint and (
722+
os.path.exists(self.checkpoint_job_dir) and os.listdir(self.checkpoint_job_dir)
723+
):
724+
ori_name = self.name
725+
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
726+
self.name = f"{ori_name}_{timestamp}"
727+
self.checkpoint_job_dir = f"{self.checkpoint_job_dir}_{timestamp}"
728+
logger.warning(f"Experiment [{ori_name}] already exists, renamed as {self.name}.")
717729
os.makedirs(self.checkpoint_job_dir, exist_ok=True)
718730

719731
# check and update model path

trinity/common/verl_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,10 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
293293
self.trainer.experiment_name = config.name
294294
self.trainer.default_local_dir = config.checkpoint_job_dir
295295
self.trainer.sft_warmup_steps = config.buffer.trainer_input.sft_warmup_steps
296+
if not config.continue_from_checkpoint:
297+
self.trainer.resume_mode = "disable"
298+
else:
299+
self.trainer.resume_mode = "auto"
296300

297301
self.buffer = config.buffer
298302
# TODO: use dynamic read_batch_size to support multi-round scenarios

0 commit comments

Comments
 (0)