Skip to content

Commit ad77ffe

Browse files
authored
Wrap file writer in ray (#82)
1 parent 8c6107d commit ad77ffe

File tree

11 files changed

+215
-58
lines changed

11 files changed

+215
-58
lines changed

tests/buffer/file_test.py

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
1+
import os
12
import unittest
23

3-
from tests.tools import get_template_config, get_unittest_dataset_config
4-
from trinity.buffer.buffer import get_buffer_reader
4+
import ray
55

6+
from tests.tools import (
7+
get_checkpoint_path,
8+
get_template_config,
9+
get_unittest_dataset_config,
10+
)
11+
from trinity.buffer.buffer import get_buffer_reader, get_buffer_writer
12+
from trinity.buffer.utils import default_storage_path
13+
from trinity.common.config import StorageConfig
14+
from trinity.common.constants import StorageType
615

7-
class TestFileReader(unittest.TestCase):
16+
17+
class TestFileBuffer(unittest.TestCase):
818
def test_file_reader(self):
919
"""Test file reader."""
10-
config = get_template_config()
11-
dataset_config = get_unittest_dataset_config("countdown", "train")
12-
config.buffer.explorer_input.taskset = dataset_config
13-
reader = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer)
20+
reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer)
1421

1522
tasks = []
1623
while True:
@@ -20,13 +27,68 @@ def test_file_reader(self):
2027
break
2128
self.assertEqual(len(tasks), 16)
2229

23-
config.buffer.explorer_input.taskset.total_epochs = 2
24-
config.buffer.explorer_input.taskset.index = 4
25-
reader = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer)
30+
# test epoch and offset
31+
self.config.buffer.explorer_input.taskset.total_epochs = 2
32+
self.config.buffer.explorer_input.taskset.index = 4
33+
reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer)
2634
tasks = []
2735
while True:
2836
try:
2937
tasks.extend(reader.read())
3038
except StopIteration:
3139
break
3240
self.assertEqual(len(tasks), 16 * 2 - 4)
41+
42+
# test offset > dataset_len
43+
self.config.buffer.explorer_input.taskset.total_epochs = 3
44+
self.config.buffer.explorer_input.taskset.index = 20
45+
reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer)
46+
tasks = []
47+
while True:
48+
try:
49+
tasks.extend(reader.read())
50+
except StopIteration:
51+
break
52+
self.assertEqual(len(tasks), 16 * 3 - 20)
53+
54+
def test_file_writer(self):
55+
writer = get_buffer_writer(
56+
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
57+
)
58+
writer.write(
59+
[
60+
{"prompt": "hello world"},
61+
{"prompt": "hi"},
62+
]
63+
)
64+
file_wrapper = ray.get_actor("json-test_buffer")
65+
self.assertIsNotNone(file_wrapper)
66+
file_path = default_storage_path(
67+
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
68+
)
69+
with open(file_path, "r") as f:
70+
self.assertEqual(len(f.readlines()), 2)
71+
72+
def setUp(self):
73+
self.config = get_template_config()
74+
self.config.checkpoint_root_dir = get_checkpoint_path()
75+
dataset_config = get_unittest_dataset_config("countdown", "train")
76+
self.config.buffer.explorer_input.taskset = dataset_config
77+
self.config.buffer.trainer_input.experience_buffer = StorageConfig(
78+
name="test_buffer", storage_type=StorageType.FILE
79+
)
80+
self.config.buffer.trainer_input.experience_buffer.name = "test_buffer"
81+
self.config.buffer.cache_dir = os.path.join(
82+
self.config.checkpoint_root_dir, self.config.project, self.config.name, "buffer"
83+
)
84+
os.makedirs(self.config.buffer.cache_dir, exist_ok=True)
85+
if os.path.exists(
86+
default_storage_path(
87+
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
88+
)
89+
):
90+
os.remove(
91+
default_storage_path(
92+
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
93+
)
94+
)

tests/buffer/queue_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from trinity.common.constants import AlgorithmType, StorageType
1010
from trinity.common.experience import Experience
1111

12-
file_path = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl")
12+
BUFFER_FILE_PATH = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl")
1313

1414

1515
class TestQueueBuffer(RayUnittestBase):
@@ -21,7 +21,7 @@ def test_queue_buffer(self):
2121
name="test_buffer",
2222
algorithm_type=AlgorithmType.PPO,
2323
storage_type=StorageType.QUEUE,
24-
path=file_path,
24+
path=BUFFER_FILE_PATH,
2525
)
2626
config = BufferConfig(
2727
max_retry_times=3,
@@ -61,9 +61,9 @@ def test_queue_buffer(self):
6161
self.assertEqual(len(exps), put_batch_size * 2)
6262
writer.finish()
6363
self.assertRaises(StopIteration, reader.read)
64-
with open(file_path, "r") as f:
64+
with open(BUFFER_FILE_PATH, "r") as f:
6565
self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2)
6666

