Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ Controls the rollout models and workflow execution.
```yaml
explorer:
name: explorer
runner_num: 32
runner_per_model: 8
max_timeout: 900
max_retry_times: 2
env_vars: {}
Expand All @@ -325,17 +325,22 @@ explorer:
auxiliary_models:
- model_path: /PATH/TO/MODEL
tensor_parallel_size: 1
eval_interval: 100
eval_on_startup: True
```

- `name`: Name of the explorer. This name will be used as the Ray actor's name, so it must be unique.
- `runner_num`: Number of parallel workflow runners.
- `runner_per_model`: Number of parallel workflow runners per each rollout model.
- `max_timeout`: Maximum time (in seconds) for a workflow to complete.
- `max_retry_times`: Maximum number of retries for a workflow.
- `env_vars`: Environment variables to be set for every workflow runners.
- `rollout_model.engine_type`: Type of inference engine. Options: `vllm_async` (recommended), `vllm`.
- `rollout_model.engine_num`: Number of inference engines.
- `rollout_model.tensor_parallel_size`: Degree of tensor parallelism.
- `auxiliary_models`: Additional models used for custom workflows.
- `eval_interval`: Interval (in steps) for evaluating the model.
- `eval_on_startup`: Whether to evaluate the model on startup. More precisely, at step 0 with the original model, so it will not be triggered when restarting.
- `runner_num`: (*Deprecated*) Number of parallel workflow runners.

---

Expand Down
7 changes: 5 additions & 2 deletions examples/grpo_alfworld/alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ buffer:
path: 'sqlite:///alfworld.db'
explorer:
runner_num: 32
max_timeout: 3600
rollout_model:
engine_type: vllm_async
engine_num: 2
Expand All @@ -44,10 +45,12 @@ explorer:
seed: 42
gpu_memory_utilization: 0.7
enable_chunked_prefill: true
env_vars:
TMPDIR: /PATH/TO/ALFWORLD_TMP_DIR
synchronizer:
sync_method: 'nccl'
sync_interval: 8
sync_timeout: 1200
sync_interval: 5
sync_timeout: 3600
trainer:
trainer_type: 'verl'
trainer_config_path: 'examples/grpo_alfworld/train_alfworld.yaml'
Expand Down
19 changes: 14 additions & 5 deletions tests/buffer/file_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from trinity.common.constants import StorageType


class TestFileBuffer(unittest.TestCase):
class TestFileBuffer(unittest.IsolatedAsyncioTestCase):
temp_output_path = "tmp/test_file_buffer/"

@classmethod
Expand All @@ -30,7 +30,7 @@ def tearDownClass(cls):
if os.path.exists(cls.temp_output_path):
os.system(f"rm -rf {cls.temp_output_path}")

