Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 72 additions & 10 deletions tests/buffer/file_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import os
import unittest

from tests.tools import get_template_config, get_unittest_dataset_config
from trinity.buffer.buffer import get_buffer_reader
import ray

from tests.tools import (
get_checkpoint_path,
get_template_config,
get_unittest_dataset_config,
)
from trinity.buffer.buffer import get_buffer_reader, get_buffer_writer
from trinity.buffer.utils import default_storage_path
from trinity.common.config import StorageConfig
from trinity.common.constants import StorageType

class TestFileReader(unittest.TestCase):

class TestFileBuffer(unittest.TestCase):
def test_file_reader(self):
"""Test file reader."""
config = get_template_config()
dataset_config = get_unittest_dataset_config("countdown", "train")
config.buffer.explorer_input.taskset = dataset_config
reader = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer)
reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer)

tasks = []
while True:
Expand All @@ -20,13 +27,68 @@ def test_file_reader(self):
break
self.assertEqual(len(tasks), 16)

config.buffer.explorer_input.taskset.total_epochs = 2
config.buffer.explorer_input.taskset.index = 4
reader = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer)
# test epoch and offset
self.config.buffer.explorer_input.taskset.total_epochs = 2
self.config.buffer.explorer_input.taskset.index = 4
reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer)
tasks = []
while True:
try:
tasks.extend(reader.read())
except StopIteration:
break
self.assertEqual(len(tasks), 16 * 2 - 4)

# test offset > dataset_len
self.config.buffer.explorer_input.taskset.total_epochs = 3
self.config.buffer.explorer_input.taskset.index = 20
reader = get_buffer_reader(self.config.buffer.explorer_input.taskset, self.config.buffer)
tasks = []
while True:
try:
tasks.extend(reader.read())
except StopIteration:
break
self.assertEqual(len(tasks), 16 * 3 - 20)

def test_file_writer(self):
writer = get_buffer_writer(
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
)
writer.write(
[
{"prompt": "hello world"},
{"prompt": "hi"},
]
)
file_wrapper = ray.get_actor("json-test_buffer")
self.assertIsNotNone(file_wrapper)
file_path = default_storage_path(
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
)
with open(file_path, "r") as f:
self.assertEqual(len(f.readlines()), 2)

def setUp(self):
self.config = get_template_config()
self.config.checkpoint_root_dir = get_checkpoint_path()
dataset_config = get_unittest_dataset_config("countdown", "train")
self.config.buffer.explorer_input.taskset = dataset_config
self.config.buffer.trainer_input.experience_buffer = StorageConfig(
name="test_buffer", storage_type=StorageType.FILE
)
self.config.buffer.trainer_input.experience_buffer.name = "test_buffer"
self.config.buffer.cache_dir = os.path.join(
self.config.checkpoint_root_dir, self.config.project, self.config.name, "buffer"
)
os.makedirs(self.config.buffer.cache_dir, exist_ok=True)
if os.path.exists(
default_storage_path(
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
)
):
os.remove(
default_storage_path(
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
)
)
10 changes: 5 additions & 5 deletions tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from trinity.common.constants import AlgorithmType, StorageType
from trinity.common.experience import Experience

file_path = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl")
BUFFER_FILE_PATH = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl")


class TestQueueBuffer(RayUnittestBase):
Expand All @@ -21,7 +21,7 @@ def test_queue_buffer(self):
name="test_buffer",
algorithm_type=AlgorithmType.PPO,
storage_type=StorageType.QUEUE,
path=file_path,
path=BUFFER_FILE_PATH,
)
config = BufferConfig(
max_retry_times=3,
Expand Down Expand Up @@ -61,9 +61,9 @@ def test_queue_buffer(self):
self.assertEqual(len(exps), put_batch_size * 2)
writer.finish()
self.assertRaises(StopIteration, reader.read)
with open(file_path, "r") as f:
with open(BUFFER_FILE_PATH, "r") as f:
self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2)

