diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 61ebc46315..222beebbb3 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -1,6 +1,7 @@ import os import unittest +import ray import torch from trinity.buffer.reader.sql_reader import SQLReader @@ -22,6 +23,7 @@ def test_create_sql_buffer(self) -> None: algorithm_type=AlgorithmType.PPO, path=f"sqlite:///{db_path}", storage_type=StorageType.SQL, + wrap_in_ray=True, ) config = BufferConfig( max_retry_times=3, @@ -45,3 +47,5 @@ def test_create_sql_buffer(self) -> None: for _ in range(total_num // read_batch_size): exps = sql_reader.read() self.assertEqual(len(exps), read_batch_size) + db_wrapper = ray.get_actor("sql-test_buffer") + self.assertIsNotNone(db_wrapper) diff --git a/trinity/buffer/db_wrapper.py b/trinity/buffer/db_wrapper.py new file mode 100644 index 0000000000..977aaae493 --- /dev/null +++ b/trinity/buffer/db_wrapper.py @@ -0,0 +1,105 @@ +import time +from typing import List, Optional + +import ray +from sqlalchemy import asc, create_engine, desc +from sqlalchemy.exc import OperationalError +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import NullPool + +from trinity.buffer.schema import Base, create_dynamic_table +from trinity.buffer.utils import retry_session +from trinity.common.config import BufferConfig, StorageConfig +from trinity.common.constants import ReadStrategy +from trinity.utils.log import get_logger + + +class DBWrapper: + """ + A wrapper of a SQL database. + + If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as a Ray Actor, + and provide a remote interface to the local database. + + For databases that do not support multi-processing read/write (e.g. sqlite, duckdb), we + recommend setting `wrap_in_ray` to `True` + """ + + def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: + self.logger = get_logger(__name__) + self.engine = create_engine(storage_config.path, poolclass=NullPool) + self.table_model_cls = create_dynamic_table( + storage_config.algorithm_type, storage_config.name + ) + + try: + Base.metadata.create_all(self.engine, checkfirst=True) + except OperationalError: + self.logger.warning("Failed to create database, assuming it already exists.") + + self.session = sessionmaker(bind=self.engine) + self.batch_size = config.read_batch_size + self.max_retry_times = config.max_retry_times + self.max_retry_interval = config.max_retry_interval + + @classmethod + def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): + if storage_config.wrap_in_ray: + return ( + ray.remote(cls) + .options( + name=f"sql-{storage_config.name}", + get_if_exists=True, + ) + .remote(storage_config, config) + ) + else: + return cls(storage_config, config) + + def write(self, data: list) -> None: + with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session: + experience_models = [self.table_model_cls.from_experience(exp) for exp in data] + session.add_all(experience_models) + + def read(self, strategy: Optional[ReadStrategy] = None) -> List: + if strategy is None: + strategy = ReadStrategy.LFU + + if strategy == ReadStrategy.LFU: + sortOrder = (asc(self.table_model_cls.consumed), desc(self.table_model_cls.id)) + + elif strategy == ReadStrategy.LRU: + sortOrder = (desc(self.table_model_cls.id),) + + elif strategy == ReadStrategy.PRIORITY: + sortOrder = (desc(self.table_model_cls.priority), desc(self.table_model_cls.id)) + + else: + raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage") + + exp_list = [] + while len(exp_list) < self.batch_size: + if len(exp_list): + self.logger.info("waiting for experiences...") + time.sleep(1) + with retry_session( + self.session, self.max_retry_times, self.max_retry_interval + ) as session: + # get a batch of experiences from the database + experiences = ( + session.query(self.table_model_cls) + .filter(self.table_model_cls.reward.isnot(None)) + .order_by(*sortOrder) # TODO: very slow + .limit(self.batch_size - len(exp_list)) + .with_for_update() + .all() + ) + # update the consumed field + for exp in experiences: + exp.consumed += 1 + exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences]) + self.logger.info(f"get {len(exp_list)} experiences:") + self.logger.info(f"reward = {[exp.reward for exp in exp_list]}") + self.logger.info(f"first prompt_text = {exp_list[0].prompt_text}") + self.logger.info(f"first response_text = {exp_list[0].response_text}") + return exp_list diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py index a360182f07..8490c44506 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/queue.py @@ -23,6 +23,7 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: if storage_config.path is not None and len(storage_config.path) > 0: sql_config = deepcopy(storage_config) sql_config.storage_type = StorageType.SQL + sql_config.wrap_in_ray = False self.sql_writer = SQLWriter(sql_config, self.config) else: self.sql_writer = None diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index ffd013d4ef..3b26014fc4 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -16,13 +16,13 @@ class QueueReader(BufferReader): """Reader of the Queue buffer.""" - def __init__(self, meta: StorageConfig, config: BufferConfig): - assert meta.storage_type == StorageType.QUEUE + def __init__(self, storage_config: StorageConfig, config: BufferConfig): + assert storage_config.storage_type == StorageType.QUEUE self.config = config self.queue = QueueActor.options( - name=f"queue-{meta.name}", + name=f"queue-{storage_config.name}", get_if_exists=True, - ).remote(meta, config) + ).remote(storage_config, config) def read(self, strategy: Optional[ReadStrategy] = None) -> List: if strategy is not None and strategy != ReadStrategy.FIFO: diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py index 4da2920816..dcd9d942bb 100644 --- a/trinity/buffer/reader/sql_reader.py +++ b/trinity/buffer/reader/sql_reader.py @@ -1,21 +1,13 @@ """Reader of the SQL buffer.""" -import time from typing import List, Optional -from sqlalchemy import asc, create_engine, desc -from sqlalchemy.exc import OperationalError -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import NullPool +import ray from trinity.buffer.buffer_reader import BufferReader -from trinity.buffer.schema import Base, create_dynamic_table -from trinity.buffer.utils import retry_session +from trinity.buffer.db_wrapper import DBWrapper from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import ReadStrategy, StorageType -from trinity.utils.log import get_logger - -logger = get_logger(__name__) class SQLReader(BufferReader): @@ -23,57 +15,11 @@ class SQLReader(BufferReader): def __init__(self, meta: StorageConfig, config: BufferConfig) -> None: assert meta.storage_type == StorageType.SQL - self.engine = create_engine(meta.path, poolclass=NullPool) - - self.table_model_cls = create_dynamic_table(meta.algorithm_type, meta.name) - try: - Base.metadata.create_all(self.engine, checkfirst=True) - except OperationalError: - logger.warning("Failed to create database, assuming it already exists.") - self.session = sessionmaker(bind=self.engine) - self.batch_size = config.read_batch_size - self.max_retry_times = config.max_retry_times - self.max_retry_interval = config.max_retry_interval + self.wrap_in_ray = meta.wrap_in_ray + self.db_wrapper = DBWrapper.get_wrapper(meta, config) def read(self, strategy: Optional[ReadStrategy] = None) -> List: - if strategy is None: - strategy = ReadStrategy.LFU - - if strategy == ReadStrategy.LFU: - sortOrder = (asc(self.table_model_cls.consumed), desc(self.table_model_cls.id)) - - elif strategy == ReadStrategy.LRU: - sortOrder = (desc(self.table_model_cls.id),) - - elif strategy == ReadStrategy.PRIORITY: - sortOrder = (desc(self.table_model_cls.priority), desc(self.table_model_cls.id)) - + if self.wrap_in_ray: + return ray.get(self.db_wrapper.read.remote(strategy)) else: - raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage") - - exp_list = [] - while len(exp_list) < self.batch_size: - if len(exp_list): - logger.info("waiting for experiences...") - time.sleep(1) - with retry_session( - self.session, self.max_retry_times, self.max_retry_interval - ) as session: - # get a batch of experiences from the database - experiences = ( - session.query(self.table_model_cls) - .filter(self.table_model_cls.reward.isnot(None)) - .order_by(*sortOrder) # TODO: very slow - .limit(self.batch_size - len(exp_list)) - .with_for_update() - .all() - ) - # update the consumed field - for exp in experiences: - exp.consumed += 1 - exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences]) - logger.info(f"get {len(exp_list)} experiences:") - logger.info(f"reward = {[exp.reward for exp in exp_list]}") - logger.info(f"first prompt_text = {exp_list[0].prompt_text}") - logger.info(f"first response_text = {exp_list[0].response_text}") - return exp_list + return self.db_wrapper.read(strategy) diff --git a/trinity/buffer/writer/sql_writer.py b/trinity/buffer/writer/sql_writer.py index 7464064037..113ba5773d 100644 --- a/trinity/buffer/writer/sql_writer.py +++ b/trinity/buffer/writer/sql_writer.py @@ -1,18 +1,11 @@ """Writer of the SQL buffer.""" -from sqlalchemy import create_engine -from sqlalchemy.exc import OperationalError -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import NullPool +import ray from trinity.buffer.buffer_writer import BufferWriter -from trinity.buffer.schema import Base, create_dynamic_table -from trinity.buffer.utils import retry_session +from trinity.buffer.db_wrapper import DBWrapper from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import StorageType -from trinity.utils.log import get_logger - -logger = get_logger(__name__) class SQLWriter(BufferWriter): @@ -22,24 +15,15 @@ def __init__(self, meta: StorageConfig, config: BufferConfig) -> None: assert meta.storage_type == StorageType.SQL # we only support write RFT algorithm buffer for now # TODO: support other algorithms - assert meta.algorithm_type.is_rft, "Only RFT buffer is supported for writing." - self.engine = create_engine(meta.path, poolclass=NullPool) - self.table_model_cls = create_dynamic_table(meta.algorithm_type, meta.name) - - try: - Base.metadata.create_all(self.engine, checkfirst=True) - except OperationalError: - logger.warning("Failed to create database, assuming it already exists.") - - self.session = sessionmaker(bind=self.engine) - self.batch_size = config.read_batch_size - self.max_retry_times = config.max_retry_times - self.max_retry_interval = config.max_retry_interval + assert meta.algorithm_type.is_rft(), "Only RFT buffer is supported for writing." + self.wrap_in_ray = meta.wrap_in_ray + self.db_wrapper = DBWrapper.get_wrapper(meta, config) def write(self, data: list) -> None: - with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session: - experience_models = [self.table_model_cls.from_experience(exp) for exp in data] - session.add_all(experience_models) + if self.wrap_in_ray: + ray.get(self.db_wrapper.write.remote(data)) + else: + self.db_wrapper.write(data) def finish(self) -> None: # TODO: implement this diff --git a/trinity/common/config.py b/trinity/common/config.py index 3feec8e2ea..02215cbd6b 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -77,6 +77,9 @@ class StorageConfig: format: FormatConfig = field(default_factory=FormatConfig) index: int = 0 + # used for StorageType.SQL + wrap_in_ray: bool = True + # used for rollout tasks default_workflow_type: Optional[str] = None default_reward_fn_type: Optional[str] = None