diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index 2fcf1242ce..363a4939ad 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -1,16 +1,23 @@ +import os import unittest -from tests.tools import get_template_config, get_unittest_dataset_config -from trinity.buffer.buffer import get_buffer_reader +import ray +from tests.tools import ( + get_checkpoint_path, + get_template_config, + get_unittest_dataset_config, +) +from trinity.buffer.buffer import get_buffer_reader, get_buffer_writer +from trinity.buffer.utils import default_storage_path +from trinity.common.config import StorageConfig +from trinity.common.constants import StorageType -class TestFileReader(unittest.TestCase): + +class TestFileBuffer(unittest.TestCase): def test_file_reader(self): """Test file reader.""" - config = get_template_config() - dataset_config = get_unittest_dataset_config("countdown", "train") - config.buffer.explorer_input.taskset = dataset_config - reader = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer) + reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer) tasks = [] while True: @@ -20,9 +27,10 @@ def test_file_reader(self): break self.assertEqual(len(tasks), 16) - config.buffer.explorer_input.taskset.total_epochs = 2 - config.buffer.explorer_input.taskset.index = 4 - reader = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer) + # test epoch and offset + self.config.buffer.explorer_input.taskset.total_epochs = 2 + self.config.buffer.explorer_input.taskset.index = 4 + reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer) tasks = [] while True: try: @@ -30,3 +38,57 @@ def test_file_reader(self): except StopIteration: break self.assertEqual(len(tasks), 16 * 2 - 4) + + # test offset > dataset_len + self.config.buffer.explorer_input.taskset.total_epochs = 3 + self.config.buffer.explorer_input.taskset.index = 20 + reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer) + tasks = [] + while True: + try: + tasks.extend(reader.read()) + except StopIteration: + break + self.assertEqual(len(tasks), 16 * 3 - 20) + + def test_file_writer(self): + writer = get_buffer_writer( + self.config.buffer.trainer_input.experience_buffer, self.config.buffer + ) + writer.write( + [ + {"prompt": "hello world"}, + {"prompt": "hi"}, + ] + ) + file_wrapper = ray.get_actor("json-test_buffer") + self.assertIsNotNone(file_wrapper) + file_path = default_storage_path( + self.config.buffer.trainer_input.experience_buffer, self.config.buffer + ) + with open(file_path, "r") as f: + self.assertEqual(len(f.readlines()), 2) + + def setUp(self): + self.config = get_template_config() + self.config.checkpoint_root_dir = get_checkpoint_path() + dataset_config = get_unittest_dataset_config("countdown", "train") + self.config.buffer.explorer_input.taskset = dataset_config + self.config.buffer.trainer_input.experience_buffer = StorageConfig( + name="test_buffer", storage_type=StorageType.FILE + ) + self.config.buffer.trainer_input.experience_buffer.name = "test_buffer" + self.config.buffer.cache_dir = os.path.join( + self.config.checkpoint_root_dir, self.config.project, self.config.name, "buffer" + ) + os.makedirs(self.config.buffer.cache_dir, exist_ok=True) + if os.path.exists( + default_storage_path( + self.config.buffer.trainer_input.experience_buffer, self.config.buffer + ) + ): + os.remove( + default_storage_path( + self.config.buffer.trainer_input.experience_buffer, self.config.buffer + ) + ) diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index 262a2bcd3e..4f3947c795 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -9,7 +9,7 @@ from trinity.common.constants import AlgorithmType, StorageType from trinity.common.experience import Experience -file_path = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl") +BUFFER_FILE_PATH = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl") class TestQueueBuffer(RayUnittestBase): @@ -21,7 +21,7 @@ def test_queue_buffer(self): name="test_buffer", algorithm_type=AlgorithmType.PPO, storage_type=StorageType.QUEUE, - path=file_path, + path=BUFFER_FILE_PATH, ) config = BufferConfig( max_retry_times=3, @@ -61,9 +61,9 @@ def test_queue_buffer(self): self.assertEqual(len(exps), put_batch_size * 2) writer.finish() self.assertRaises(StopIteration, reader.read) - with open(file_path, "r") as f: + with open(BUFFER_FILE_PATH, "r") as f: self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2) def setUp(self): - if os.path.exists(file_path): - os.remove(file_path) + if os.path.exists(BUFFER_FILE_PATH): + os.remove(BUFFER_FILE_PATH) diff --git a/tests/explorer/runner_pool_test.py b/tests/explorer/runner_pool_test.py index 285961dc37..eaee55b40a 100644 --- a/tests/explorer/runner_pool_test.py +++ b/tests/explorer/runner_pool_test.py @@ -106,6 +106,7 @@ def setUp(self): name="test", storage_type=StorageType.QUEUE, algorithm_type=AlgorithmType.PPO, + path="", ) self.queue = QueueReader( self.config.buffer.trainer_input.experience_buffer, self.config.buffer diff --git a/trinity/buffer/buffer.py b/trinity/buffer/buffer.py index 32a5fb85a8..69f40a8d70 100644 --- a/trinity/buffer/buffer.py +++ b/trinity/buffer/buffer.py @@ -61,5 +61,9 @@ def get_buffer_writer(storage_config: StorageConfig, buffer_config: BufferConfig from trinity.buffer.writer.queue_writer import QueueWriter return QueueWriter(storage_config, buffer_config) + elif storage_config.storage_type == StorageType.FILE: + from trinity.buffer.writer.file_writer import JSONWriter + + return JSONWriter(storage_config, buffer_config) else: raise ValueError(f"{storage_config.storage_type} not supported.") diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py index c6a54650aa..f0ddea46c9 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/queue.py @@ -9,6 +9,7 @@ from trinity.buffer.writer.sql_writer import SQLWriter from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import StorageType +from trinity.utils.log import get_logger def is_database_url(path: str) -> bool: @@ -26,25 +27,26 @@ class QueueActor: FINISH_MESSAGE = "$FINISH$" def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: + self.logger = get_logger(__name__) self.config = config self.capacity = getattr(config, "capacity", 10000) self.queue = asyncio.Queue(self.capacity) - if storage_config.path is not None and len(storage_config.path) > 0: - if is_database_url(storage_config.path): - storage_config.storage_type = StorageType.SQL - sql_config = deepcopy(storage_config) - sql_config.storage_type = StorageType.SQL - sql_config.wrap_in_ray = False - self.writer = SQLWriter(sql_config, self.config) - elif is_json_file(storage_config.path): - storage_config.storage_type = StorageType.FILE - json_config = deepcopy(storage_config) - json_config.storage_type = StorageType.FILE - self.writer = JSONWriter(json_config, self.config) + st_config = deepcopy(storage_config) + st_config.wrap_in_ray = False + if st_config.path is not None: + if is_database_url(st_config.path): + st_config.storage_type = StorageType.SQL + self.writer = SQLWriter(st_config, self.config) + elif is_json_file(st_config.path): + st_config.storage_type = StorageType.FILE + self.writer = JSONWriter(st_config, self.config) else: + self.logger.warning("Unknown supported storage path: %s", st_config.path) self.writer = None else: - self.writer = None + st_config.storage_type = StorageType.FILE + self.writer = JSONWriter(st_config, self.config) + self.logger.warning(f"Save experiences in {st_config.path}.") def length(self) -> int: """The length of the queue.""" diff --git a/trinity/buffer/db_wrapper.py b/trinity/buffer/ray_wrapper.py similarity index 65% rename from trinity/buffer/db_wrapper.py rename to trinity/buffer/ray_wrapper.py index bbf96c176e..9de80f32ab 100644 --- a/trinity/buffer/db_wrapper.py +++ b/trinity/buffer/ray_wrapper.py @@ -1,3 +1,5 @@ +import json +import os import time from typing import List, Optional @@ -8,9 +10,11 @@ from sqlalchemy.pool import NullPool from trinity.buffer.schema import Base, create_dynamic_table -from trinity.buffer.utils import retry_session +from trinity.buffer.utils import default_storage_path, retry_session from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import ReadStrategy +from trinity.common.experience import Experience +from trinity.common.workflows import Task from trinity.utils.log import get_logger @@ -27,6 +31,8 @@ class DBWrapper: def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: self.logger = get_logger(__name__) + if storage_config.path is None: + storage_config.path = default_storage_path(storage_config, config) self.engine = create_engine(storage_config.path, poolclass=NullPool) self.table_model_cls = create_dynamic_table( storage_config.algorithm_type, storage_config.name @@ -106,3 +112,60 @@ def read( self.logger.info(f"first prompt_text = {exp_list[0].prompt_text}") self.logger.info(f"first response_text = {exp_list[0].response_text}") return exp_list + + +class _Encoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, Experience): + return o.to_dict() + if isinstance(o, Task): + return o.to_dict() + return super().default(o) + + +class FileWrapper: + """ + A wrapper of a local jsonl file. + + If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as + a Ray Actor, and provide a remote interface to the local file. + + This wrapper is only for writing, if you want to read from the file, use + StorageType.QUEUE instead. + """ + + def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: + if storage_config.path is None: + storage_config.path = default_storage_path(storage_config, config) + ext = os.path.splitext(storage_config.path)[-1] + if ext != ".jsonl" and ext != ".json": + raise ValueError( + f"File path must end with '.json' or '.jsonl', got {storage_config.path}" + ) + self.file = open(storage_config.path, "a", encoding="utf-8") + self.encoder = _Encoder(ensure_ascii=False) + + @classmethod + def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): + if storage_config.wrap_in_ray: + return ( + ray.remote(cls) + .options( + name=f"json-{storage_config.name}", + get_if_exists=True, + ) + .remote(storage_config, config) + ) + else: + return cls(storage_config, config) + + def write(self, data: List) -> None: + for item in data: + json_str = self.encoder.encode(item) + self.file.write(json_str + "\n") + self.file.flush() + + def read(self) -> List: + raise NotImplementedError( + "read() is not implemented for FileWrapper, please use QUEUE instead" + ) diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py index eb21c92b95..cc725f842c 100644 --- a/trinity/buffer/reader/sql_reader.py +++ b/trinity/buffer/reader/sql_reader.py @@ -5,7 +5,7 @@ import ray from trinity.buffer.buffer_reader import BufferReader -from trinity.buffer.db_wrapper import DBWrapper +from trinity.buffer.ray_wrapper import DBWrapper from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import ReadStrategy, StorageType diff --git a/trinity/buffer/utils.py b/trinity/buffer/utils.py index ae5927792a..aa2c0e4849 100644 --- a/trinity/buffer/utils.py +++ b/trinity/buffer/utils.py @@ -1,6 +1,9 @@ +import os import time from contextlib import contextmanager +from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.constants import StorageType from trinity.utils.log import get_logger logger = get_logger(__name__) @@ -31,3 +34,18 @@ def retry_session(session_maker, max_retry_times: int, max_retry_interval: float raise e finally: session.close() + + +def default_storage_path(storage_config: StorageConfig, buffer_config: BufferConfig) -> str: + if buffer_config.cache_dir is None: + raise ValueError("Please call config.check_and_update() before using.") + if storage_config.storage_type == StorageType.SQL: + return "sqlite:///" + os.path.join( + buffer_config.cache_dir, + f"{storage_config.name}.db", + ) + else: + return os.path.join( + buffer_config.cache_dir, + f"{storage_config.name}.jsonl", + ) diff --git a/trinity/buffer/writer/file_writer.py b/trinity/buffer/writer/file_writer.py index b163b86b67..0fc4929ca5 100644 --- a/trinity/buffer/writer/file_writer.py +++ b/trinity/buffer/writer/file_writer.py @@ -1,39 +1,27 @@ -import json -import os from typing import List +import ray + from trinity.buffer.buffer_writer import BufferWriter +from trinity.buffer.ray_wrapper import FileWrapper from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import StorageType -from trinity.common.experience import Experience -from trinity.common.workflows import Task - - -class _Encoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, Experience): - return o.to_dict() - if isinstance(o, Task): - return o.to_dict() - return super().default(o) class JSONWriter(BufferWriter): def __init__(self, meta: StorageConfig, config: BufferConfig): assert meta.storage_type == StorageType.FILE - if meta.path is None: - raise ValueError("File path cannot be None for RawFileWriter") - ext = os.path.splitext(meta.path)[-1] - if ext != ".jsonl" and ext != ".json": - raise ValueError(f"File path must end with .json or .jsonl, got {meta.path}") - self.file = open(meta.path, "a", encoding="utf-8") - self.encoder = _Encoder(ensure_ascii=False) + self.writer = FileWrapper.get_wrapper(meta, config) + self.wrap_in_ray = meta.wrap_in_ray def write(self, data: List) -> None: - for item in data: - json_str = self.encoder.encode(item) - self.file.write(json_str + "\n") - self.file.flush() + if self.wrap_in_ray: + ray.get(self.writer.write.remote(data)) + else: + self.writer.write(data) def finish(self): - self.file.close() + if self.wrap_in_ray: + ray.get(self.writer.finish.remote()) + else: + self.writer.finish() diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py index 113ba5773d..e251792f7c 100644 --- a/trinity/buffer/writer/sql_writer.py +++ b/trinity/buffer/writer/sql_writer.py @@ -3,7 +3,7 @@ import ray from trinity.buffer.buffer_writer import BufferWriter -from trinity.buffer.db_wrapper import DBWrapper +from trinity.buffer.ray_wrapper import DBWrapper from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import StorageType diff --git a/trinity/common/config.py b/trinity/common/config.py index 3a5d0ff57c..42c54c442a 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -80,6 +80,9 @@ class StorageConfig: # used for StorageType.SQL wrap_in_ray: bool = True + # used for StorageType.QUEUE + capacity: int = 10000 + # used for rollout tasks default_workflow_type: Optional[str] = None default_reward_fn_type: Optional[str] = None @@ -234,6 +237,7 @@ class BufferConfig: read_batch_size: int = 1 # automatically set tokenizer_path: Optional[str] = None # automatically set pad_token_id: Optional[int] = None # automatically set + cache_dir: Optional[str] = None # automatically set @dataclass @@ -366,6 +370,7 @@ def _check_interval(self) -> None: self.trainer.save_interval = self.synchronizer.sync_interval def _check_buffer(self) -> None: # noqa: C901 + # TODO: split this function into different buffer read/writer # check explorer_input if self.mode != "train" and not self.buffer.explorer_input.taskset.path: raise ValueError( @@ -426,6 +431,11 @@ def _check_buffer(self) -> None: # noqa: C901 logger.info( f"Auto set `buffer.trainer_input.experience_buffer` to {self.buffer.trainer_input.experience_buffer}" ) + elif self.buffer.trainer_input.experience_buffer.storage_type is StorageType.FILE: + logger.warning( + "`FILE` storage is not supported to use as experience_buffer in `both` mode, use `QUEUE` instead." + ) + self.buffer.trainer_input.experience_buffer.storage_type = StorageType.QUEUE elif self.mode == "train": # TODO: to be check if self.algorithm.algorithm_type.is_dpo(): if ( @@ -470,6 +480,15 @@ def _check_buffer(self) -> None: # noqa: C901 logger.warning(f"Failed to get pad token id from model {self.model.model_path}") self.buffer.pad_token_id = 0 self.buffer.tokenizer_path = self.model.model_path + # create buffer.cache_dir at ///buffer + self.buffer.cache_dir = os.path.abspath(os.path.join(self.checkpoint_job_dir, "buffer")) + try: + os.makedirs(self.buffer.cache_dir, exist_ok=True) + except Exception: + logger.warning( + f"Failed to create buffer dir {self.buffer.cache_dir}, please check " + f"your checkpoint directory: {self.checkpoint_job_dir}" + ) def check_and_update(self) -> None: # noqa: C901 """Check and update the config.""" @@ -532,14 +551,14 @@ def check_and_update(self) -> None: # noqa: C901 self._check_interval() - # create a job dir in /monitor + # create a job dir in ///monitor self.monitor.cache_dir = os.path.join(self.checkpoint_job_dir, "monitor") try: os.makedirs(self.monitor.cache_dir, exist_ok=True) except Exception: logger.warning( f"Failed to create monitor dir {self.monitor.cache_dir}, please check " - f"your checkpoint directory: {self.checkpoint_root_dir}" + f"your checkpoint directory: {self.checkpoint_job_dir}" ) # check buffer