Skip to content

Commit 553b861

Browse files
committed
buffer writer support async
1 parent b63ee11 commit 553b861

File tree

10 files changed

+71
-34
lines changed

10 files changed

+71
-34
lines changed

tests/buffer/file_test.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from trinity.common.constants import StorageType
1717

1818

19-
class TestFileBuffer(unittest.TestCase):
19+
class TestFileBuffer(unittest.IsolatedAsyncioTestCase):
2020
temp_output_path = "tmp/test_file_buffer/"
2121

2222
@classmethod
@@ -93,25 +93,31 @@ def test_file_reader(self):
9393
break
9494
self.assertEqual(len(tasks), 16 * 3 - 20)
9595

96-
def test_file_writer(self):
96+
async def test_file_writer(self):
9797
writer = get_buffer_writer(
9898
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
9999
)
100-
writer.acquire()
100+
await writer.acquire()
101101
writer.write(
102102
[
103103
{"prompt": "hello world"},
104104
{"prompt": "hi"},
105105
]
106106
)
107-
writer.release()
107+
await writer.write_async(
108+
[
109+
{"prompt": "My name is"},
110+
{"prompt": "What is your name?"},
111+
]
112+
)
113+
await writer.release()
108114
file_wrapper = ray.get_actor("json-test_buffer")
109115
self.assertIsNotNone(file_wrapper)
110116
file_path = default_storage_path(
111117
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
112118
)
113119
with open(file_path, "r") as f:
114-
self.assertEqual(len(f.readlines()), 2)
120+
self.assertEqual(len(f.readlines()), 4)
115121

116122
def setUp(self):
117123
self.config = get_template_config()

tests/buffer/queue_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55

6-
from tests.tools import RayUnittestBase
6+
from tests.tools import RayUnittestBaseAysnc
77
from trinity.buffer.reader.queue_reader import QueueReader
88
from trinity.buffer.writer.queue_writer import QueueWriter
99
from trinity.common.config import BufferConfig, StorageConfig
@@ -13,8 +13,8 @@
1313
BUFFER_FILE_PATH = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl")
1414

1515

