Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions tests/buffer/file_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import unittest

from tests.tools import get_template_config, get_unittest_dataset_config
from trinity.buffer.buffer import get_buffer_reader


class TestFileReader(unittest.TestCase):
def test_file_reader(self):
"""Test file reader."""
config = get_template_config()
dataset_config = get_unittest_dataset_config("countdown", "train")
config.buffer.explorer_input.taskset = dataset_config
reader = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer)

tasks = []
while True:
try:
tasks.extend(reader.read())
except StopIteration:
break
self.assertEqual(len(tasks), 16)

config.buffer.explorer_input.taskset.total_epochs = 2
config.buffer.explorer_input.taskset.index = 4
reader = get_buffer_reader(config.buffer.explorer_input.taskset, config.buffer)
tasks = []
while True:
try:
tasks.extend(reader.read())
except StopIteration:
break
self.assertEqual(len(tasks), 16 * 2 - 4)
27 changes: 26 additions & 1 deletion tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import torch

from tests.tools import RayUnittestBase
Expand All @@ -7,6 +9,8 @@
from trinity.common.constants import AlgorithmType, StorageType
from trinity.common.experience import Experience

file_path = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl")


class TestQueueBuffer(RayUnittestBase):
def test_queue_buffer(self):
Expand All @@ -17,6 +21,7 @@ def test_queue_buffer(self):
name="test_buffer",
algorithm_type=AlgorithmType.PPO,
storage_type=StorageType.QUEUE,
path=file_path,
)
config = BufferConfig(
max_retry_times=3,
Expand All @@ -36,9 +41,29 @@ def test_queue_buffer(self):
]
for _ in range(total_num // put_batch_size):
writer.write(exps)
writer.finish()
for _ in range(total_num // read_batch_size):
exps = reader.read()
self.assertEqual(len(exps), read_batch_size)
print(f"finish read {read_batch_size} experience")
writer.write(
[
Experience(
tokens=torch.tensor([float(j) for j in range(i + 1)]),
prompt_length=i,
reward=float(i),
logprobs=torch.tensor([0.1]),
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
)
for i in range(1, put_batch_size * 2 + 1)
]
)
exps = reader.read(batch_size=put_batch_size * 2)
self.assertEqual(len(exps), put_batch_size * 2)
writer.finish()
self.assertRaises(StopIteration, reader.read)
with open(file_path, "r") as f:
self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2)

def setUp(self):
if os.path.exists(file_path):
os.remove(file_path)
16 changes: 16 additions & 0 deletions tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,21 @@ def test_create_sql_buffer(self) -> None:
for _ in range(total_num // read_batch_size):
exps = sql_reader.read()
self.assertEqual(len(exps), read_batch_size)

# dynamic read/write
sql_writer.write(
[
Experience(
tokens=torch.tensor([float(j) for j in range(i + 1)]),
prompt_length=i,
reward=float(i),
logprobs=torch.tensor([0.1]),
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
)
for i in range(1, put_batch_size * 2 + 1)
]
)
exps = sql_reader.read(batch_size=put_batch_size * 2)
self.assertEqual(len(exps), put_batch_size * 2)
db_wrapper = ray.get_actor("sql-test_buffer")
self.assertIsNotNone(db_wrapper)
4 changes: 3 additions & 1 deletion trinity/buffer/buffer_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ class BufferReader(ABC):
"""Interface of the buffer reader."""

@abstractmethod
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
def read(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
) -> List:
"""Read from buffer."""
9 changes: 6 additions & 3 deletions trinity/buffer/db_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def write(self, data: list) -> None:
experience_models = [self.table_model_cls.from_experience(exp) for exp in data]
session.add_all(experience_models)

def read(self, strategy: Optional[ReadStrategy] = None) -> List:
def read(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
) -> List:
if strategy is None:
strategy = ReadStrategy.LFU

Expand All @@ -78,7 +80,8 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage")

exp_list = []
while len(exp_list) < self.batch_size:
batch_size = batch_size or self.batch_size
while len(exp_list) < batch_size:
if len(exp_list):
self.logger.info("waiting for experiences...")
time.sleep(1)
Expand All @@ -90,7 +93,7 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
session.query(self.table_model_cls)
.filter(self.table_model_cls.reward.isnot(None))
.order_by(*sortOrder) # TODO: very slow
.limit(self.batch_size - len(exp_list))
.limit(batch_size - len(exp_list))
.with_for_update()
.all()
)
Expand Down
32 changes: 25 additions & 7 deletions trinity/buffer/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,20 @@

import ray

from trinity.buffer.writer.file_writer import JSONWriter
from trinity.buffer.writer.sql_writer import SQLWriter
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import StorageType


def is_database_url(path: str) -> bool:
return any(path.startswith(prefix) for prefix in ["sqlite:///", "postgresql://", "mysql://"])


def is_json_file(path: str) -> bool:
return path.endswith(".json") or path.endswith(".jsonl")


@ray.remote
class QueueActor:
"""An asyncio.Queue based queue actor."""
Expand All @@ -21,12 +30,21 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
self.capacity = getattr(config, "capacity", 10000)
self.queue = asyncio.Queue(self.capacity)
if storage_config.path is not None and len(storage_config.path) > 0:
sql_config = deepcopy(storage_config)
sql_config.storage_type = StorageType.SQL
sql_config.wrap_in_ray = False
self.sql_writer = SQLWriter(sql_config, self.config)
if is_database_url(storage_config.path):
storage_config.storage_type = StorageType.SQL
sql_config = deepcopy(storage_config)
sql_config.storage_type = StorageType.SQL
sql_config.wrap_in_ray = False
self.writer = SQLWriter(sql_config, self.config)
elif is_json_file(storage_config.path):
storage_config.storage_type = StorageType.FILE
json_config = deepcopy(storage_config)
json_config.storage_type = StorageType.FILE
self.writer = JSONWriter(json_config, self.config)
else:
self.writer = None
else:
self.sql_writer = None
self.writer = None

def length(self) -> int:
"""The length of the queue."""
Expand All @@ -35,8 +53,8 @@ def length(self) -> int:
async def put_batch(self, exp_list: List) -> None:
"""Put batch of experience."""
await self.queue.put(exp_list)
if self.sql_writer is not None:
self.sql_writer.write(exp_list)
if self.writer is not None:
self.writer.write(exp_list)

async def finish(self) -> None:
"""Stop the queue."""
Expand Down
Loading