Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import unittest

import ray
import torch

from trinity.buffer.reader.sql_reader import SQLReader
Expand All @@ -22,6 +23,7 @@ def test_create_sql_buffer(self) -> None:
algorithm_type=AlgorithmType.PPO,
path=f"sqlite:///{db_path}",
storage_type=StorageType.SQL,
wrap_in_ray=True,
)
config = BufferConfig(
max_retry_times=3,
Expand All @@ -45,3 +47,5 @@ def test_create_sql_buffer(self) -> None:
for _ in range(total_num // read_batch_size):
exps = sql_reader.read()
self.assertEqual(len(exps), read_batch_size)
db_wrapper = ray.get_actor("sql-test_buffer")
self.assertIsNotNone(db_wrapper)
105 changes: 105 additions & 0 deletions trinity/buffer/db_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import time
from typing import List, Optional

import ray
from sqlalchemy import asc, create_engine, desc
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool

from trinity.buffer.schema import Base, create_dynamic_table
from trinity.buffer.utils import retry_session
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import ReadStrategy
from trinity.utils.log import get_logger


class DBWrapper:
"""
A wrapper of a SQL database.

If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as a Ray Actor,
and provide a remote interface to the local database.

For databases that do not support multi-processing read/write (e.g. sqlite, duckdb), we
recommend setting `wrap_in_ray` to `True`
"""

def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
self.logger = get_logger(__name__)
self.engine = create_engine(storage_config.path, poolclass=NullPool)
self.table_model_cls = create_dynamic_table(
storage_config.algorithm_type, storage_config.name
)

try:
Base.metadata.create_all(self.engine, checkfirst=True)
except OperationalError:
self.logger.warning("Failed to create database, assuming it already exists.")

self.session = sessionmaker(bind=self.engine)
self.batch_size = config.read_batch_size
self.max_retry_times = config.max_retry_times
self.max_retry_interval = config.max_retry_interval

@classmethod
def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
if storage_config.wrap_in_ray:
return (
ray.remote(cls)
.options(
name=f"sql-{storage_config.name}",
get_if_exists=True,
)
.remote(storage_config, config)
)
else:
return cls(storage_config, config)

def write(self, data: list) -> None:
with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session:
experience_models = [self.table_model_cls.from_experience(exp) for exp in data]
session.add_all(experience_models)

def read(self, strategy: Optional[ReadStrategy] = None) -> List:
if strategy is None:
strategy = ReadStrategy.LFU

if strategy == ReadStrategy.LFU:
sortOrder = (asc(self.table_model_cls.consumed), desc(self.table_model_cls.id))

elif strategy == ReadStrategy.LRU:
sortOrder = (desc(self.table_model_cls.id),)

elif strategy == ReadStrategy.PRIORITY:
sortOrder = (desc(self.table_model_cls.priority), desc(self.table_model_cls.id))

else:
raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage")

exp_list = []
while len(exp_list) < self.batch_size:
if len(exp_list):
self.logger.info("waiting for experiences...")
time.sleep(1)
with retry_session(
self.session, self.max_retry_times, self.max_retry_interval
) as session:
# get a batch of experiences from the database
experiences = (
session.query(self.table_model_cls)
.filter(self.table_model_cls.reward.isnot(None))
.order_by(*sortOrder) # TODO: very slow
.limit(self.batch_size - len(exp_list))
.with_for_update()
.all()
)
# update the consumed field
for exp in experiences:
exp.consumed += 1
exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences])
self.logger.info(f"get {len(exp_list)} experiences:")
self.logger.info(f"reward = {[exp.reward for exp in exp_list]}")
self.logger.info(f"first prompt_text = {exp_list[0].prompt_text}")
self.logger.info(f"first response_text = {exp_list[0].response_text}")
return exp_list
1 change: 1 addition & 0 deletions trinity/buffer/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
if storage_config.path is not None and len(storage_config.path) > 0:
sql_config = deepcopy(storage_config)
sql_config.storage_type = StorageType.SQL
sql_config.wrap_in_ray = False
self.sql_writer = SQLWriter(sql_config, self.config)
else:
self.sql_writer = None
Expand Down
8 changes: 4 additions & 4 deletions trinity/buffer/reader/queue_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
class QueueReader(BufferReader):
"""Reader of the Queue buffer."""

def __init__(self, meta: StorageConfig, config: BufferConfig):
assert meta.storage_type == StorageType.QUEUE
def __init__(self, storage_config: StorageConfig, config: BufferConfig):
assert storage_config.storage_type == StorageType.QUEUE
self.config = config
self.queue = QueueActor.options(
name=f"queue-{meta.name}",
name=f"queue-{storage_config.name}",
get_if_exists=True,
).remote(meta, config)
).remote(storage_config, config)

def read(self, strategy: Optional[ReadStrategy] = None) -> List:
if strategy is not None and strategy != ReadStrategy.FIFO:
Expand Down
68 changes: 7 additions & 61 deletions trinity/buffer/reader/sql_reader.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,25 @@
"""Reader of the SQL buffer."""

