Skip to content

Commit ab28d0c

Browse files
committed
merge main
2 parents a592af7 + ad77ffe commit ab28d0c

File tree

18 files changed

+460
-106
lines changed

18 files changed

+460
-106
lines changed

tests/buffer/file_test.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import os
2+
import unittest
3+
4+
import ray
5+
6+
from tests.tools import (
7+
get_checkpoint_path,
8+
get_template_config,
9+
get_unittest_dataset_config,
10+
)
11+
from trinity.buffer.buffer import get_buffer_reader, get_buffer_writer
12+
from trinity.buffer.utils import default_storage_path
13+
from trinity.common.config import StorageConfig
14+
from trinity.common.constants import StorageType
15+
16+
17+
class TestFileBuffer(unittest.TestCase):
18+
def test_file_reader(self):
19+
"""Test file reader."""
20+
reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer)
21+
22+
tasks = []
23+
while True:
24+
try:
25+
tasks.extend(reader.read())
26+
except StopIteration:
27+
break
28+
self.assertEqual(len(tasks), 16)
29+
30+
# test epoch and offset
31+
self.config.buffer.explorer_input.taskset.total_epochs = 2
32+
self.config.buffer.explorer_input.taskset.index = 4
33+
reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer)
34+
tasks = []
35+
while True:
36+
try:
37+
tasks.extend(reader.read())
38+
except StopIteration:
39+
break
40+
self.assertEqual(len(tasks), 16 * 2 - 4)
41+
42+
# test offset > dataset_len
43+
self.config.buffer.explorer_input.taskset.total_epochs = 3
44+
self.config.buffer.explorer_input.taskset.index = 20
45+
reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer)
46+
tasks = []
47+
while True:
48+
try:
49+
tasks.extend(reader.read())
50+
except StopIteration:
51+
break
52+
self.assertEqual(len(tasks), 16 * 3 - 20)
53+
54+
def test_file_writer(self):
55+
writer = get_buffer_writer(
56+
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
57+
)
58+
writer.write(
59+
[
60+
{"prompt": "hello world"},
61+
{"prompt": "hi"},
62+
]
63+
)
64+
file_wrapper = ray.get_actor("json-test_buffer")
65+
self.assertIsNotNone(file_wrapper)
66+
file_path = default_storage_path(
67+
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
68+
)
69+
with open(file_path, "r") as f:
70+
self.assertEqual(len(f.readlines()), 2)
71+
72+
def setUp(self):
73+
self.config = get_template_config()
74+
self.config.checkpoint_root_dir = get_checkpoint_path()
75+
dataset_config = get_unittest_dataset_config("countdown", "train")
76+
self.config.buffer.explorer_input.taskset = dataset_config
77+
self.config.buffer.trainer_input.experience_buffer = StorageConfig(
78+
name="test_buffer", storage_type=StorageType.FILE
79+
)
80+
self.config.buffer.trainer_input.experience_buffer.name = "test_buffer"
81+
self.config.buffer.cache_dir = os.path.join(
82+
self.config.checkpoint_root_dir, self.config.project, self.config.name, "buffer"
83+
)
84+
os.makedirs(self.config.buffer.cache_dir, exist_ok=True)
85+
if os.path.exists(
86+
default_storage_path(
87+
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
88+
)
89+
):
90+
os.remove(
91+
default_storage_path(
92+
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
93+
)
94+
)

tests/buffer/queue_test.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import torch
24

35
from tests.tools import RayUnittestBase
@@ -7,6 +9,8 @@
79
from trinity.common.constants import StorageType
810
from trinity.common.experience import Experience
911

12+
BUFFER_FILE_PATH = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl")
13+
1014

