Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import torch

from tests.tools import RayUnittestBase
Expand All @@ -7,6 +9,8 @@
from trinity.common.constants import AlgorithmType, StorageType
from trinity.common.experience import Experience

file_path = os.path.join(os.path.dirname(__file__), "test_queue_buffer.jsonl")


class TestQueueBuffer(RayUnittestBase):
def test_queue_buffer(self):
Expand All @@ -17,6 +21,7 @@ def test_queue_buffer(self):
name="test_buffer",
algorithm_type=AlgorithmType.PPO,
storage_type=StorageType.QUEUE,
path=file_path,
)
config = BufferConfig(
max_retry_times=3,
Expand All @@ -36,9 +41,25 @@ def test_queue_buffer(self):
]
for _ in range(total_num // put_batch_size):
writer.write(exps)
writer.finish()
for _ in range(total_num // read_batch_size):
exps = reader.read()
self.assertEqual(len(exps), read_batch_size)
print(f"finish read {read_batch_size} experience")
writer.write(
[
Experience(
tokens=torch.tensor([float(j) for j in range(i + 1)]),
prompt_length=i,
reward=float(i),
logprobs=torch.tensor([0.1]),
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
)
for i in range(1, put_batch_size * 2 + 1)
]
)
exps = reader.read(batch_size=put_batch_size * 2)
self.assertEqual(len(exps), put_batch_size * 2)
writer.finish()
self.assertRaises(StopIteration, reader.read)
with open(file_path, "r") as f:
self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2)
16 changes: 16 additions & 0 deletions tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,21 @@ def test_create_sql_buffer(self) -> None:
for _ in range(total_num // read_batch_size):
exps = sql_reader.read()
self.assertEqual(len(exps), read_batch_size)

# dynamic read/write
sql_writer.write(
[
Experience(
tokens=torch.tensor([float(j) for j in range(i + 1)]),
prompt_length=i,
reward=float(i),
logprobs=torch.tensor([0.1]),
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
)
for i in range(1, put_batch_size * 2 + 1)
]
)
exps = sql_reader.read(batch_size=put_batch_size * 2)
self.assertEqual(len(exps), put_batch_size * 2)
db_wrapper = ray.get_actor("sql-test_buffer")
self.assertIsNotNone(db_wrapper)
4 changes: 3 additions & 1 deletion trinity/buffer/buffer_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ class BufferReader(ABC):
"""Interface of the buffer reader."""

@abstractmethod
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
def read(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
) -> List:
"""Read from buffer."""
9 changes: 6 additions & 3 deletions trinity/buffer/db_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def write(self, data: list) -> None:
experience_models = [self.table_model_cls.from_experience(exp) for exp in data]
session.add_all(experience_models)

def read(self, strategy: Optional[ReadStrategy] = None) -> List:
def read(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
) -> List:
if strategy is None:
strategy = ReadStrategy.LFU

Expand All @@ -78,7 +80,8 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage")

exp_list = []
while len(exp_list) < self.batch_size:
batch_size = batch_size or self.batch_size
while len(exp_list) < batch_size:
if len(exp_list):
self.logger.info("waiting for experiences...")
time.sleep(1)
Expand All @@ -90,7 +93,7 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
session.query(self.table_model_cls)
.filter(self.table_model_cls.reward.isnot(None))
.order_by(*sortOrder) # TODO: very slow
.limit(self.batch_size - len(exp_list))
.limit(batch_size - len(exp_list))
.with_for_update()
.all()
)
Expand Down
32 changes: 25 additions & 7 deletions trinity/buffer/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,20 @@

import ray

from trinity.buffer.writer.file_writer import JSONWriter
from trinity.buffer.writer.sql_writer import SQLWriter
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import StorageType


def is_database_url(path: str) -> bool:
return any(path.startswith(prefix) for prefix in ["sqlite:///", "postgresql://", "mysql://"])


def is_json_file(path: str) -> bool:
return path.endswith(".json") or path.endswith(".jsonl")


@ray.remote
class QueueActor:
"""An asyncio.Queue based queue actor."""
Expand All @@ -21,12 +30,21 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
self.capacity = getattr(config, "capacity", 10000)
self.queue = asyncio.Queue(self.capacity)
if storage_config.path is not None and len(storage_config.path) > 0:
sql_config = deepcopy(storage_config)
sql_config.storage_type = StorageType.SQL
sql_config.wrap_in_ray = False
self.sql_writer = SQLWriter(sql_config, self.config)
if is_database_url(storage_config.path):
storage_config.storage_type = StorageType.SQL
sql_config = deepcopy(storage_config)
sql_config.storage_type = StorageType.SQL
sql_config.wrap_in_ray = False
self.writer = SQLWriter(sql_config, self.config)
elif is_json_file(storage_config.path):
storage_config.storage_type = StorageType.FILE
json_config = deepcopy(storage_config)
json_config.storage_type = StorageType.FILE
self.writer = JSONWriter(json_config, self.config)
else:
self.writer = None
else:
self.sql_writer = None
self.writer = None

def length(self) -> int:
"""The length of the queue."""
Expand All @@ -35,8 +53,8 @@ def length(self) -> int:
async def put_batch(self, exp_list: List) -> None:
"""Put batch of experience."""
await self.queue.put(exp_list)
if self.sql_writer is not None:
self.sql_writer.write(exp_list)
if self.writer is not None:
self.writer.write(exp_list)

async def finish(self) -> None:
"""Stop the queue."""
Expand Down
151 changes: 85 additions & 66 deletions trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Filed based buffer reader."""

from itertools import islice
from typing import List, Optional

import datasets
import transformers
from datasets import load_dataset
from datasets import Dataset, load_dataset

from trinity.buffer.buffer_reader import BufferReader
from trinity.common.config import BufferConfig, StorageConfig
Expand All @@ -17,6 +18,22 @@
FILE_READERS = Registry("file_readers")


@FILE_READERS.register_module(AlgorithmType.SFT.value)
class _HFBatchReader:
def __init__(self, dataset: Dataset):
self.dataset = dataset
self.current_batch_size = None

def set_offset(self, offset: int) -> None:
self.iter = self.dataset.iter(offset)

def read_batch(self, batch_size: int) -> List:
batch = list(islice(self.dataset, batch_size))
if not batch or batch_size != len(batch):
raise StopIteration
return batch


@FILE_READERS.register_module(AlgorithmType.SFT.value)
class SFTDataReader(BufferReader):
"""Reader for SFT file data."""
Expand All @@ -29,22 +46,20 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
self.prompt_key = meta.format.prompt_key
self.response_key = meta.format.response_key
self.read_batch_size = config.read_batch_size
self.dataset = load_dataset(
meta.path, name=subset_name, split=self.split
self.dataset = _HFBatchReader(
load_dataset(meta.path, name=subset_name, split=self.split)
) # TODO: support resume
self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)

def read(self, strategy: Optional[ReadStrategy] = None) -> List:
try:
batch_data = next(self.data_iter)
except StopIteration:
self.dataset = self.dataset.shuffle()
self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True)
batch_data = next(self.data_iter)
def read(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
) -> List:
samples = self.dataset.read_batch(batch_size or self.read_batch_size)
exp_list = []
if self.prompt_type == PromptType.MESSAGES:
for messages in batch_data[self.messages_key]:
for sample in samples:
messages = sample[self.messages_key]
tokens = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=False, return_tensors="pt"
)[0]
Expand All @@ -58,9 +73,9 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
exp_list.append(experience)

elif self.prompt_type == PromptType.CHATPAIR:
for prompt_messages, response_messages in zip(
batch_data[self.prompt_key], batch_data[self.response_key]
):
for sample in samples:
prompt_messages = sample[self.prompt_key]
response_messages = sample[self.response_key]
if not isinstance(prompt_messages, list):
prompt_messages = [prompt_messages]
if not isinstance(response_messages, list):
Expand All @@ -83,7 +98,9 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:

elif self.prompt_type == PromptType.PLAINTEXT:
# TODO: support HF format without chat template
for prompt, response in zip(batch_data[self.prompt_key], batch_data[self.response_key]):
for sample in samples:
prompt = sample[self.prompt_key]
response = sample[self.response_key]
tokens = self.tokenizer(prompt + response, return_tensors="pt")["input_ids"][0]
prompt_tokens = self.tokenizer(prompt, return_tensors="pt")["input_ids"][0]
experience = Experience(
Expand All @@ -106,8 +123,8 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
self.chosen_key = meta.format.chosen_key
self.rejected_key = meta.format.rejected_key
self.read_batch_size = config.read_batch_size
self.dataset = load_dataset(
meta.path, name=subset_name, split=self.split
self.dataset = _HFBatchReader(
load_dataset(meta.path, name=subset_name, split=self.split)
) # TODO: support resume
self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
Expand All @@ -120,17 +137,16 @@ def _get_assistant_message(self, item) -> dict:
else:
return item

def read(self, strategy: Optional[ReadStrategy] = None) -> List:
try:
batch_data = next(self.data_iter)
except StopIteration:
self.dataset = self.dataset.shuffle()
self.data_iter = self.dataset.iter(self.read_batch_size, drop_last_batch=True)
batch_data = next(self.data_iter)
def read(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
) -> List:
batch_data = self.dataset.read_batch(batch_size or self.read_batch_size)
exp_list = []
for prompt, chosen, rejected in zip(
batch_data[self.prompt_key], batch_data[self.chosen_key], batch_data[self.rejected_key]
):
for sample in batch_data:
prompt = sample[self.prompt_key]
chosen = sample[self.chosen_key]
rejected = sample[self.rejected_key]

if self.prompt_type == PromptType.MESSAGES:
prompt_messages = prompt

Expand Down Expand Up @@ -177,18 +193,13 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
self.split = meta.split
subset_name = meta.subset_name
# disable datasets caching to avoid reuse old-version dataset
self.epoch = 0
datasets.disable_caching()
self.dataset = load_dataset(
meta.path, name=subset_name, split=self.split
) # TODO: may from db_url
# if task_type != TaskType.EVAL and config.db_url != "":
# logger.info(f"Loading dataset from database with url: {config.db_url}")
# db_type = config.db_url.split(":")[0]
# db_name = config.db_url.split("/")[-1]
# dataset = Dataset.from_sql(RftDatasetModel.__tablename__, f"{db_type}:///{db_name}")
datasets.enable_caching()
self.index = meta.index # TODO: apply shuffle

self.dataset = _HFBatchReader(load_dataset(meta.path, name=subset_name, split=self.split))
if self.meta.index > 0:
# offset the dataset to the correct index
self.dataset.read_batch(self.meta.index)
self.read_batch_size = config.batch_size
self.prompt_key = meta.format.prompt_key
self.response_key = meta.format.response_key
self.workflow_key = meta.format.workflow_key
Expand All @@ -202,31 +213,39 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
def __len__(self):
return len(self.dataset)

def read(self, strategy: Optional[ReadStrategy] = None):
if self.index >= len(self.dataset) * self.total_epochs:
raise StopIteration
sample = self.dataset[self.index % len(self.dataset)]
workflow_class = (
WORKFLOWS.get(sample[self.workflow_key])
if self.workflow_key in sample
else self.default_workflow_cls
)
reward_fn = (
REWARD_FUNCTIONS.get(sample[self.reward_fn_key])
if self.reward_fn_key in sample
else self.default_reward_fn_cls
)
assert workflow_class is not None, "`default_workflow_type` or `workflow_key` is required"
task = Task(
workflow=workflow_class,
format_args=self.meta.format,
rollout_args=self.meta.rollout_args,
workflow_args=self.meta.workflow_args,
is_eval=self.meta.task_type == TaskType.EVAL,
reward_fn=reward_fn,
raw_task=sample,
)
self.index += 1
if self.task_type == TaskType.EVAL and self.index == len(self.dataset):
self.index = 0
return task
def read(
self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None
) -> List:
batch_size = batch_size or self.read_batch_size
tasks = []
try:
samples = self.dataset.read_batch(batch_size)
except StopIteration:
self.epoch += 1
if self.epoch >= self.total_epochs:
raise StopIteration
for sample in samples:
workflow_class = (
WORKFLOWS.get(sample[self.workflow_key])
if self.workflow_key in sample
else self.default_workflow_cls
)
reward_fn = (
REWARD_FUNCTIONS.get(sample[self.reward_fn_key])
if self.reward_fn_key in sample
else self.default_reward_fn_cls
)
assert (
workflow_class is not None
), "`default_workflow_type` or `workflow_key` is required"
task = Task(
workflow=workflow_class,
format_args=self.meta.format,
rollout_args=self.meta.rollout_args,
workflow_args=self.meta.workflow_args,
is_eval=self.meta.task_type == TaskType.EVAL,
reward_fn=reward_fn,
raw_task=sample,
)
tasks.append(task)
return tasks
Loading
Loading