Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ checkpoint_root_dir: /PATH/TO/CHECKPOINT
- `explore`: Only launches the explorer.
- `bench`: Used for benchmarking.
- `checkpoint_root_dir`: Root directory where all checkpoints and logs will be saved. Checkpoints for this experiment will be stored in `<checkpoint_root_dir>/<project>/<name>/`.
- `ray_namespace`: Namespace for the modules launched in the current experiment. If not specified, it will be set to `<project>/<name>`.

---

Expand Down Expand Up @@ -166,6 +167,9 @@ buffer:
eval_tasksets:
...

explorer_output:
...

trainer_input:
experience_buffer:
...
Expand Down Expand Up @@ -219,15 +223,15 @@ buffer:

The configuration for each task dataset is defined as follows:

- `name`: Name of the dataset. Name must be unique.
- `name`: Name of the dataset. This name will be used as the Ray actor's name, so it must be unique.
- `storage_type`: How the dataset is stored. Options: `file`, `queue`, `sql`.
- `file`: The dataset is stored in `jsonl`/`parquet` files. The data file organization is required to meet the huggingface standard. *We recommand using this storage type for most cases.*
- `queue`: The dataset is stored in a queue. The queue is a simple FIFO queue that stores the task dataset. *Do not use this storage type for task dataset unless you know what you are doing.*
- `sql`: The dataset is stored in a SQL database. *This type is unstable and will be optimized in the future versions.*
- `path`: The path to the task dataset.
- For `file` storage type, the path is the path to the directory that contains the task dataset files.
- For `file` storage type, the path points to the directory that contains the task dataset files.
- For `queue` storage type, the path is optional. You can back up the data in the queue by specifying a sqlite database path here.
- For `sql` storage type, the path is the path to the sqlite database file.
- For `sql` storage type, the path points to the sqlite database file.
- `subset_name`: The subset name of the task dataset. Default is `None`.
- `split`: The split of the task dataset. Default is `train`.
- `format`: Defines keys for prompts and responses in the dataset.
Expand All @@ -240,6 +244,33 @@ The configuration for each task dataset is defined as follows:
- `workflow_args`: A dictionary of arguments used to supplement dataset-level parameters.


### Explorer Output