16-
class TestQueueBuffer(RayUnittestBase):
17-
def test_queue_buffer(self):
16+
class TestQueueBuffer(RayUnittestBaseAysnc):
17+
async def test_queue_buffer(self):
1818
total_num = 8
1919
put_batch_size = 2
2020
read_batch_size = 4
@@ -32,7 +32,7 @@ def test_queue_buffer(self):
3232
)
3333
writer = QueueWriter(meta, config)
3434
reader = QueueReader(meta, config)
35-
self.assertEqual(writer.acquire(), 1)
35+
self.assertEqual(await writer.acquire(), 1)
3636
exps = [
3737
Experience(
3838
tokens=torch.tensor([float(j) for j in range(i + 1)]),
@@ -43,7 +43,7 @@ def test_queue_buffer(self):
4343
for i in range(1, put_batch_size + 1)
4444
]
4545
for _ in range(total_num // put_batch_size):
46-
writer.write(exps)
46+
await writer.write_async(exps)
4747
for _ in range(total_num // read_batch_size):
4848
exps = reader.read()
4949
self.assertEqual(len(exps), read_batch_size)
@@ -62,7 +62,7 @@ def test_queue_buffer(self):
6262
)
6363
exps = reader.read(batch_size=put_batch_size * 2)
6464
self.assertEqual(len(exps), put_batch_size * 2)
65-
self.assertEqual(writer.release(), 0)
65+
self.assertEqual(await writer.release(), 0)
6666
self.assertRaises(StopIteration, reader.read)
6767
with open(BUFFER_FILE_PATH, "r") as f:
6868
self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2)

tests/buffer/sql_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os
2-
import unittest
32

43
import ray
54
import torch
65

6+
from tests.tools import RayUnittestBaseAysnc
77
from trinity.buffer.reader.sql_reader import SQLReader
88
from trinity.buffer.writer.sql_writer import SQLWriter
99
from trinity.common.config import BufferConfig, StorageConfig
@@ -13,8 +13,8 @@
1313
db_path = os.path.join(os.path.dirname(__file__), "test.db")
1414

1515

16-
class TestSQLBuffer(unittest.TestCase):
17-
def test_create_sql_buffer(self) -> None:
16+
class TestSQLBuffer(RayUnittestBaseAysnc):
17+
async def test_create_sql_buffer(self) -> None:
1818
total_num = 8
1919
put_batch_size = 2
2020
read_batch_size = 4
@@ -42,9 +42,9 @@ def test_create_sql_buffer(self) -> None:
4242
)
4343
for i in range(1, put_batch_size + 1)
4444
]
45-
self.assertEqual(sql_writer.acquire(), 1)
45+
self.assertEqual(await sql_writer.acquire(), 1)
4646
for _ in range(total_num // put_batch_size):
47-
sql_writer.write(exps)
47+
await sql_writer.write_async(exps)
4848
for _ in range(total_num // read_batch_size):
4949
exps = sql_reader.read()
5050
self.assertEqual(len(exps), read_batch_size)
@@ -66,5 +66,5 @@ def test_create_sql_buffer(self) -> None:
6666
self.assertEqual(len(exps), put_batch_size * 2)
6767
db_wrapper = ray.get_actor("sql-test_buffer")
6868
self.assertIsNotNone(db_wrapper)
69-
self.assertEqual(sql_writer.release(), 0)
69+
self.assertEqual(await sql_writer.release(), 0)
7070
self.assertRaises(StopIteration, sql_reader.read)

tests/tools.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,13 @@ def setUpClass(cls):
182182
@classmethod
183183
def tearDownClass(cls):
184184
ray.shutdown(_exiting_interpreter=True)
185+
186+
187+
class RayUnittestBaseAysnc(unittest.IsolatedAsyncioTestCase):
188+
@classmethod
189+
def setUpClass(cls):
190+
ray.init(ignore_reinit_error=True, namespace="trinity_unittest")
191+
192+
@classmethod
193+
def tearDownClass(cls):
194+
ray.shutdown(_exiting_interpreter=True)

trinity/buffer/buffer_writer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,19 @@ def write(self, data: List) -> None:
1111
"""Write to buffer."""
1212

1313
@abstractmethod
14-
def acquire(self) -> int:
14+
async def write_async(self, data: List) -> None:
15+
"""Write to buffer asynchronously."""
16+
17+
@abstractmethod
18+
async def acquire(self) -> int:
1519
"""Acquire the buffer writer.
1620
1721
Returns:
1822
`int`: The reference count of the buffer after acquiring.
1923
"""
2024

2125
@abstractmethod
22-
def release(self) -> int:
26+
async def release(self) -> int:
2327
"""Release the buffer writer. After release, the buffer writer can not be used again.
2428
2529
Returns:

trinity/buffer/queue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ async def release(self) -> int:
5757
self.ref_count -= 1
5858
if self.ref_count <= 0:
5959
await self.queue.put(self.FINISH_MESSAGE)
60-
self.writer.release()
60+
await self.writer.release()
6161
return self.ref_count
6262

6363
def length(self) -> int:

trinity/buffer/writer/file_writer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,21 @@ def write(self, data: List) -> None:
2020
else:
2121
self.writer.write(data)
2222

23-
def acquire(self) -> int:
23+
async def write_async(self, data):
2424
if self.wrap_in_ray:
25-
return ray.get(self.writer.acquire.remote())
25+
await self.writer.write.remote(data)
26+
else:
27+
self.writer.write(data)
28+
29+
async def acquire(self) -> int:
30+
if self.wrap_in_ray:
31+
return await self.writer.acquire.remote()
2632
else:
2733
return 0
2834

29-
def release(self) -> int:
35+
async def release(self) -> int:
3036
if self.wrap_in_ray:
31-
return ray.get(self.writer.release.remote())
37+
return await self.writer.release.remote()
3238
else:
3339
self.writer.release()
3440
return 0

trinity/buffer/writer/queue_writer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
2323
def write(self, data: List) -> None:
2424
ray.get(self.queue.put_batch.remote(data))
2525

26-
def acquire(self) -> int:
27-
return ray.get(self.queue.acquire.remote())
26+
async def write_async(self, data):
27+
return await self.queue.put_batch.remote(data)
2828

29-
def release(self) -> int:
30-
return ray.get(self.queue.release.remote())
29+
async def acquire(self) -> int:
30+
return await self.queue.acquire.remote()
31+
32+
async def release(self) -> int:
33+
return await self.queue.release.remote()

trinity/buffer/writer/sql_writer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,21 @@ def write(self, data: list) -> None:
2323
else:
2424
self.db_wrapper.write(data)
2525

26-
def acquire(self) -> int:
26+
async def write_async(self, data):
2727
if self.wrap_in_ray:
28-
return ray.get(self.db_wrapper.acquire.remote())
28+
await self.db_wrapper.write.remote(data)
29+
else:
30+
self.db_wrapper.write(data)
31+
32+
async def acquire(self) -> int:
33+
if self.wrap_in_ray:
34+
return await self.db_wrapper.acquire.remote()
2935
else:
3036
return 0
3137

32-
def release(self) -> int:
38+
async def release(self) -> int:
3339
if self.wrap_in_ray:
34-
return ray.get(self.db_wrapper.release.remote())
40+
return await self.db_wrapper.release.remote()
3541
else:
3642
self.db_wrapper.release()
3743
return 0

trinity/explorer/explorer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ def __init__(self, config: Config):
4343
self.config = config
4444
self.algorithm_manager = AlgorithmManager(config)
4545
self.models, self.auxiliary_models = create_inference_models(config)
46+
self.experience_buffer = None
4647
if self.config.mode != "bench":
4748
self.experience_buffer = get_buffer_writer(
4849
self.config.buffer.explorer_output, # type: ignore
4950
self.config.buffer,
5051
)
51-
self.experience_buffer.acquire()
5252
self.config.buffer.explorer_input.taskset.index = explorer_meta.get("latest_task_index", 0)
5353
self.taskset = get_buffer_reader(
5454
self.config.buffer.explorer_input.taskset, self.config.buffer
@@ -169,6 +169,8 @@ async def prepare(self) -> None:
169169
asyncio.create_task(self.setup_weight_sync_group(master_address, master_port))
170170
)
171171
asyncio.gather(*futures, return_exceptions=True)
172+
if self.experience_buffer:
173+
await self.experience_buffer.acquire()
172174
if self.config.explorer.eval_on_startup and self.explore_step_num == 0:
173175
self.eval()
174176

@@ -217,7 +219,7 @@ async def explore_step(self) -> bool:
217219
self.logger.warning("No more tasks to explore. Stop exploring.")
218220
await self.save_checkpoint(sync_weight=False)
219221
self.status = RunningStatus.STOPPED
220-
self.experience_buffer.release()
222+
await self.experience_buffer.release()
221223
return False
222224
self.scheduler.schedule(tasks, batch_id=self.explore_step_num + 1)
223225
self.explore_step_num += 1

0 commit comments

Comments
 (0)