6767
def setUp(self):
68-
if os.path.exists(file_path):
69-
os.remove(file_path)
68+
if os.path.exists(BUFFER_FILE_PATH):
69+
os.remove(BUFFER_FILE_PATH)

tests/explorer/runner_pool_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def setUp(self):
106106
name="test",
107107
storage_type=StorageType.QUEUE,
108108
algorithm_type=AlgorithmType.PPO,
109+
path="",
109110
)
110111
self.queue = QueueReader(
111112
self.config.buffer.trainer_input.experience_buffer, self.config.buffer

trinity/buffer/buffer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,5 +61,9 @@ def get_buffer_writer(storage_config: StorageConfig, buffer_config: BufferConfig
6161
from trinity.buffer.writer.queue_writer import QueueWriter
6262

6363
return QueueWriter(storage_config, buffer_config)
64+
elif storage_config.storage_type == StorageType.FILE:
65+
from trinity.buffer.writer.file_writer import JSONWriter
66+
67+
return JSONWriter(storage_config, buffer_config)
6468
else:
6569
raise ValueError(f"{storage_config.storage_type} not supported.")

trinity/buffer/queue.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from trinity.buffer.writer.sql_writer import SQLWriter
1010
from trinity.common.config import BufferConfig, StorageConfig
1111
from trinity.common.constants import StorageType
12+
from trinity.utils.log import get_logger
1213

1314

1415
def is_database_url(path: str) -> bool:
@@ -26,25 +27,26 @@ class QueueActor:
2627
FINISH_MESSAGE = "$FINISH$"
2728

2829
def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
30+
self.logger = get_logger(__name__)
2931
self.config = config
3032
self.capacity = getattr(config, "capacity", 10000)
3133
self.queue = asyncio.Queue(self.capacity)
32-
if storage_config.path is not None and len(storage_config.path) > 0:
33-
if is_database_url(storage_config.path):
34-
storage_config.storage_type = StorageType.SQL
35-
sql_config = deepcopy(storage_config)
36-
sql_config.storage_type = StorageType.SQL
37-
sql_config.wrap_in_ray = False
38-
self.writer = SQLWriter(sql_config, self.config)
39-
elif is_json_file(storage_config.path):
40-
storage_config.storage_type = StorageType.FILE
41-
json_config = deepcopy(storage_config)
42-
json_config.storage_type = StorageType.FILE
43-
self.writer = JSONWriter(json_config, self.config)
34+
st_config = deepcopy(storage_config)
35+
st_config.wrap_in_ray = False
36+
if st_config.path is not None:
37+
if is_database_url(st_config.path):
38+
st_config.storage_type = StorageType.SQL
39+
self.writer = SQLWriter(st_config, self.config)
40+
elif is_json_file(st_config.path):
41+
st_config.storage_type = StorageType.FILE
42+
self.writer = JSONWriter(st_config, self.config)
4443
else:
44+
self.logger.warning("Unknown supported storage path: %s", st_config.path)
4545
self.writer = None
4646
else:
47-
self.writer = None
47+
st_config.storage_type = StorageType.FILE
48+
self.writer = JSONWriter(st_config, self.config)
49+
self.logger.warning(f"Save experiences in {st_config.path}.")
4850

4951
def length(self) -> int:
5052
"""The length of the queue."""
Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
import os
13
import time
24
from typing import List, Optional
35

@@ -8,9 +10,11 @@
810
from sqlalchemy.pool import NullPool
911

1012
from trinity.buffer.schema import Base, create_dynamic_table
11-
from trinity.buffer.utils import retry_session
13+
from trinity.buffer.utils import default_storage_path, retry_session
1214
from trinity.common.config import BufferConfig, StorageConfig
1315
from trinity.common.constants import ReadStrategy
16+
from trinity.common.experience import Experience
17+
from trinity.common.workflows import Task
1418
from trinity.utils.log import get_logger
1519

1620

@@ -27,6 +31,8 @@ class DBWrapper:
2731

2832
def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
2933
self.logger = get_logger(__name__)
34+
if storage_config.path is None:
35+
storage_config.path = default_storage_path(storage_config, config)
3036
self.engine = create_engine(storage_config.path, poolclass=NullPool)
3137
self.table_model_cls = create_dynamic_table(
3238
storage_config.algorithm_type, storage_config.name
@@ -106,3 +112,60 @@ def read(
106112
self.logger.info(f"first prompt_text = {exp_list[0].prompt_text}")
107113
self.logger.info(f"first response_text = {exp_list[0].response_text}")
108114
return exp_list
115+
116+
117+
class _Encoder(json.JSONEncoder):
118+
def default(self, o):
119+
if isinstance(o, Experience):
120+
return o.to_dict()
121+
if isinstance(o, Task):
122+
return o.to_dict()
123+
return super().default(o)
124+
125+
126+
class FileWrapper:
127+
"""
128+
A wrapper of a local jsonl file.
129+
130+
If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as
131+
a Ray Actor, and provide a remote interface to the local file.
132+
133+
This wrapper is only for writing, if you want to read from the file, use
134+
StorageType.QUEUE instead.
135+
"""
136+
137+
def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
138+
if storage_config.path is None:
139+
storage_config.path = default_storage_path(storage_config, config)
140+
ext = os.path.splitext(storage_config.path)[-1]
141+
if ext != ".jsonl" and ext != ".json":
142+
raise ValueError(
143+
f"File path must end with '.json' or '.jsonl', got {storage_config.path}"
144+
)
145+
self.file = open(storage_config.path, "a", encoding="utf-8")
146+
self.encoder = _Encoder(ensure_ascii=False)
147+
148+
@classmethod
149+
def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
150+
if storage_config.wrap_in_ray:
151+
return (
152+
ray.remote(cls)
153+
.options(
154+
name=f"json-{storage_config.name}",
155+
get_if_exists=True,
156+
)
157+
.remote(storage_config, config)
158+
)
159+
else:
160+
return cls(storage_config, config)
161+
162+
def write(self, data: List) -> None:
163+
for item in data:
164+
json_str = self.encoder.encode(item)
165+
self.file.write(json_str + "\n")
166+
self.file.flush()
167+
168+
def read(self) -> List:
169+
raise NotImplementedError(
170+
"read() is not implemented for FileWrapper, please use QUEUE instead"
171+
)

trinity/buffer/reader/sql_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import ray
66

77
from trinity.buffer.buffer_reader import BufferReader
8-
from trinity.buffer.db_wrapper import DBWrapper
8+
from trinity.buffer.ray_wrapper import DBWrapper
99
from trinity.common.config import BufferConfig, StorageConfig
1010
from trinity.common.constants import ReadStrategy, StorageType
1111

trinity/buffer/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import os
12
import time
23
from contextlib import contextmanager
34

5+
from trinity.common.config import BufferConfig, StorageConfig
6+
from trinity.common.constants import StorageType
47
from trinity.utils.log import get_logger
58

69
logger = get_logger(__name__)
@@ -31,3 +34,18 @@ def retry_session(session_maker, max_retry_times: int, max_retry_interval: float
3134
raise e
3235
finally:
3336
session.close()
37+
38+
39+
def default_storage_path(storage_config: StorageConfig, buffer_config: BufferConfig) -> str:
40+
if buffer_config.cache_dir is None:
41+
raise ValueError("Please call config.check_and_update() before using.")
42+
if storage_config.storage_type == StorageType.SQL:
43+
return "sqlite:///" + os.path.join(
44+
buffer_config.cache_dir,
45+
f"{storage_config.name}.db",
46+
)
47+
else:
48+
return os.path.join(
49+
buffer_config.cache_dir,
50+
f"{storage_config.name}.jsonl",
51+
)
Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,27 @@
1-
import json
2-
import os
31
from typing import List
42

3+
import ray
4+
55
from trinity.buffer.buffer_writer import BufferWriter
6+
from trinity.buffer.ray_wrapper import FileWrapper
67
from trinity.common.config import BufferConfig, StorageConfig
78
from trinity.common.constants import StorageType
8-
from trinity.common.experience import Experience
9-
from trinity.common.workflows import Task
10-
11-
12-
class _Encoder(json.JSONEncoder):
13-
def default(self, o):
14-
if isinstance(o, Experience):
15-
return o.to_dict()
16-
if isinstance(o, Task):
17-
return o.to_dict()
18-
return super().default(o)
199

2010

2111
class JSONWriter(BufferWriter):
2212
def __init__(self, meta: StorageConfig, config: BufferConfig):
2313
assert meta.storage_type == StorageType.FILE
24-
if meta.path is None:
25-
raise ValueError("File path cannot be None for RawFileWriter")
26-
ext = os.path.splitext(meta.path)[-1]
27-
if ext != ".jsonl" and ext != ".json":
28-
raise ValueError(f"File path must end with .json or .jsonl, got {meta.path}")
29-
self.file = open(meta.path, "a", encoding="utf-8")
30-
self.encoder = _Encoder(ensure_ascii=False)
14+
self.writer = FileWrapper.get_wrapper(meta, config)
15+
self.wrap_in_ray = meta.wrap_in_ray
3116

3217
def write(self, data: List) -> None:
33-
for item in data:
34-
json_str = self.encoder.encode(item)
35-
self.file.write(json_str + "\n")
36-
self.file.flush()
18+
if self.wrap_in_ray:
19+
ray.get(self.writer.write.remote(data))
20+
else:
21+
self.writer.write(data)
3722

3823
def finish(self):
39-
self.file.close()
24+
if self.wrap_in_ray:
25+
ray.get(self.writer.finish.remote())
26+
else:
27+
self.writer.finish()

trinity/buffer/writer/sql_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import ray
44

55
from trinity.buffer.buffer_writer import BufferWriter
6-
from trinity.buffer.db_wrapper import DBWrapper
6+
from trinity.buffer.ray_wrapper import DBWrapper
77
from trinity.common.config import BufferConfig, StorageConfig
88
from trinity.common.constants import StorageType
99

0 commit comments

Comments
 (0)