1+ import os
12import unittest
23
3- from tests .tools import get_template_config , get_unittest_dataset_config
4- from trinity .buffer .buffer import get_buffer_reader
4+ import ray
55
6+ from tests .tools import (
7+ get_checkpoint_path ,
8+ get_template_config ,
9+ get_unittest_dataset_config ,
10+ )
11+ from trinity .buffer .buffer import get_buffer_reader , get_buffer_writer
12+ from trinity .buffer .utils import default_storage_path
13+ from trinity .common .config import StorageConfig
14+ from trinity .common .constants import StorageType
615
7- class TestFileReader (unittest .TestCase ):
16+
17+ class TestFileBuffer (unittest .TestCase ):
818 def test_file_reader (self ):
919 """Test file reader."""
10- config = get_template_config ()
11- dataset_config = get_unittest_dataset_config ("countdown" , "train" )
12- config .buffer .explorer_input .taskset = dataset_config
13- reader = get_buffer_reader (config .buffer .explorer_input .taskset , config .buffer )
20+ reader = get_buffer_reader (self .config .buffer .explorer_input .taskset , self .config .buffer )
1421
1522 tasks = []
1623 while True :
@@ -20,13 +27,68 @@ def test_file_reader(self):
2027 break
2128 self .assertEqual (len (tasks ), 16 )
2229
23- config .buffer .explorer_input .taskset .total_epochs = 2
24- config .buffer .explorer_input .taskset .index = 4
25- reader = get_buffer_reader (config .buffer .explorer_input .taskset , config .buffer )
30+ # test epoch and offset
31+ self .config .buffer .explorer_input .taskset .total_epochs = 2
32+ self .config .buffer .explorer_input .taskset .index = 4
33+ reader = get_buffer_reader (self .config .buffer .explorer_input .taskset , self .config .buffer )
2634 tasks = []
2735 while True :
2836 try :
2937 tasks .extend (reader .read ())
3038 except StopIteration :
3139 break
3240 self .assertEqual (len (tasks ), 16 * 2 - 4 )
41+
42+ # test offset > dataset_len
43+ self .config .buffer .explorer_input .taskset .total_epochs = 3
44+ self .config .buffer .explorer_input .taskset .index = 20
45+ reader = get_buffer_reader (self .config .buffer .explorer_input .taskset , self .config .buffer )
46+ tasks = []
47+ while True :
48+ try :
49+ tasks .extend (reader .read ())
50+ except StopIteration :
51+ break
52+ self .assertEqual (len (tasks ), 16 * 3 - 20 )
53+
54+ def test_file_writer (self ):
55+ writer = get_buffer_writer (
56+ self .config .buffer .trainer_input .experience_buffer , self .config .buffer
57+ )
58+ writer .write (
59+ [
60+ {"prompt" : "hello world" },
61+ {"prompt" : "hi" },
62+ ]
63+ )
64+ file_wrapper = ray .get_actor ("json-test_buffer" )
65+ self .assertIsNotNone (file_wrapper )
66+ file_path = default_storage_path (
67+ self .config .buffer .trainer_input .experience_buffer , self .config .buffer
68+ )
69+ with open (file_path , "r" ) as f :
70+ self .assertEqual (len (f .readlines ()), 2 )
71+
72+ def setUp (self ):
73+ self .config = get_template_config ()
74+ self .config .checkpoint_root_dir = get_checkpoint_path ()
75+ dataset_config = get_unittest_dataset_config ("countdown" , "train" )
76+ self .config .buffer .explorer_input .taskset = dataset_config
77+ self .config .buffer .trainer_input .experience_buffer = StorageConfig (
78+ name = "test_buffer" , storage_type = StorageType .FILE
79+ )
80+ self .config .buffer .trainer_input .experience_buffer .name = "test_buffer"
81+ self .config .buffer .cache_dir = os .path .join (
82+ self .config .checkpoint_root_dir , self .config .project , self .config .name , "buffer"
83+ )
84+ os .makedirs (self .config .buffer .cache_dir , exist_ok = True )
85+ if os .path .exists (
86+ default_storage_path (
87+ self .config .buffer .trainer_input .experience_buffer , self .config .buffer
88+ )
89+ ):
90+ os .remove (
91+ default_storage_path (
92+ self .config .buffer .trainer_input .experience_buffer , self .config .buffer
93+ )
94+ )
0 commit comments