Skip to content

Commit 8c6107d

Browse files
authored
Refactor Buffer Reader (#80)
1 parent f862e11 commit 8c6107d

File tree

13 files changed

+294
-100
lines changed

13 files changed

+294
-100
lines changed

tests/buffer/file_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import unittest
2+
3+
from tests.tools import get_template_config, get_unittest_dataset_config
4+
from trinity.buffer.buffer import get_buffer_reader
5+
6+
7+
class TestFileReader(unittest.TestCase):
8+
def test_file_reader(self):
9+
"""Test file reader."""
10+
config = get_template_config()
11+
dataset_config = get_unittest_dataset_config("countdown", "train")
12+
config.buffer.explorer_input.taskset = dataset_config
13+
reader = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer)
14+
15+
tasks = []
16+
while True:
17+
try:
18+
tasks.extend(reader.read())
19+
except StopIteration:
20+
break
21+
self.assertEqual(len(tasks), 16)
22+
23+
config.buffer.explorer_input.taskset.total_epochs = 2
24+
config.buffer.explorer_input.taskset.index = 4
25+
reader = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer)
26+
tasks = []
27+
while True:
28+
try:
29+
tasks.extend(reader.read())
30+
except StopIteration:
31+
break
32+
self.assertEqual(len(tasks), 16 * 2 - 4)

tests/buffer/queue_test.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import torch
24

35
from tests.tools import RayUnittestBase
@@ -7,6 +9,8 @@
79
from trinity.common.constants import AlgorithmType, StorageType
810
from trinity.common.experience import Experience
911

12+
file_path = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl")
13+
1014

