Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6b3ad37
add synchronize v2
chenyushuo Jul 16, 2025
e2b149f
add config
chenyushuo Jul 16, 2025
0b5a8f3
rename `WANT_SYNC` to `REQUIRE_SYNC`
chenyushuo Jul 17, 2025
6338b1e
refactor on checkpoint saving
chenyushuo Jul 18, 2025
ad60ac4
Merge branch 'main' of github.com:modelscope/Trinity-RFT into dev/add…
chenyushuo Jul 18, 2025
bbd8f70
Add `state_dict` and `checkpoint` update method to `synchronizer`;
chenyushuo Jul 21, 2025
1623c97
1. Add `group` into `checkpoint_job_dir` and change `group` in `wandb…
chenyushuo Jul 22, 2025
a689262
add doc string for `Synchronizer` and `FSDPCheckpointManager`
chenyushuo Jul 22, 2025
65a27b7
Merge branch 'main' of github.com:modelscope/Trinity-RFT into dev/add…
chenyushuo Jul 22, 2025
ac0b2f0
1. Bug fix in `trainer_test` and `explorer_test`
chenyushuo Jul 24, 2025
41dc7c7
Merge branch 'main' of github.com:modelscope/Trinity-RFT into dev/add…
chenyushuo Jul 24, 2025
c7e1e4c
1. Remove `update_with_checkpoint`
chenyushuo Jul 24, 2025
a0b8320
1. Add `block_until_saved` for checkpoint saving.
chenyushuo Jul 24, 2025
efcaf63
Add synchronizer test
chenyushuo Jul 25, 2025
3724919
bug fix for queue and sync test
chenyushuo Jul 29, 2025
162e736
Merge branch 'main' of github.com:modelscope/Trinity-RFT into dev/add…
chenyushuo Jul 29, 2025
1327873
add `lifetime="detached"` to Synchronizer
chenyushuo Jul 29, 2025
7f2ebc3
doc fix and fix test
chenyushuo Jul 29, 2025
041a56f
doc fix and bug fix in unittest
chenyushuo Jul 30, 2025
982ee73
Merge branch 'main' of github.com:modelscope/Trinity-RFT into dev/add…
chenyushuo Jul 30, 2025
e3bf395
Bug fix in EID and `ray.get` in `explorer`
chenyushuo Jul 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions tests/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ def test_load_default_config(self):
)
self.assertEqual(config.model.model_path, config.model.critic_model_path)
self.assertEqual(config.model.model_path, config.explorer.rollout_model.model_path)
self.assertEqual(
config.trainer.trainer_config.trainer.save_freq,
config.synchronizer.sync_interval,
)

def test_all_examples_are_valid(self):
example_dir = os.path.join(os.path.dirname(__file__), "..", "..", "examples")
Expand Down
296 changes: 296 additions & 0 deletions tests/common/synchronizer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
# -*- coding: utf-8 -*-
"""Test cases for Synchronizer modules."""

import asyncio
import multiprocessing
import os
import shutil
import time
import unittest
from copy import deepcopy
from datetime import datetime
from typing import List

import ray
from parameterized import parameterized_class

from tests.tools import (
TensorBoardParser,
get_checkpoint_path,
get_model_path,
get_template_config,
get_unittest_dataset_config,
)
from trinity.algorithm.algorithm import ALGORITHM_TYPE
from trinity.cli.launcher import both, explore, train
from trinity.common.config import Config, StorageConfig
from trinity.common.constants import StorageType, SyncMethod, SyncStyle
from trinity.explorer.explorer import Explorer
from trinity.trainer.trainer import Trainer
from trinity.utils.log import get_logger

logger = get_logger(__name__)
CHECKPOINT_ROOT_DIR = os.path.join(os.path.dirname(__file__), "temp_checkpoint_dir")


def trainer_monkey_patch(config: Config, max_steps: int, intervals: List[int]):
def new_train_step(self):
self.engine.algorithm = ALGORITHM_TYPE.get(config.algorithm.algorithm_type)
self.engine.global_steps += 1
self.logger.info(f"Training at step {self.engine.global_steps} started.")
time.sleep(intervals[self.engine.global_steps - 1])
metrics = {"actor/step": self.engine.global_steps}
self.monitor.log(data=metrics, step=self.engine.global_steps)
self.logger.info(f"Training at step {self.engine.global_steps} finished.")
return self.engine.global_steps < max_steps

Trainer.train_step = new_train_step


def explorer_monkey_patch(config: Config, max_steps: int, intervals: List[int]):
async def new_explore_step(self):
if self.explore_step_num == max_steps:
await self.save_checkpoint(sync_weight=False)
self.explore_step_num += 1
return self.explore_step_num <= max_steps

def wrapper(old_save_checkpoint):
async def new_save_checkpoint(self, sync_weight: bool = False):
await asyncio.sleep(intervals.pop(0))
await old_save_checkpoint(self, sync_weight)

return new_save_checkpoint

async def new_finish_explore_step(self, step: int, model_version: int) -> None:
metric = {"rollout/model_version": model_version}
self.monitor.log(metric, step=step)

Explorer.explore_step = new_explore_step
Explorer.save_checkpoint = wrapper(Explorer.save_checkpoint)
Explorer._finish_explore_step = new_finish_explore_step


def run_trainer(config: Config, max_steps: int, intervals: List[int]) -> None:
ray.init(ignore_reinit_error=True, namespace=config.ray_namespace)
trainer_monkey_patch(config, max_steps, intervals)
train(config)
ray.shutdown(_exiting_interpreter=True)


def run_explorer(config: Config, max_steps: int, intervals: List[int]) -> None:
ray.init(ignore_reinit_error=True, namespace=config.ray_namespace)
explorer_monkey_patch(config, max_steps, intervals)
explore(config)
ray.shutdown(_exiting_interpreter=True)


def run_both(
config: Config, max_steps: int, trainer_intervals: List[int], explorer_intervals: List[int]
) -> None:
ray.init(ignore_reinit_error=True, namespace=config.ray_namespace)
trainer_monkey_patch(config, max_steps, trainer_intervals)
explorer_monkey_patch(config, max_steps, explorer_intervals)
both(config)
ray.shutdown(_exiting_interpreter=True)


class BaseTestSynchronizer(unittest.TestCase):
def setUp(self):
if multiprocessing.get_start_method(allow_none=True) != "spawn":
multiprocessing.set_start_method("spawn", force=True)

def tearDown(self):
checkpoint_path = get_checkpoint_path()
shutil.rmtree(os.path.join(checkpoint_path, "unittest"))


@parameterized_class(
(
"sync_method",
"sync_style",
"max_steps",
"trainer_intervals",
"explorer1_intervals",
"explorer2_intervals",
),
[
(
SyncMethod.CHECKPOINT,
SyncStyle.FIXED,
8,
[2, 1, 2, 1, 2, 1, 2, 1],
[0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
[0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
),
(
SyncMethod.CHECKPOINT,
SyncStyle.DYNAMIC_BY_EXPLORER,
8,
[2, 1, 2, 1, 2, 1, 2, 1],
[0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
[0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
),
(
SyncMethod.MEMORY,
SyncStyle.FIXED,
8,
[2, 1, 2, 1, 2, 1, 2, 1],
[0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
[0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
),
(
SyncMethod.MEMORY,
SyncStyle.DYNAMIC_BY_EXPLORER,
8,
[2, 1, 2, 1, 2, 1, 2, 1],
[0, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 2.5],
[0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
),
],
)
class TestStateDictBasedSynchronizer(BaseTestSynchronizer):
def test_synchronizer(self):
config = get_template_config()
config.project = "unittest"
config.name = f"test_synchronizer_{datetime.now().strftime('%Y%m%d%H%M%S')}"
config.checkpoint_root_dir = get_checkpoint_path()
config.buffer.total_epochs = 1
config.buffer.batch_size = 4
config.cluster.gpu_per_node = 2
config.cluster.node_num = 1
config.model.model_path = get_model_path()
config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
config.buffer.trainer_input.experience_buffer = StorageConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
wrap_in_ray=True,
)
config.synchronizer.sync_method = self.sync_method
config.synchronizer.sync_style = self.sync_style
config.synchronizer.sync_interval = 2
config.trainer.save_interval = 100
config.monitor.monitor_type = "tensorboard"
trainer_config = deepcopy(config)
trainer_config.mode = "train"
trainer_config.check_and_update()

explorer1_config = deepcopy(config)
explorer1_config.mode = "explore"
explorer1_config.explorer.name = "explorer1"
explorer1_config.explorer.rollout_model.engine_num = 1
explorer1_config.explorer.rollout_model.tensor_parallel_size = 1
explorer1_config.buffer.explorer_output = StorageConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
wrap_in_ray=True,
)
explorer2_config = deepcopy(explorer1_config)
explorer2_config.explorer.name = "explorer2"
explorer1_config.check_and_update()
explorer2_config.check_and_update()

trainer_process = multiprocessing.Process(
target=run_trainer, args=(trainer_config, self.max_steps, self.trainer_intervals)
)
trainer_process.start()
ray.init(ignore_reinit_error=True)
while True:
try:
ray.get_actor("queue-exp_buffer", namespace=trainer_config.ray_namespace)
break
except ValueError:
print("waiting for trainer to start.")
time.sleep(5)
explorer_process_1 = multiprocessing.Process(
target=run_explorer,
args=(explorer1_config, self.max_steps, self.explorer1_intervals),
)
explorer_process_1.start()
explorer_process_2 = multiprocessing.Process(
target=run_explorer, args=(explorer2_config, self.max_steps, self.explorer2_intervals)
)
explorer_process_2.start()

explorer_process_1.join(timeout=200)
explorer_process_2.join(timeout=200)
trainer_process.join(timeout=200)

# check the tensorboard
parser = TensorBoardParser(
os.path.join(trainer_config.monitor.cache_dir, "tensorboard", "trainer")
)
actor_metrics = parser.metric_list("actor")
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8)
parser = TensorBoardParser(
os.path.join(explorer1_config.monitor.cache_dir, "tensorboard", "explorer1")
)
rollout_metrics = parser.metric_list("rollout")
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)
parser = TensorBoardParser(
os.path.join(explorer2_config.monitor.cache_dir, "tensorboard", "explorer2")
)
rollout_metrics = parser.metric_list("rollout")
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)


@parameterized_class(
("sync_style", "max_steps", "trainer_intervals", "explorer_intervals"),
[
(
SyncStyle.FIXED,
8,
[2, 1, 2, 1, 2, 1, 2, 1],
[0, 2.5, 2.5, 2.5, 2.5, 0],
),
(
SyncStyle.DYNAMIC_BY_EXPLORER,
8,
[2, 1, 2, 1, 2, 1, 2, 1],
[0, 0.5, 0.5, 0.5, 0.5, 0],
),
],
)
class TestNCCLBasedSynchronizer(BaseTestSynchronizer):
def test_synchronizer(self):
config = get_template_config()
config.project = "unittest"
config.name = f"test_synchronizer_{datetime.now().strftime('%Y%m%d%H%M%S')}"
config.checkpoint_root_dir = get_checkpoint_path()
config.buffer.total_epochs = 1
config.buffer.batch_size = 4
config.model.model_path = get_model_path()
config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
config.buffer.trainer_input.experience_buffer = StorageConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
wrap_in_ray=True,
)
config.synchronizer.sync_method = SyncMethod.NCCL
config.synchronizer.sync_style = self.sync_style
config.synchronizer.sync_interval = 2
config.trainer.save_interval = 100
config.monitor.monitor_type = "tensorboard"
config.mode = "both"
config.check_and_update()

# TODO: test more interval cases
both_process = multiprocessing.Process(
target=run_both,
args=(config, self.max_steps, self.trainer_intervals, self.explorer_intervals),
)
both_process.start()
both_process.join(timeout=200)

# check the tensorboard
parser = TensorBoardParser(os.path.join(config.monitor.cache_dir, "tensorboard", "trainer"))
actor_metrics = parser.metric_list("actor")
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8)
parser = TensorBoardParser(
os.path.join(config.monitor.cache_dir, "tensorboard", "explorer")
)
rollout_metrics = parser.metric_list("rollout")
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)

def tearDown(self):
if os.path.exists(CHECKPOINT_ROOT_DIR):
shutil.rmtree(CHECKPOINT_ROOT_DIR)
2 changes: 0 additions & 2 deletions tests/explorer/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def init_process_group(
group_name: str,
backend: str = "nccl",
timeout: int = 1200,
update_with_checkpoint: bool = True,
) -> None:
pass

Expand All @@ -91,7 +90,6 @@ def init_process_group(
group_name: str,
backend: str = "nccl",
timeout: int = 1200,
update_with_checkpoint: bool = True,
) -> None:
pass

Expand Down
4 changes: 3 additions & 1 deletion tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# -*- coding: utf-8 -*-
"""Test for the workflow module"""
import unittest
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Dict, Optional
from unittest.mock import MagicMock

from torch import Tensor

from tests.tools import get_unittest_dataset_config
from trinity.common.experience import EID
from trinity.common.rewards import RMGalleryFn
from trinity.common.workflows import (
MathBoxedWorkflow,
Expand All @@ -27,6 +28,7 @@ class MockResponse:
unique_id: Optional[str] = "0"
tokens: Optional[Tensor] = Tensor([0, 0])
prompt_length: int = 1
eid: EID = field(default_factory=EID)


class DummyWorkflow(Workflow):
Expand Down
4 changes: 3 additions & 1 deletion tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

def get_template_config() -> Config:
config_path = os.path.join(os.path.dirname(__file__), "template", "config.yaml")
return load_config(config_path)
config = load_config(config_path)
config.ray_namespace = ray.get_runtime_context().namespace
return config


def get_model_path() -> str:
Expand Down
5 changes: 3 additions & 2 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from trinity.cli.launcher import bench, both, explore, train
from trinity.common.config import Config, StorageConfig
from trinity.common.constants import StorageType, SyncMethod
from trinity.common.constants import StorageType, SyncMethod, SyncStyle
from trinity.common.models.utils import get_checkpoint_dir_with_step_num
from trinity.manager.manager import CacheManager

Expand Down Expand Up @@ -99,7 +99,7 @@ def test_trainer(self):
self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_8, "actor"))) > 0)
self.assertEqual(step_num, 8)
# TODO: Reinit will fail when using v1 engine, find a way to fix it
ray.init(ignore_reinit_error=True)
ray.init(ignore_reinit_error=True, namespace=self.config.ray_namespace)
# test bench mode
self.config.mode = "bench"
self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT
Expand Down Expand Up @@ -362,6 +362,7 @@ def test_fully_async_mode(self, name, use_priority_queue):
use_priority_queue=use_priority_queue,
)
config.synchronizer.sync_method = SyncMethod.CHECKPOINT
config.synchronizer.sync_style = SyncStyle.DYNAMIC_BY_EXPLORER
config.synchronizer.sync_interval = 8
config.monitor.monitor_type = "tensorboard"
trainer_config = deepcopy(config)
Expand Down
4 changes: 4 additions & 0 deletions trinity/buffer/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def __init__(self, capacity: int):
async def close(self) -> None:
"""Close the queue."""
self._closed = True
for getter in self._getters:
if not getter.done():
getter.set_exception(StopAsyncIteration())
self._getters.clear()

def stopped(self) -> bool:
"""Check if there is no more data to read."""
Expand Down
Loading