In [`explore` mode](#global-configuration), since there is no trainer, users can configure an experience buffer via `buffer.explorer_input`, rather than using `buffer.trainer_input`, which will be introduced in the next section.

> For `both` and `train` modes, users should use `buffer.trainer_input` instead of `buffer.explorer_output`.

```yaml
buffer:
...
explorer_output:
name: countdown_buffer
storage_type: queue
path: sqlite:///countdown_buffer.db
wrap_in_ray: True
```

- `name`: The name of the experience buffer. This name will be used as the Ray actor's name, so it must be unique.
- `storage_type`: The storage type for the experience buffer.
- `queue`: Experience data is stored in a queue. This storage type is recommended for most use cases.
- `sql`: Experience data is stored in a SQL database. If your database only supports local access (e.g., SQLite), set `wrap_in_ray` to `True` to wrap the database in a Ray actor, enabling remote access from other nodes.
- `file`: Experience data is stored in a JSON file. This storage type should be used only for debugging purposes in `explore` mode.
- For `queue` storage type, this field is optional. You can specify a SQLite database or JSON file path here to back up the queue data.
- For `file` storage type, the path points to the directory containing the dataset files.
- For `sql` storage type, the path points to the SQLite database file.
- `wrap_in_ray`: Whether to wrap the experience buffer in a Ray actor. Only take effect when `storage_type` is `sql` or `file`. The `queue` storage always uses a Ray actor.


### Trainer Input

Defines the experience buffer and optional SFT warm-up dataset.
Expand All @@ -264,7 +295,7 @@ buffer:
sft_warmup_steps: 0
```

- `experience_buffer`: Experience replay buffer used by the trainer.
- `experience_buffer`: Experience buffer used by the trainer, which is logically equivalent to `buffer.explorer_output`.
- `sft_warmup_dataset`: Optional dataset used for pre-training (SFT warmup).
- `sft_warmup_steps`: Number of steps to use SFT warm-up before RL begins.

Expand All @@ -276,6 +307,7 @@ Controls the rollout models and workflow execution.

```yaml
explorer:
name: explorer
runner_num: 32
rollout_model:
engine_type: vllm_async
Expand All @@ -286,11 +318,13 @@ explorer:
tensor_parallel_size: 1
```

- `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique.
- `runner_num`: Number of parallel workflow runners.
- `rollout_model.engine_type`: Type of inference engine. Options: `vllm_async` (recommended), `vllm`.
- `rollout_model.engine_num`: Number of inference engines.
- `rollout_model.tensor_parallel_size`: Degree of tensor parallelism.
- `auxiliary_models`: Additional models used for custom workflows.

---

## Synchronizer Configuration
Expand All @@ -301,13 +335,15 @@ Controls how model weights are synchronized between trainer and explorer.
synchronizer:
sync_method: 'nccl'
sync_interval: 10
sync_offset: 0
sync_timeout: 1200
```

- `sync_method`: Method of synchronization. Options:
- `nccl`: Uses NCCL for fast synchronization. Supported for `both` mode.
- `checkpoint`: Loads latest model from disk. Supported for `train`, `explore`, or `bench` mode.
- `sync_interval`: Interval (in steps) of model weight synchronization between trainer and explorer.
- `sync_offset`: Offset (in steps) of model weight synchronization between trainer and explorer. The explorer can run `sync_offset` steps before the trainer starts training.
- `sync_timeout`: Timeout duration for synchronization.

---
Expand All @@ -318,12 +354,14 @@ Specifies the backend and behavior of the trainer.

```yaml
trainer:
name: trainer
trainer_type: 'verl'
save_interval: 100
trainer_config_path: 'examples/ppo_countdown/train_countdown.yaml'
trainer_config: null
```

- `name`: Name of the trainer. This name will be used as the Ray actor's name, so it must be unique.
- `trainer_type`: Trainer backend implementation. Currently only supports `verl`.
- `save_interval`: Frequency (in steps) at which to save model checkpoints.
- `trainer_config_path`: The path to the trainer configuration file.
Expand Down
2 changes: 1 addition & 1 deletion tests/buffer/file_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_file_buffer(self):
# test writer
writer = JSONWriter(meta, None)
writer.write(data)
writer.finish()
writer.release()

# test reader
meta.path = self.temp_output_path
Expand Down
3 changes: 2 additions & 1 deletion tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def test_queue_buffer(self):
)
writer = QueueWriter(meta, config)
reader = QueueReader(meta, config)
self.assertEqual(writer.acquire(), 1)
exps = [
Experience(
tokens=torch.tensor([float(j) for j in range(i + 1)]),
Expand Down Expand Up @@ -59,7 +60,7 @@ def test_queue_buffer(self):
)
exps = reader.read(batch_size=put_batch_size * 2)
self.assertEqual(len(exps), put_batch_size * 2)
writer.finish()
self.assertEqual(writer.release(), 0)
self.assertRaises(StopIteration, reader.read)
with open(BUFFER_FILE_PATH, "r") as f:
self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2)
Expand Down
3 changes: 3 additions & 0 deletions tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_create_sql_buffer(self) -> None:
)
for i in range(1, put_batch_size + 1)
]
self.assertEqual(sql_writer.acquire(), 1)
for _ in range(total_num // put_batch_size):
sql_writer.write(exps)
for _ in range(total_num // read_batch_size):
Expand All @@ -65,3 +66,5 @@ def test_create_sql_buffer(self) -> None:
self.assertEqual(len(exps), put_batch_size * 2)
db_wrapper = ray.get_actor("sql-test_buffer")
self.assertIsNotNone(db_wrapper)
self.assertEqual(sql_writer.release(), 0)
self.assertRaises(StopIteration, sql_reader.read)
2 changes: 1 addition & 1 deletion tests/explorer/runner_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,6 @@ def test_runner_pool_with_auxiliary_models(self):
st = time.time()
status = pool.get_next_unorder()
et = time.time()
self.assertTrue(et - st < 1)
self.assertTrue(et - st < 1.5)
self.assertEqual(len(status), 1)
self.assertTrue(status[0].ok)
2 changes: 1 addition & 1 deletion tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_checkpoint_path() -> str:
def get_unittest_dataset_config(
dataset_name: str = "countdown", split: str = "train"
) -> StorageConfig:
"""Countdown sample dataset for 8 steps"""
"""Countdown dataset with 16 samples."""
if dataset_name == "countdown" or dataset_name == "copy_countdown":
return StorageConfig(
name=dataset_name,
Expand Down
137 changes: 134 additions & 3 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""Tests for trainer."""
import multiprocessing
import os
import shutil
import time
import unittest
from abc import abstractmethod
from copy import deepcopy
from datetime import datetime

import ray
Expand All @@ -14,8 +18,11 @@
get_template_config,
get_unittest_dataset_config,
)
from trinity.cli.launcher import bench, both, train
from trinity.common.constants import SyncMethod
from trinity.cli.launcher import bench, both, explore, train
from trinity.common.config import Config, StorageConfig
from trinity.common.constants import StorageType, SyncMethod
from trinity.common.models.utils import get_checkpoint_dir_with_step_num
from trinity.manager.manager import CacheManager


class BaseTrainerCase(RayUnittestBase):
Expand Down Expand Up @@ -149,7 +156,6 @@ def test_trainer(self):
response_metrics = parser.metric_list("response_length")
self.assertTrue(len(response_metrics) > 0)
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
ray.timeline(filename="timeline.json")
ray.shutdown(_exiting_interpreter=True)
# check checkpoint
from trinity.common.models.utils import get_checkpoint_dir_with_step_num
Expand Down Expand Up @@ -262,3 +268,128 @@ def test_trainer(self):
def tearDown(self):
# remove dir only when the test passed
shutil.rmtree(self.config.checkpoint_job_dir)


def run_trainer(config: Config) -> None:
ray.init(namespace=config.ray_namespace)
train(config)


def run_explorer(config: Config) -> None:
ray.init(namespace=config.ray_namespace)
explore(config)


class TestFullyAsyncMode(unittest.TestCase):
def setUp(self):
if multiprocessing.get_start_method(allow_none=True) != "spawn":
multiprocessing.set_start_method("spawn", force=True)

def test_fully_async_mode(self):
config = get_template_config()
config.project = "unittest"
config.name = f"fully_async_{datetime.now().strftime('%Y%m%d%H%M%S')}"
config.checkpoint_root_dir = get_checkpoint_path()
config.buffer.total_epochs = 1
config.buffer.batch_size = 4
config.cluster.gpu_per_node = 2
config.cluster.node_num = 1
config.model.model_path = get_model_path()
config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
config.buffer.trainer_input.experience_buffer = StorageConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
wrap_in_ray=True,
)
config.synchronizer.sync_method = SyncMethod.CHECKPOINT
config.synchronizer.sync_interval = 8
config.monitor.monitor_type = "tensorboard"
trainer_config = deepcopy(config)
trainer_config.mode = "train"
trainer_config.check_and_update()

explorer1_config = deepcopy(config)
explorer1_config.mode = "explore"
explorer1_config.explorer.name = "explorer1"
config.cluster.gpu_per_node = 1
config.cluster.node_num = 1
explorer1_config.explorer.rollout_model.engine_num = 1
explorer1_config.explorer.rollout_model.tensor_parallel_size = 1
explorer1_config.explorer.runner_num = 4
explorer1_config.buffer.explorer_output = StorageConfig(
name="exp_buffer",
storage_type=StorageType.QUEUE,
wrap_in_ray=True,
)
explorer2_config = deepcopy(explorer1_config)
explorer1_config.check_and_update()

trainer_process = multiprocessing.Process(target=run_trainer, args=(trainer_config,))
trainer_process.start()

ray.init(ignore_reinit_error=True)
while True:
try:
ray.get_actor("queue-exp_buffer", namespace=trainer_config.ray_namespace)
break
except ValueError:
print("waiting for trainer to start.")
time.sleep(5)

explorer_process_1 = multiprocessing.Process(target=run_explorer, args=(explorer1_config,))
explorer_process_1.start()

time.sleep(20)
explorer2_config.explorer.name = "explorer2"
explorer2_config.check_and_update()
explorer_process_2 = multiprocessing.Process(target=run_explorer, args=(explorer2_config,))
explorer_process_2.start()

explorer_process_1.join()
explorer_process_2.join()

# wait for trainer process to finish.
trainer_process.join(timeout=200)

# check the tensorboard
parser = TensorBoardParser(
os.path.join(trainer_config.monitor.cache_dir, "tensorboard", "trainer")
)
actor_metrics = parser.metric_list("actor")
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 8)
parser = TensorBoardParser(
os.path.join(explorer1_config.monitor.cache_dir, "tensorboard", "explorer1")
)
rollout_metrics = parser.metric_list("rollout")
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
parser = TensorBoardParser(
os.path.join(explorer2_config.monitor.cache_dir, "tensorboard", "explorer2")
)
rollout_metrics = parser.metric_list("rollout")
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
# check the checkpoint
explorer1_cache = CacheManager(explorer1_config)
cache = explorer1_cache.load_explorer()
self.assertEqual(cache["latest_iteration"], 4)
explorer2_cache = CacheManager(explorer2_config)
cache = explorer2_cache.load_explorer()
self.assertEqual(cache["latest_iteration"], 4)
self.assertIsNotNone(
get_checkpoint_dir_with_step_num(
checkpoint_root_path=explorer1_config.checkpoint_job_dir,
trainer_type="verl",
step_num=8,
)
)
self.assertIsNotNone(
get_checkpoint_dir_with_step_num(
checkpoint_root_path=explorer2_config.checkpoint_job_dir,
trainer_type="verl",
step_num=8,
)
)
ray.shutdown()

def tearDown(self):
checkpoint_path = get_checkpoint_path()
shutil.rmtree(os.path.join(checkpoint_path, "unittest"))
16 changes: 14 additions & 2 deletions trinity/buffer/buffer_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,17 @@ def write(self, data: List) -> None:
"""Write to buffer."""

@abstractmethod
def finish(self) -> None:
"""Finish writing."""
def acquire(self) -> int:
"""Acquire the buffer writer.

Returns:
`int`: The reference count of the buffer after acquiring.
"""

@abstractmethod
def release(self) -> int:
"""Release the buffer writer. After release, the buffer writer can not be used again.

Returns:
`int`: The reference count of the buffer after releasing.
"""
Loading