Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ buffer:
- For `sql` storage type, the path points to the SQLite database file.
- `wrap_in_ray`: Whether to wrap the experience buffer in a Ray actor. Only take effect when `storage_type` is `sql` or `file`. The `queue` storage always uses a Ray actor.
- `max_read_timeout`: The maximum waiting time (in seconds) to read new experience data. If exceeded, an incomplete batch will be returned directly. Only take effect when `storage_type` is `queue`. Default is 1800 seconds (30 minutes).
- `use_priority_queue`: Only take effect when `storage_type` is `queue`. If set to `True`, the queue will be a priority queue, which allows for prioritizing certain experiences over others. Default is `False`.
- `reuse_cooldown_time`: Only take effect when `storage_type` is `queue` and `use_priority_queue` is `True`. If set, it specifies the cooldown time (in seconds) for reusing experiences. If not specified, the default value is `None`, meaning experiences can not be reused.


### Trainer Input
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"tensorboard",
"openai",
"jsonlines",
"sortedcontainers",
]

[project.scripts]
Expand All @@ -61,6 +62,7 @@ dev = [
"mypy>=1.7.0",
"pytest>=8.0.0",
"pytest-json-ctrf",
"parameterized",
]

doc = [
Expand Down
2 changes: 1 addition & 1 deletion scripts/docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ RUN apt update && apt install -y \

# For Aliyun users: update pip mirror to aliyun to speed up pip install
RUN pip config set global.index-url http://mirrors.cloud.aliyuncs.com/pypi/simple/ \
&& pip config set global.trusted-host mirrors.cloud.aliyuncs.com
&& pip config set install.trusted-host mirrors.cloud.aliyuncs.com

# copy the Trinity-RFT dir into the workspace
COPY . .
Expand Down
244 changes: 211 additions & 33 deletions tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time

import torch
from parameterized import parameterized

from tests.tools import RayUnittestBaseAysnc
from trinity.buffer.reader.queue_reader import QueueReader
Expand All @@ -15,24 +16,29 @@


class TestQueueBuffer(RayUnittestBaseAysnc):
async def test_queue_buffer(self):
total_num = 8
put_batch_size = 2
read_batch_size = 4
@parameterized.expand(
[
(
"queue",
False,
),
(
"priority_queue",
True,
),
]
)
async def test_queue_buffer(self, name, use_priority_queue):
meta = StorageConfig(
name="test_buffer",
algorithm_type="ppo",
storage_type=StorageType.QUEUE,
max_read_timeout=3,
path=BUFFER_FILE_PATH,
use_priority_queue=use_priority_queue,
)
config = BufferConfig(
max_retry_times=3,
max_retry_interval=1,
read_batch_size=read_batch_size,
)
writer = QueueWriter(meta, config)
reader = QueueReader(meta, config)
writer = QueueWriter(meta, self.config)
reader = QueueReader(meta, self.config)
self.assertEqual(await writer.acquire(), 1)
exps = [
Experience(
Expand All @@ -41,37 +47,76 @@ async def test_queue_buffer(self):
reward=float(i),
logprobs=torch.tensor([0.1]),
)
for i in range(1, put_batch_size + 1)
for i in range(1, self.put_batch_size + 1)
]
for _ in range(total_num // put_batch_size):
for exp in exps:
exp.info = {"model_version": 0, "use_count": 0}
for _ in range(self.total_num // self.put_batch_size):
await writer.write_async(exps)
for _ in range(total_num // read_batch_size):
for _ in range(self.total_num // self.read_batch_size):
exps = reader.read()
self.assertEqual(len(exps), read_batch_size)
print(f"finish read {read_batch_size} experience")
writer.write(
[
Experience(
tokens=torch.tensor([float(j) for j in range(i + 1)]),
prompt_length=i,
reward=float(i),
logprobs=torch.tensor([0.1]),
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
)
for i in range(1, put_batch_size * 2 + 1)
]
)
exps = reader.read(batch_size=put_batch_size * 2)
self.assertEqual(len(exps), put_batch_size * 2)
self.assertEqual(len(exps), self.read_batch_size)
print(f"finish read {self.read_batch_size} experience")
exps = [
Experience(
tokens=torch.tensor([float(j) for j in range(i + 1)]),
prompt_length=i,
reward=float(i),
logprobs=torch.tensor([0.1]),
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
)
for i in range(1, self.put_batch_size * 2 + 1)
]
for exp in exps:
exp.info = {"model_version": 1, "use_count": 0}
writer.write(exps)
exps = reader.read(batch_size=self.put_batch_size * 2)
self.assertEqual(len(exps), self.put_batch_size * 2)
self.assertEqual(await writer.release(), 0)
self.assertRaises(StopIteration, reader.read)
with open(BUFFER_FILE_PATH, "r") as f:
self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2)
self.assertEqual(len(f.readlines()), self.total_num + self.put_batch_size * 2)
st = time.time()
self.assertRaises(TimeoutError, reader.read, batch_size=1)
et = time.time()
self.assertTrue(et - st > 2)

async def test_priority_queue_capacity(self):
# test queue capacity
meta = StorageConfig(
name="test_buffer_small",
algorithm_type="ppo",
storage_type=StorageType.QUEUE,
max_read_timeout=1,
capacity=2,
path=BUFFER_FILE_PATH,
use_priority_queue=True,
replay_buffer_kwargs={"priority_fn": "linear_decay", "decay": 0.6},
)
writer = QueueWriter(meta, self.config)
reader = QueueReader(meta, self.config)

for i in range(4):
writer.write(
[
Experience(
tokens=torch.tensor([1, 2, 3]),
prompt_length=2,
info={"model_version": i, "use_count": 0},
),
]
)

exps = reader.read(batch_size=2)
self.assertEqual(exps[0].info["model_version"], 3)
self.assertEqual(exps[0].info["use_count"], 1)
self.assertEqual(exps[1].info["model_version"], 2)
self.assertEqual(exps[1].info["use_count"], 1)

with self.assertRaises(TimeoutError):
reader.read(batch_size=1)

async def test_queue_buffer_capacity(self):
# test queue capacity
meta = StorageConfig(
name="test_buffer_small",
Expand All @@ -81,8 +126,8 @@ async def test_queue_buffer(self):
capacity=4,
path=BUFFER_FILE_PATH,
)
writer = QueueWriter(meta, config)
reader = QueueReader(meta, config)
writer = QueueWriter(meta, self.config)
reader = QueueReader(meta, self.config)
writer.write([{"content": "hello"}])
writer.write([{"content": "hi"}])
writer.write([{"content": "hello"}])
Expand All @@ -100,6 +145,139 @@ def write_blocking_call():
thread.join(timeout=1)
self.assertFalse(thread.is_alive())

async def test_priority_queue_buffer_reuse(self):
# test queue reuse
meta = StorageConfig(
name="test_buffer_small",
algorithm_type="ppo",
storage_type=StorageType.QUEUE,
max_read_timeout=3,
capacity=4,
path=BUFFER_FILE_PATH,
use_priority_queue=True,
reuse_cooldown_time=0.5,
replay_buffer_kwargs={"priority_fn": "linear_decay", "decay": 0.6},
)
writer = QueueWriter(meta, self.config)
reader = QueueReader(meta, self.config)
for i in range(4):
writer.write(
[
Experience(
tokens=torch.tensor([1, 2, 3]),
prompt_length=2,
info={"model_version": i, "use_count": 0},
),
Experience(
tokens=torch.tensor([1, 2, 3]),
prompt_length=2,
info={"model_version": i, "use_count": 0},
),
]
)

# should not be blocked
def replace_call():
writer.write(
[
Experience(
tokens=torch.tensor([1, 2, 3]),
prompt_length=2,
info={"model_version": 4, "use_count": 0},
),
Experience(
tokens=torch.tensor([1, 2, 3]),
prompt_length=2,
info={"model_version": 4, "use_count": 0},
),
]
)

thread = threading.Thread(target=replace_call)
thread.start()
thread.join(timeout=2)
self.assertFalse(thread.is_alive())

exps = reader.read(batch_size=4)
self.assertEqual(len(exps), 4)
self.assertEqual(exps[0].info["model_version"], 4)
self.assertEqual(exps[0].info["use_count"], 1)
self.assertEqual(exps[2].info["model_version"], 3)
self.assertEqual(exps[2].info["use_count"], 1)

# model_version 4, 3, 2, 1
# use_count 1, 1, 0, 0
# priority 3.4, 2.4, 2.0, 1.0

time.sleep(1)
exps = reader.read(batch_size=4)
self.assertEqual(len(exps), 4)
self.assertEqual(exps[0].info["model_version"], 4)
self.assertEqual(exps[0].info["use_count"], 2)
self.assertEqual(exps[2].info["model_version"], 3)
self.assertEqual(exps[2].info["use_count"], 2)

# model_version 4, 3, 2, 1
# use_count 2, 2, 0, 0
# priority 2.8, 1.8, 2.0, 1.0

time.sleep(1)
exps = reader.read(batch_size=4)
self.assertEqual(len(exps), 4)
self.assertEqual(exps[0].info["model_version"], 4)
self.assertEqual(exps[0].info["use_count"], 3)
self.assertEqual(exps[2].info["model_version"], 2)
self.assertEqual(exps[2].info["use_count"], 1)

# model_version 4, 3, 2, 1
# use_count 3, 2, 1, 0
# priority 2.2, 1.8, 1.4, 1.0

time.sleep(1)
exps = reader.read(batch_size=4)
self.assertEqual(len(exps), 4)
self.assertEqual(exps[0].info["model_version"], 4)
self.assertEqual(exps[0].info["use_count"], 4)
self.assertEqual(exps[2].info["model_version"], 3)
self.assertEqual(exps[2].info["use_count"], 3)

# model_version 4, 3, 2, 1
# use_count 4, 3, 1, 0
# priority 1.6, 1.2, 1.4, 1.0

time.sleep(1)
exps = reader.read(batch_size=4)
self.assertEqual(len(exps), 4)
self.assertEqual(exps[0].info["model_version"], 4)
self.assertEqual(exps[0].info["use_count"], 5)
self.assertEqual(exps[2].info["model_version"], 2)
self.assertEqual(exps[2].info["use_count"], 2)

# model_version 4, 3, 2, 1
# use_count 5, 3, 2, 0
# priority 1.0, 1.2, 0.8, 1.0

time.sleep(1)
exps = reader.read(batch_size=4)
self.assertEqual(len(exps), 4)
self.assertEqual(exps[0].info["model_version"], 3)
self.assertEqual(exps[0].info["use_count"], 4)
self.assertEqual(exps[2].info["model_version"], 1)
self.assertEqual(exps[2].info["use_count"], 1)

# model_version 4, 3, 2, 1
# use_count 5, 4, 2, 1
# priority 1.0, 0.6, 0.8, 0.4

def setUp(self):
self.total_num = 8
self.put_batch_size = 2
self.read_batch_size = 4

self.config = BufferConfig(
max_retry_times=3,
max_retry_interval=1,
read_batch_size=self.read_batch_size,
)
if os.path.exists(BUFFER_FILE_PATH):
os.remove(BUFFER_FILE_PATH)
16 changes: 15 additions & 1 deletion tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from datetime import datetime

import ray
from parameterized import parameterized

from tests.tools import (
RayUnittestBase,
Expand Down Expand Up @@ -301,7 +302,19 @@ def setUp(self):
if multiprocessing.get_start_method(allow_none=True) != "spawn":
multiprocessing.set_start_method("spawn", force=True)

def test_fully_async_mode(self):
@parameterized.expand(
[
(
"queue",
False,
),
(
"priority_queue",
True,
),
]
)
def test_fully_async_mode(self, name, use_priority_queue):
config = get_template_config()
config.project = "unittest"
config.name = f"fully_async_{datetime.now().strftime('%Y%m%d%H%M%S')}"
Expand All @@ -316,6 +329,7 @@ def test_fully_async_mode(self):
name="exp_buffer",
storage_type=StorageType.QUEUE,
wrap_in_ray=True,
use_priority_queue=use_priority_queue,
)
config.synchronizer.sync_method = SyncMethod.CHECKPOINT
config.synchronizer.sync_interval = 8
Expand Down
Loading