def setUp(self):
if os.path.exists(file_path):
os.remove(file_path)
if os.path.exists(BUFFER_FILE_PATH):
os.remove(BUFFER_FILE_PATH)
1 change: 1 addition & 0 deletions tests/explorer/runner_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def setUp(self):
name="test",
storage_type=StorageType.QUEUE,
algorithm_type=AlgorithmType.PPO,
path="",
)
self.queue = QueueReader(
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
Expand Down
4 changes: 4 additions & 0 deletions trinity/buffer/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,9 @@ def get_buffer_writer(storage_config: StorageConfig, buffer_config: BufferConfig
from trinity.buffer.writer.queue_writer import QueueWriter

return QueueWriter(storage_config, buffer_config)
elif storage_config.storage_type == StorageType.FILE:
from trinity.buffer.writer.file_writer import JSONWriter

return JSONWriter(storage_config, buffer_config)
else:
raise ValueError(f"{storage_config.storage_type} not supported.")
28 changes: 15 additions & 13 deletions trinity/buffer/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from trinity.buffer.writer.sql_writer import SQLWriter
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import StorageType
from trinity.utils.log import get_logger


def is_database_url(path: str) -> bool:
Expand All @@ -26,25 +27,26 @@ class QueueActor:
FINISH_MESSAGE = "$FINISH$"

def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
self.logger = get_logger(__name__)
self.config = config
self.capacity = getattr(config, "capacity", 10000)
self.queue = asyncio.Queue(self.capacity)
if storage_config.path is not None and len(storage_config.path) > 0:
if is_database_url(storage_config.path):
storage_config.storage_type = StorageType.SQL
sql_config = deepcopy(storage_config)
sql_config.storage_type = StorageType.SQL
sql_config.wrap_in_ray = False
self.writer = SQLWriter(sql_config, self.config)
elif is_json_file(storage_config.path):
storage_config.storage_type = StorageType.FILE
json_config = deepcopy(storage_config)
json_config.storage_type = StorageType.FILE
self.writer = JSONWriter(json_config, self.config)
st_config = deepcopy(storage_config)
st_config.wrap_in_ray = False
if st_config.path is not None:
if is_database_url(st_config.path):
st_config.storage_type = StorageType.SQL
self.writer = SQLWriter(st_config, self.config)
elif is_json_file(st_config.path):
st_config.storage_type = StorageType.FILE
self.writer = JSONWriter(st_config, self.config)
else:
self.logger.warning("Unknown supported storage path: %s", st_config.path)
self.writer = None
else:
self.writer = None
st_config.storage_type = StorageType.FILE
self.writer = JSONWriter(st_config, self.config)
self.logger.warning(f"Save experiences in {st_config.path}.")

def length(self) -> int:
"""The length of the queue."""
Expand Down
65 changes: 64 additions & 1 deletion trinity/buffer/db_wrapper.py → trinity/buffer/ray_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import os
import time
from typing import List, Optional

Expand All @@ -8,9 +10,11 @@
from sqlalchemy.pool import NullPool

from trinity.buffer.schema import Base, create_dynamic_table
from trinity.buffer.utils import retry_session
from trinity.buffer.utils import default_storage_path, retry_session
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import ReadStrategy
from trinity.common.experience import Experience
from trinity.common.workflows import Task
from trinity.utils.log import get_logger


Expand All @@ -27,6 +31,8 @@ class DBWrapper:

def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
self.logger = get_logger(__name__)
if storage_config.path is None:
storage_config.path = default_storage_path(storage_config, config)
self.engine = create_engine(storage_config.path, poolclass=NullPool)
self.table_model_cls = create_dynamic_table(
storage_config.algorithm_type, storage_config.name
Expand Down Expand Up @@ -106,3 +112,60 @@ def read(
self.logger.info(f"first prompt_text = {exp_list[0].prompt_text}")
self.logger.info(f"first response_text = {exp_list[0].response_text}")
return exp_list


class _Encoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, Experience):
return o.to_dict()
if isinstance(o, Task):
return o.to_dict()
return super().default(o)


class FileWrapper:
"""
A wrapper of a local jsonl file.

If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as
a Ray Actor, and provide a remote interface to the local file.

This wrapper is only for writing, if you want to read from the file, use
StorageType.QUEUE instead.
"""

def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
if storage_config.path is None:
storage_config.path = default_storage_path(storage_config, config)
ext = os.path.splitext(storage_config.path)[-1]
if ext != ".jsonl" and ext != ".json":
raise ValueError(
f"File path must end with '.json' or '.jsonl', got {storage_config.path}"
)
self.file = open(storage_config.path, "a", encoding="utf-8")
self.encoder = _Encoder(ensure_ascii=False)

@classmethod
def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
if storage_config.wrap_in_ray:
return (
ray.remote(cls)
.options(
name=f"json-{storage_config.name}",
get_if_exists=True,
)
.remote(storage_config, config)
)
else:
return cls(storage_config, config)

def write(self, data: List) -> None:
for item in data:
json_str = self.encoder.encode(item)
self.file.write(json_str + "\n")
self.file.flush()

def read(self) -> List:
raise NotImplementedError(
"read() is not implemented for FileWrapper, please use QUEUE instead"
)
2 changes: 1 addition & 1 deletion trinity/buffer/reader/sql_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import ray

from trinity.buffer.buffer_reader import BufferReader
from trinity.buffer.db_wrapper import DBWrapper
from trinity.buffer.ray_wrapper import DBWrapper
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import ReadStrategy, StorageType

Expand Down
18 changes: 18 additions & 0 deletions trinity/buffer/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
import time
from contextlib import contextmanager

from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import StorageType
from trinity.utils.log import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -31,3 +34,18 @@ def retry_session(session_maker, max_retry_times: int, max_retry_interval: float
raise e
finally:
session.close()


def default_storage_path(storage_config: StorageConfig, buffer_config: BufferConfig) -> str:
if buffer_config.cache_dir is None:
raise ValueError("Please call config.check_and_update() before using.")
if storage_config.storage_type == StorageType.SQL:
return "sqlite:///" + os.path.join(
buffer_config.cache_dir,
f"{storage_config.name}.db",
)
else:
return os.path.join(
buffer_config.cache_dir,
f"{storage_config.name}.jsonl",
)
38 changes: 13 additions & 25 deletions trinity/buffer/writer/file_writer.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,27 @@
import json
import os
from typing import List

import ray

from trinity.buffer.buffer_writer import BufferWriter
from trinity.buffer.ray_wrapper import FileWrapper
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import StorageType
from trinity.common.experience import Experience
from trinity.common.workflows import Task


class _Encoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, Experience):
return o.to_dict()
if isinstance(o, Task):
return o.to_dict()
return super().default(o)


class JSONWriter(BufferWriter):
def __init__(self, meta: StorageConfig, config: BufferConfig):
assert meta.storage_type == StorageType.FILE
if meta.path is None:
raise ValueError("File path cannot be None for RawFileWriter")
ext = os.path.splitext(meta.path)[-1]
if ext != ".jsonl" and ext != ".json":
raise ValueError(f"File path must end with .json or .jsonl, got {meta.path}")
self.file = open(meta.path, "a", encoding="utf-8")
self.encoder = _Encoder(ensure_ascii=False)
self.writer = FileWrapper.get_wrapper(meta, config)
self.wrap_in_ray = meta.wrap_in_ray

def write(self, data: List) -> None:
for item in data:
json_str = self.encoder.encode(item)
self.file.write(json_str + "\n")
self.file.flush()
if self.wrap_in_ray:
ray.get(self.writer.write.remote(data))
else:
self.writer.write(data)

def finish(self):
self.file.close()
if self.wrap_in_ray:
ray.get(self.writer.finish.remote())
else:
self.writer.finish()
2 changes: 1 addition & 1 deletion trinity/buffer/writer/sql_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import ray

from trinity.buffer.buffer_writer import BufferWriter
from trinity.buffer.db_wrapper import DBWrapper
from trinity.buffer.ray_wrapper import DBWrapper
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import StorageType

Expand Down
Loading