def test_file_buffer(self):
async def test_file_buffer(self):
meta = StorageConfig(
name="test_buffer",
path=os.path.join(self.temp_output_path, "buffer.jsonl"),
Expand All @@ -46,8 +46,9 @@ def test_file_buffer(self):

# test writer
writer = JSONWriter(meta, None)
await writer.acquire()
writer.write(data)
writer.release()
await writer.release()

# test reader
meta.path = self.temp_output_path
Expand Down Expand Up @@ -119,23 +120,31 @@ def test_file_reader(self): # noqa: C901
break
self.assertEqual(len(tasks), 40 - 24)

def test_file_writer(self):
async def test_file_writer(self):
writer = get_buffer_writer(
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
)
await writer.acquire()
writer.write(
[
{"prompt": "hello world"},
{"prompt": "hi"},
]
)
await writer.write_async(
[
{"prompt": "My name is"},
{"prompt": "What is your name?"},
]
)
await writer.release()
file_wrapper = ray.get_actor("json-test_buffer")
self.assertIsNotNone(file_wrapper)
file_path = default_storage_path(
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
)
with open(file_path, "r") as f:
self.assertEqual(len(f.readlines()), 2)
self.assertEqual(len(f.readlines()), 4)

def setUp(self):
self.config = get_template_config()
Expand Down
41 changes: 35 additions & 6 deletions tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import threading
import time

import torch

from tests.tools import RayUnittestBase
from tests.tools import RayUnittestBaseAysnc
from trinity.buffer.reader.queue_reader import QueueReader
from trinity.buffer.writer.queue_writer import QueueWriter
from trinity.common.config import BufferConfig, StorageConfig
Expand All @@ -13,8 +14,8 @@
BUFFER_FILE_PATH = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl")


class TestQueueBuffer(RayUnittestBase):
def test_queue_buffer(self):
class TestQueueBuffer(RayUnittestBaseAysnc):
async def test_queue_buffer(self):
total_num = 8
put_batch_size = 2
read_batch_size = 4
Expand All @@ -32,7 +33,7 @@ def test_queue_buffer(self):
)
writer = QueueWriter(meta, config)
reader = QueueReader(meta, config)
self.assertEqual(writer.acquire(), 1)
self.assertEqual(await writer.acquire(), 1)
exps = [
Experience(
tokens=torch.tensor([float(j) for j in range(i + 1)]),
Expand All @@ -43,7 +44,7 @@ def test_queue_buffer(self):
for i in range(1, put_batch_size + 1)
]
for _ in range(total_num // put_batch_size):
writer.write(exps)
await writer.write_async(exps)
for _ in range(total_num // read_batch_size):
exps = reader.read()
self.assertEqual(len(exps), read_batch_size)
Expand All @@ -62,7 +63,7 @@ def test_queue_buffer(self):
)
exps = reader.read(batch_size=put_batch_size * 2)
self.assertEqual(len(exps), put_batch_size * 2)
self.assertEqual(writer.release(), 0)
self.assertEqual(await writer.release(), 0)
self.assertRaises(StopIteration, reader.read)
with open(BUFFER_FILE_PATH, "r") as f:
self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2)
Expand All @@ -71,6 +72,34 @@ def test_queue_buffer(self):
et = time.time()
self.assertTrue(et - st > 2)

# test queue capacity
meta = StorageConfig(
name="test_buffer_small",
algorithm_type="ppo",
storage_type=StorageType.QUEUE,
max_read_timeout=3,
capacity=4,
path=BUFFER_FILE_PATH,
)
writer = QueueWriter(meta, config)
reader = QueueReader(meta, config)
writer.write([{"content": "hello"}])
writer.write([{"content": "hi"}])
writer.write([{"content": "hello"}])
writer.write([{"content": "hi"}])

# should be blocked
def write_blocking_call():
writer.write([{"content": "blocked"}])

thread = threading.Thread(target=write_blocking_call)
thread.start()
thread.join(timeout=2)
self.assertTrue(thread.is_alive(), "write() did not block as expected")
reader.read()
thread.join(timeout=1)
self.assertFalse(thread.is_alive())

def setUp(self):
if os.path.exists(BUFFER_FILE_PATH):
os.remove(BUFFER_FILE_PATH)
12 changes: 6 additions & 6 deletions tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import unittest

import ray
import torch

from tests.tools import RayUnittestBaseAysnc
from trinity.buffer.reader.sql_reader import SQLReader
from trinity.buffer.writer.sql_writer import SQLWriter
from trinity.common.config import BufferConfig, StorageConfig
Expand All @@ -13,8 +13,8 @@
db_path = os.path.join(os.path.dirname(__file__), "test.db")


class TestSQLBuffer(unittest.TestCase):
def test_create_sql_buffer(self) -> None:
class TestSQLBuffer(RayUnittestBaseAysnc):
async def test_create_sql_buffer(self) -> None:
total_num = 8
put_batch_size = 2
read_batch_size = 4
Expand Down Expand Up @@ -42,9 +42,9 @@ def test_create_sql_buffer(self) -> None:
)
for i in range(1, put_batch_size + 1)
]
self.assertEqual(sql_writer.acquire(), 1)
self.assertEqual(await sql_writer.acquire(), 1)
for _ in range(total_num // put_batch_size):
sql_writer.write(exps)
await sql_writer.write_async(exps)
for _ in range(total_num // read_batch_size):
exps = sql_reader.read()
self.assertEqual(len(exps), read_batch_size)
Expand All @@ -66,5 +66,5 @@ def test_create_sql_buffer(self) -> None:
self.assertEqual(len(exps), put_batch_size * 2)
db_wrapper = ray.get_actor("sql-test_buffer")
self.assertIsNotNone(db_wrapper)
self.assertEqual(sql_writer.release(), 0)
self.assertEqual(await sql_writer.release(), 0)
self.assertRaises(StopIteration, sql_reader.read)
Loading