import time
from typing import List, Optional

from sqlalchemy import asc, create_engine, desc
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
import ray

from trinity.buffer.buffer_reader import BufferReader
from trinity.buffer.schema import Base, create_dynamic_table
from trinity.buffer.utils import retry_session
from trinity.buffer.db_wrapper import DBWrapper
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import ReadStrategy, StorageType
from trinity.utils.log import get_logger

logger = get_logger(__name__)


class SQLReader(BufferReader):
"""Reader of the SQL buffer."""

def __init__(self, meta: StorageConfig, config: BufferConfig) -> None:
assert meta.storage_type == StorageType.SQL
self.engine = create_engine(meta.path, poolclass=NullPool)

self.table_model_cls = create_dynamic_table(meta.algorithm_type, meta.name)
try:
Base.metadata.create_all(self.engine, checkfirst=True)
except OperationalError:
logger.warning("Failed to create database, assuming it already exists.")
self.session = sessionmaker(bind=self.engine)
self.batch_size = config.read_batch_size
self.max_retry_times = config.max_retry_times
self.max_retry_interval = config.max_retry_interval
self.wrap_in_ray = meta.wrap_in_ray
self.db_wrapper = DBWrapper.get_wrapper(meta, config)

def read(self, strategy: Optional[ReadStrategy] = None) -> List:
if strategy is None:
strategy = ReadStrategy.LFU

if strategy == ReadStrategy.LFU:
sortOrder = (asc(self.table_model_cls.consumed), desc(self.table_model_cls.id))

elif strategy == ReadStrategy.LRU:
sortOrder = (desc(self.table_model_cls.id),)

elif strategy == ReadStrategy.PRIORITY:
sortOrder = (desc(self.table_model_cls.priority), desc(self.table_model_cls.id))

if self.wrap_in_ray:
return ray.get(self.db_wrapper.read.remote(strategy))
else:
raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage")

exp_list = []
while len(exp_list) < self.batch_size:
if len(exp_list):
logger.info("waiting for experiences...")
time.sleep(1)
with retry_session(
self.session, self.max_retry_times, self.max_retry_interval
) as session:
# get a batch of experiences from the database
experiences = (
session.query(self.table_model_cls)
.filter(self.table_model_cls.reward.isnot(None))
.order_by(*sortOrder) # TODO: very slow
.limit(self.batch_size - len(exp_list))
.with_for_update()
.all()
)
# update the consumed field
for exp in experiences:
exp.consumed += 1
exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences])
logger.info(f"get {len(exp_list)} experiences:")
logger.info(f"reward = {[exp.reward for exp in exp_list]}")
logger.info(f"first prompt_text = {exp_list[0].prompt_text}")
logger.info(f"first response_text = {exp_list[0].response_text}")
return exp_list
return self.db_wrapper.read(strategy)
34 changes: 9 additions & 25 deletions trinity/buffer/writer/sql_writer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
"""Writer of the SQL buffer."""

from sqlalchemy import create_engine
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
import ray

from trinity.buffer.buffer_writer import BufferWriter
from trinity.buffer.schema import Base, create_dynamic_table
from trinity.buffer.utils import retry_session
from trinity.buffer.db_wrapper import DBWrapper
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.constants import StorageType
from trinity.utils.log import get_logger

logger = get_logger(__name__)


class SQLWriter(BufferWriter):
Expand All @@ -22,24 +15,15 @@ def __init__(self, meta: StorageConfig, config: BufferConfig) -> None:
assert meta.storage_type == StorageType.SQL
# we only support write RFT algorithm buffer for now
# TODO: support other algorithms
assert meta.algorithm_type.is_rft, "Only RFT buffer is supported for writing."
self.engine = create_engine(meta.path, poolclass=NullPool)
self.table_model_cls = create_dynamic_table(meta.algorithm_type, meta.name)

try:
Base.metadata.create_all(self.engine, checkfirst=True)
except OperationalError:
logger.warning("Failed to create database, assuming it already exists.")

self.session = sessionmaker(bind=self.engine)
self.batch_size = config.read_batch_size
self.max_retry_times = config.max_retry_times
self.max_retry_interval = config.max_retry_interval
assert meta.algorithm_type.is_rft(), "Only RFT buffer is supported for writing."
self.wrap_in_ray = meta.wrap_in_ray
self.db_wrapper = DBWrapper.get_wrapper(meta, config)

def write(self, data: list) -> None:
with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session:
experience_models = [self.table_model_cls.from_experience(exp) for exp in data]
session.add_all(experience_models)
if self.wrap_in_ray:
ray.get(self.db_wrapper.write.remote(data))
else:
self.db_wrapper.write(data)

def finish(self) -> None:
# TODO: implement this
Expand Down
3 changes: 3 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ class StorageConfig:
format: FormatConfig = field(default_factory=FormatConfig)
index: int = 0

# used for StorageType.SQL
wrap_in_ray: bool = True

# used for rollout tasks
default_workflow_type: Optional[str] = None
default_reward_fn_type: Optional[str] = None
Expand Down