Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ project: Trinity-RFT
name: example
mode: both
checkpoint_root_dir: /PATH/TO/CHECKPOINT
continue_from_checkpoint: true

algorithm:
# Algorithm-related parameters
Expand Down Expand Up @@ -68,6 +69,7 @@ checkpoint_root_dir: /PATH/TO/CHECKPOINT
- `explore`: Only launches the explorer.
- `bench`: Used for benchmarking.
- `checkpoint_root_dir`: Root directory where all checkpoints and logs will be saved. Checkpoints for this experiment will be stored in `<checkpoint_root_dir>/<project>/<name>/`.
- `continue_from_checkpoint`: If set to `true`, the experiment will continue from the latest checkpoint in the checkpoint path (if any); otherwise, it will rename the current experiment to `<name>_<timestamp>` and start a new experiment.
- `ray_namespace`: Namespace for the modules launched in the current experiment. If not specified, it will be set to `<project>/<name>`.

---
Expand Down
27 changes: 27 additions & 0 deletions tests/common/config_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# -*- coding: utf-8 -*-
"""Test cases for Config modules."""
import datetime
import os
import shutil
import unittest

from tests.tools import get_template_config
from trinity.common.config import InferenceModelConfig, load_config

CHECKPOINT_ROOT_DIR = os.path.join(os.path.dirname(__file__), "temp_checkpoint_dir")


class TestConfig(unittest.TestCase):
def test_load_default_config(self):
Expand Down Expand Up @@ -54,3 +58,26 @@ def test_all_examples_are_valid(self):
except Exception as e:
print(f"Error loading config {config_path}: {e}")
raise e

def test_continue_from_checkpoint_is_valid(self):
config = get_template_config()
config.name = "test"
config.project = "unittest"
config.checkpoint_root_dir = CHECKPOINT_ROOT_DIR

dir_path = os.path.join(config.checkpoint_root_dir, config.project, config.name)
os.makedirs(os.path.join(dir_path, "global_step_1"))

config.continue_from_checkpoint = True
config.check_and_update()
self.assertEqual(config.name, "test")

config.continue_from_checkpoint = False
config.check_and_update()
self.assertTrue(config.name.startswith("test_"))
timestamp = config.name.split("_")[-1]
self.assertTrue(datetime.datetime.strptime(timestamp, "%Y%m%d%H%M%S"))

def tearDown(self):
if os.path.exists(CHECKPOINT_ROOT_DIR):
shutil.rmtree(CHECKPOINT_ROOT_DIR)
12 changes: 12 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Configs for RFT."""
import os
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional

from omegaconf import OmegaConf
Expand Down Expand Up @@ -399,6 +400,8 @@ class Config:
checkpoint_job_dir: str = ""
# If not set, automatically generated as f"{config.project}-{config.name}"
ray_namespace: str = ""
# whether to continue training from the last checkpoint in checkpoint_job_dir (if any)
continue_from_checkpoint: bool = True

algorithm: AlgorithmConfig = field(default_factory=AlgorithmConfig)
data_processor: DataProcessorConfig = field(default_factory=DataProcessorConfig)
Expand Down Expand Up @@ -713,6 +716,15 @@ def check_and_update(self) -> None: # noqa: C901
self.checkpoint_root_dir = os.path.join(os.getcwd(), self.checkpoint_root_dir)
# create a job dir at checkpoint_root_dir/project/name
self.checkpoint_job_dir = os.path.join(self.checkpoint_root_dir, self.project, self.name)
# rename the experiment when necessary
if not self.continue_from_checkpoint and (
os.path.exists(self.checkpoint_job_dir) and os.listdir(self.checkpoint_job_dir)
):
ori_name = self.name
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
self.name = f"{ori_name}_{timestamp}"
self.checkpoint_job_dir = f"{self.checkpoint_job_dir}_{timestamp}"
logger.warning(f"Experiment [{ori_name}] already exists, renamed as {self.name}.")
os.makedirs(self.checkpoint_job_dir, exist_ok=True)

# check and update model path
Expand Down
4 changes: 4 additions & 0 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,10 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
self.trainer.experiment_name = config.name
self.trainer.default_local_dir = config.checkpoint_job_dir
self.trainer.sft_warmup_steps = config.buffer.trainer_input.sft_warmup_steps
if not config.continue_from_checkpoint:
self.trainer.resume_mode = "disable"
else:
self.trainer.resume_mode = "auto"

self.buffer = config.buffer
# TODO: use dynamic read_batch_size to support multi-round scenarios
Expand Down