1115
class TestQueueBuffer(RayUnittestBase):
1216
def test_queue_buffer(self):
@@ -17,6 +21,7 @@ def test_queue_buffer(self):
1721
name="test_buffer",
1822
algorithm_type="ppo",
1923
storage_type=StorageType.QUEUE,
24+
path=BUFFER_FILE_PATH,
2025
)
2126
config = BufferConfig(
2227
max_retry_times=3,
@@ -36,9 +41,29 @@ def test_queue_buffer(self):
3641
]
3742
for _ in range(total_num // put_batch_size):
3843
writer.write(exps)
39-
writer.finish()
4044
for _ in range(total_num // read_batch_size):
4145
exps = reader.read()
4246
self.assertEqual(len(exps), read_batch_size)
4347
print(f"finish read {read_batch_size} experience")
48+
writer.write(
49+
[
50+
Experience(
51+
tokens=torch.tensor([float(j) for j in range(i + 1)]),
52+
prompt_length=i,
53+
reward=float(i),
54+
logprobs=torch.tensor([0.1]),
55+
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
56+
)
57+
for i in range(1, put_batch_size * 2 + 1)
58+
]
59+
)
60+
exps = reader.read(batch_size=put_batch_size * 2)
61+
self.assertEqual(len(exps), put_batch_size * 2)
62+
writer.finish()
4463
self.assertRaises(StopIteration, reader.read)
64+
with open(BUFFER_FILE_PATH, "r") as f:
65+
self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2)
66+
67+
def setUp(self):
68+
if os.path.exists(BUFFER_FILE_PATH):
69+
os.remove(BUFFER_FILE_PATH)

tests/buffer/sql_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,21 @@ def test_create_sql_buffer(self) -> None:
4747
for _ in range(total_num // read_batch_size):
4848
exps = sql_reader.read()
4949
self.assertEqual(len(exps), read_batch_size)
50+
51+
# dynamic read/write
52+
sql_writer.write(
53+
[
54+
Experience(
55+
tokens=torch.tensor([float(j) for j in range(i + 1)]),
56+
prompt_length=i,
57+
reward=float(i),
58+
logprobs=torch.tensor([0.1]),
59+
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
60+
)
61+
for i in range(1, put_batch_size * 2 + 1)
62+
]
63+
)
64+
exps = sql_reader.read(batch_size=put_batch_size * 2)
65+
self.assertEqual(len(exps), put_batch_size * 2)
5066
db_wrapper = ray.get_actor("sql-test_buffer")
5167
self.assertIsNotNone(db_wrapper)

tests/explorer/runner_pool_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def setUp(self):
106106
name="test",
107107
storage_type=StorageType.QUEUE,
108108
algorithm_type="ppo",
109+
path="",
109110
)
110111
self.queue = QueueReader(
111112
self.config.buffer.trainer_input.experience_buffer, self.config.buffer

trinity/buffer/buffer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,9 @@ def get_buffer_writer(storage_config: StorageConfig, buffer_config: BufferConfig
6161
from trinity.buffer.writer.queue_writer import QueueWriter
6262

6363
return QueueWriter(storage_config, buffer_config)
64+
elif storage_config.storage_type == StorageType.FILE:
65+
from trinity.buffer.writer.file_writer import JSONWriter
66+
67+
return JSONWriter(storage_config, buffer_config)
6468
else:
6569
raise ValueError(f"{storage_config.storage_type} not supported.")

trinity/buffer/buffer_reader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,7 @@ class BufferReader(ABC):
99
"""Interface of the buffer reader."""
1010

1111
@abstractmethod
12-
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
12+
def read(
13+
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
14+
) -> List:
1315
"""Read from buffer."""

trinity/buffer/queue.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,19 @@
55

66
import ray
77

8+
from trinity.buffer.writer.file_writer import JSONWriter
89
from trinity.buffer.writer.sql_writer import SQLWriter
910
from trinity.common.config import BufferConfig, StorageConfig
1011
from trinity.common.constants import StorageType
12+
from trinity.utils.log import get_logger
13+
14+
15+
def is_database_url(path: str) -> bool:
16+
return any(path.startswith(prefix) for prefix in ["sqlite:///", "postgresql://", "mysql://"])
17+
18+
19+
def is_json_file(path: str) -> bool:
20+
return path.endswith(".json") or path.endswith(".jsonl")
1121

1222

1323
@ray.remote
@@ -17,16 +27,26 @@ class QueueActor:
1727
FINISH_MESSAGE = "$FINISH$"
1828

1929
def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
30+
self.logger = get_logger(__name__)
2031
self.config = config
2132
self.capacity = getattr(config, "capacity", 10000)
2233
self.queue = asyncio.Queue(self.capacity)
23-
if storage_config.path is not None and len(storage_config.path) > 0:
24-
sql_config = deepcopy(storage_config)
25-
sql_config.storage_type = StorageType.SQL
26-
sql_config.wrap_in_ray = False
27-
self.sql_writer = SQLWriter(sql_config, self.config)
34+
st_config = deepcopy(storage_config)
35+
st_config.wrap_in_ray = False
36+
if st_config.path is not None:
37+
if is_database_url(st_config.path):
38+
st_config.storage_type = StorageType.SQL
39+
self.writer = SQLWriter(st_config, self.config)
40+
elif is_json_file(st_config.path):
41+
st_config.storage_type = StorageType.FILE
42+
self.writer = JSONWriter(st_config, self.config)
43+
else:
44+
self.logger.warning("Unknown supported storage path: %s", st_config.path)
45+
self.writer = None
2846
else:
29-
self.sql_writer = None
47+
st_config.storage_type = StorageType.FILE
48+
self.writer = JSONWriter(st_config, self.config)
49+
self.logger.warning(f"Save experiences in {st_config.path}.")
3050

3151
def length(self) -> int:
3252
"""The length of the queue."""
@@ -35,8 +55,8 @@ def length(self) -> int:
3555
async def put_batch(self, exp_list: List) -> None:
3656
"""Put batch of experience."""
3757
await self.queue.put(exp_list)
38-
if self.sql_writer is not None:
39-
self.sql_writer.write(exp_list)
58+
if self.writer is not None:
59+
self.writer.write(exp_list)
4060

4161
async def finish(self) -> None:
4262
"""Stop the queue."""
Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
import os
13
import time
24
from typing import List, Optional
35

@@ -8,9 +10,11 @@
810
from sqlalchemy.pool import NullPool
911

1012
from trinity.buffer.schema import Base, create_dynamic_table
11-
from trinity.buffer.utils import retry_session
13+
from trinity.buffer.utils import default_storage_path, retry_session
1214
from trinity.common.config import BufferConfig, StorageConfig
1315
from trinity.common.constants import ReadStrategy
16+
from trinity.common.experience import Experience
17+
from trinity.common.workflows import Task
1418
from trinity.utils.log import get_logger
1519

1620

@@ -27,6 +31,8 @@ class DBWrapper:
2731

2832
def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
2933
self.logger = get_logger(__name__)
34+
if storage_config.path is None:
35+
storage_config.path = default_storage_path(storage_config, config)
3036
self.engine = create_engine(storage_config.path, poolclass=NullPool)
3137
self.table_model_cls = create_dynamic_table(
3238
storage_config.algorithm_type, storage_config.name
@@ -61,7 +67,9 @@ def write(self, data: list) -> None:
6167
experience_models = [self.table_model_cls.from_experience(exp) for exp in data]
6268
session.add_all(experience_models)
6369

64-
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
70+
def read(
71+
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
72+
) -> List:
6573
if strategy is None:
6674
strategy = ReadStrategy.LFU
6775

@@ -78,7 +86,8 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
7886
raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage")
7987

8088
exp_list = []
81-
while len(exp_list) < self.batch_size:
89+
batch_size = batch_size or self.batch_size
90+
while len(exp_list) < batch_size:
8291
if len(exp_list):
8392
self.logger.info("waiting for experiences...")
8493
time.sleep(1)
@@ -90,7 +99,7 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
9099
session.query(self.table_model_cls)
91100
.filter(self.table_model_cls.reward.isnot(None))
92101
.order_by(*sortOrder) # TODO: very slow
93-
.limit(self.batch_size - len(exp_list))
102+
.limit(batch_size - len(exp_list))
94103
.with_for_update()
95104
.all()
96105
)
@@ -103,3 +112,63 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
103112
self.logger.info(f"first prompt_text = {exp_list[0].prompt_text}")
104113
self.logger.info(f"first response_text = {exp_list[0].response_text}")
105114
return exp_list
115+
116+
117+
class _Encoder(json.JSONEncoder):
118+
def default(self, o):
119+
if isinstance(o, Experience):
120+
return o.to_dict()
121+
if isinstance(o, Task):
122+
return o.to_dict()
123+
return super().default(o)
124+
125+
126+
class FileWrapper:
127+
"""
128+
A wrapper of a local jsonl file.
129+
130+
If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as
131+
a Ray Actor, and provide a remote interface to the local file.
132+
133+
This wrapper is only for writing, if you want to read from the file, use
134+
StorageType.QUEUE instead.
135+
"""
136+
137+
def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
138+
if storage_config.path is None:
139+
storage_config.path = default_storage_path(storage_config, config)
140+
ext = os.path.splitext(storage_config.path)[-1]
141+
if ext != ".jsonl" and ext != ".json":
142+
raise ValueError(
143+
f"File path must end with '.json' or '.jsonl', got {storage_config.path}"
144+
)
145+
self.file = open(storage_config.path, "a", encoding="utf-8")
146+
self.encoder = _Encoder(ensure_ascii=False)
147+
148+
@classmethod
149+
def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
150+
if storage_config.wrap_in_ray:
151+
return (
152+
ray.remote(cls)
153+
.options(
154+
name=f"json-{storage_config.name}",
155+
get_if_exists=True,
156+
)
157+
.remote(storage_config, config)
158+
)
159+
else:
160+
return cls(storage_config, config)
161+
162+
def write(self, data: List) -> None:
163+
for item in data:
164+
json_str = self.encoder.encode(item)
165+
self.file.write(json_str + "\n")
166+
self.file.flush()
167+
168+
def read(self) -> List:
169+
raise NotImplementedError(
170+
"read() is not implemented for FileWrapper, please use QUEUE instead"
171+
)
172+
173+
def finish(self) -> None:
174+
self.file.close()

0 commit comments

Comments
 (0)