1115
class TestQueueBuffer(RayUnittestBase):
1216
def test_queue_buffer(self):
@@ -17,6 +21,7 @@ def test_queue_buffer(self):
1721
name="test_buffer",
1822
algorithm_type=AlgorithmType.PPO,
1923
storage_type=StorageType.QUEUE,
24+
path=file_path,
2025
)
2126
config = BufferConfig(
2227
max_retry_times=3,
@@ -36,9 +41,29 @@ def test_queue_buffer(self):
3641
]
3742
for _ in range(total_num // put_batch_size):
3843
writer.write(exps)
39-
writer.finish()
4044
for _ in range(total_num // read_batch_size):
4145
exps = reader.read()
4246
self.assertEqual(len(exps), read_batch_size)
4347
print(f"finish read {read_batch_size} experience")
48+
writer.write(
49+
[
50+
Experience(
51+
tokens=torch.tensor([float(j) for j in range(i + 1)]),
52+
prompt_length=i,
53+
reward=float(i),
54+
logprobs=torch.tensor([0.1]),
55+
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
56+
)
57+
for i in range(1, put_batch_size * 2 + 1)
58+
]
59+
)
60+
exps = reader.read(batch_size=put_batch_size * 2)
61+
self.assertEqual(len(exps), put_batch_size * 2)
62+
writer.finish()
4463
self.assertRaises(StopIteration, reader.read)
64+
with open(file_path, "r") as f:
65+
self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2)
66+
67+
def setUp(self):
68+
if os.path.exists(file_path):
69+
os.remove(file_path)

tests/buffer/sql_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,21 @@ def test_create_sql_buffer(self) -> None:
4747
for _ in range(total_num // read_batch_size):
4848
exps = sql_reader.read()
4949
self.assertEqual(len(exps), read_batch_size)
50+
51+
# dynamic read/write
52+
sql_writer.write(
53+
[
54+
Experience(
55+
tokens=torch.tensor([float(j) for j in range(i + 1)]),
56+
prompt_length=i,
57+
reward=float(i),
58+
logprobs=torch.tensor([0.1]),
59+
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
60+
)
61+
for i in range(1, put_batch_size * 2 + 1)
62+
]
63+
)
64+
exps = sql_reader.read(batch_size=put_batch_size * 2)
65+
self.assertEqual(len(exps), put_batch_size * 2)
5066
db_wrapper = ray.get_actor("sql-test_buffer")
5167
self.assertIsNotNone(db_wrapper)

trinity/buffer/buffer_reader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,7 @@ class BufferReader(ABC):
99
"""Interface of the buffer reader."""
1010

1111
@abstractmethod
12-
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
12+
def read(
13+
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
14+
) -> List:
1315
"""Read from buffer."""

trinity/buffer/db_wrapper.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def write(self, data: list) -> None:
6161
experience_models = [self.table_model_cls.from_experience(exp) for exp in data]
6262
session.add_all(experience_models)
6363

64-
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
64+
def read(
65+
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
66+
) -> List:
6567
if strategy is None:
6668
strategy = ReadStrategy.LFU
6769

@@ -78,7 +80,8 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
7880
raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage")
7981

8082
exp_list = []
81-
while len(exp_list) < self.batch_size:
83+
batch_size = batch_size or self.batch_size
84+
while len(exp_list) < batch_size:
8285
if len(exp_list):
8386
self.logger.info("waiting for experiences...")
8487
time.sleep(1)
@@ -90,7 +93,7 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
9093
session.query(self.table_model_cls)
9194
.filter(self.table_model_cls.reward.isnot(None))
9295
.order_by(*sortOrder) # TODO: very slow
93-
.limit(self.batch_size - len(exp_list))
96+
.limit(batch_size - len(exp_list))
9497
.with_for_update()
9598
.all()
9699
)

trinity/buffer/queue.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,20 @@
55

66
import ray
77

8+
from trinity.buffer.writer.file_writer import JSONWriter
89
from trinity.buffer.writer.sql_writer import SQLWriter
910
from trinity.common.config import BufferConfig, StorageConfig
1011
from trinity.common.constants import StorageType
1112

1213

14+
def is_database_url(path: str) -> bool:
15+
return any(path.startswith(prefix) for prefix in ["sqlite:///", "postgresql://", "mysql://"])
16+
17+
18+
def is_json_file(path: str) -> bool:
19+
return path.endswith(".json") or path.endswith(".jsonl")
20+
21+
1322
@ray.remote
1423
class QueueActor:
1524
"""An asyncio.Queue based queue actor."""
@@ -21,12 +30,21 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
2130
self.capacity = getattr(config, "capacity", 10000)
2231
self.queue = asyncio.Queue(self.capacity)
2332
if storage_config.path is not None and len(storage_config.path) > 0:
24-
sql_config = deepcopy(storage_config)
25-
sql_config.storage_type = StorageType.SQL
26-
sql_config.wrap_in_ray = False
27-
self.sql_writer = SQLWriter(sql_config, self.config)
33+
if is_database_url(storage_config.path):
34+
storage_config.storage_type = StorageType.SQL
35+
sql_config = deepcopy(storage_config)
36+
sql_config.storage_type = StorageType.SQL
37+
sql_config.wrap_in_ray = False
38+
self.writer = SQLWriter(sql_config, self.config)
39+
elif is_json_file(storage_config.path):
40+
storage_config.storage_type = StorageType.FILE
41+
json_config = deepcopy(storage_config)
42+
json_config.storage_type = StorageType.FILE
43+
self.writer = JSONWriter(json_config, self.config)
44+
else:
45+
self.writer = None
2846
else:
29-
self.sql_writer = None
47+
self.writer = None
3048

3149
def length(self) -> int:
3250
"""The length of the queue."""
@@ -35,8 +53,8 @@ def length(self) -> int:
3553
async def put_batch(self, exp_list: List) -> None:
3654
"""Put batch of experience."""
3755
await self.queue.put(exp_list)
38-
if self.sql_writer is not None:
39-
self.sql_writer.write(exp_list)
56+
if self.writer is not None:
57+
self.writer.write(exp_list)
4058

4159
async def finish(self) -> None:
4260
"""Stop the queue."""

0 commit comments

Comments
 (0)