33
44import torch
55
6- from tests .tools import RayUnittestBase
6+ from tests .tools import RayUnittestBaseAysnc
77from trinity .buffer .reader .queue_reader import QueueReader
88from trinity .buffer .writer .queue_writer import QueueWriter
99from trinity .common .config import BufferConfig , StorageConfig
1313BUFFER_FILE_PATH = os .path .join (os .path .dirname (__file__ ), "test_queue_buffer.jsonl" )
1414
1515
16- class TestQueueBuffer (RayUnittestBase ):
17- def test_queue_buffer (self ):
16+ class TestQueueBuffer (RayUnittestBaseAysnc ):
17+ async def test_queue_buffer (self ):
1818 total_num = 8
1919 put_batch_size = 2
2020 read_batch_size = 4
@@ -32,7 +32,7 @@ def test_queue_buffer(self):
3232 )
3333 writer = QueueWriter (meta , config )
3434 reader = QueueReader (meta , config )
35- self .assertEqual (writer .acquire (), 1 )
35+ self .assertEqual (await writer .acquire (), 1 )
3636 exps = [
3737 Experience (
3838 tokens = torch .tensor ([float (j ) for j in range (i + 1 )]),
@@ -43,7 +43,7 @@ def test_queue_buffer(self):
4343 for i in range (1 , put_batch_size + 1 )
4444 ]
4545 for _ in range (total_num // put_batch_size ):
46- writer .write (exps )
46+ await writer .write_async (exps )
4747 for _ in range (total_num // read_batch_size ):
4848 exps = reader .read ()
4949 self .assertEqual (len (exps ), read_batch_size )
@@ -62,7 +62,7 @@ def test_queue_buffer(self):
6262 )
6363 exps = reader .read (batch_size = put_batch_size * 2 )
6464 self .assertEqual (len (exps ), put_batch_size * 2 )
65- self .assertEqual (writer .release (), 0 )
65+ self .assertEqual (await writer .release (), 0 )
6666 self .assertRaises (StopIteration , reader .read )
6767 with open (BUFFER_FILE_PATH , "r" ) as f :
6868 self .assertEqual (len (f .readlines ()), total_num + put_batch_size * 2 )
0 commit comments