|
1 | | -import os |
2 | | - |
3 | | -from tests.tools import RayUnittestBase |
4 | | -from trinity.buffer.reader.file_reader import RawDataReader |
5 | | -from trinity.buffer.writer.file_writer import RawFileWriter |
6 | | -from trinity.common.config import StorageConfig |
7 | | -from trinity.common.constants import StorageType |
8 | | - |
9 | | - |
10 | | -class TestFileBuffer(RayUnittestBase): |
11 | | - |
12 | | - temp_output_path = 'tmp/test_file_buffer/' |
13 | | - |
14 | | - @classmethod |
15 | | - def setUpClass(cls): |
16 | | - os.makedirs(cls.temp_output_path, exist_ok=True) |
17 | | - |
18 | | - @classmethod |
19 | | - def tearDownClass(cls): |
20 | | - super().tearDownClass() |
21 | | - if os.path.exists(cls.temp_output_path): |
22 | | - os.system(f'rm -rf {cls.temp_output_path}') |
23 | | - |
24 | | - def test_file_buffer(self): |
25 | | - meta = StorageConfig( |
26 | | - name="test_buffer", |
27 | | - path=os.path.join(self.temp_output_path, "buffer.jsonl"), |
28 | | - storage_type=StorageType.FILE, |
29 | | - raw=True, |
30 | | - ) |
31 | | - data = [ |
32 | | - {'key1': 1, 'key2': 2}, |
33 | | - {'key1': 3, 'key2': 4}, |
34 | | - {'key1': 5, 'key2': 6}, |
35 | | - {'key1': 7, 'key2': 8}, |
36 | | - ] |
37 | | - |
38 | | - # test writer |
39 | | - writer = RawFileWriter(meta, None) |
40 | | - writer.write(data) |
41 | | - writer.finish() |
42 | | - |
43 | | - # test reader |
44 | | - meta.path = self.temp_output_path |
45 | | - reader = RawDataReader(meta, None) |
46 | | - loaded_data = reader.read() |
47 | | - self.assertEqual(len(loaded_data), 4) |
48 | | - self.assertEqual(loaded_data, data) |
49 | | - self.assertRaises(StopIteration, reader.read) |
| 1 | +import unittest |
| 2 | + |
| 3 | +from tests.tools import get_template_config, get_unittest_dataset_config |
| 4 | +from trinity.buffer.buffer import get_buffer_reader |
| 5 | + |
| 6 | + |
| 7 | +class TestFileReader(unittest.TestCase): |
| 8 | + def test_file_reader(self): |
| 9 | + """Test file reader.""" |
| 10 | + config = get_template_config() |
| 11 | + dataset_config = get_unittest_dataset_config("countdown", "train") |
| 12 | + config.buffer.explorer_input.taskset = dataset_config |
| 13 | + reader = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer) |
| 14 | + |
| 15 | + tasks = [] |
| 16 | + while True: |
| 17 | + try: |
| 18 | + tasks.extend(reader.read()) |
| 19 | + except StopIteration: |
| 20 | + break |
| 21 | + self.assertEqual(len(tasks), 16) |
| 22 | + |
| 23 | + config.buffer.explorer_input.taskset.total_epochs = 2 |
| 24 | + config.buffer.explorer_input.taskset.index = 4 |
| 25 | + reader = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer) |
| 26 | + tasks = [] |
| 27 | + while True: |
| 28 | + try: |
| 29 | + tasks.extend(reader.read()) |
| 30 | + except StopIteration: |
| 31 | + break |
| 32 | + self.assertEqual(len(tasks), 16 * 2 - 4) |
0 commit comments