Skip to content

Commit 77244d8

Browse files
authored
Wrap database in ray actor (#70)
1 parent 178c992 commit 77244d8

File tree

7 files changed

+133
-90
lines changed

7 files changed

+133
-90
lines changed

tests/buffer/sql_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import unittest
33

4+
import ray
45
import torch
56

67
from trinity.buffer.reader.sql_reader import SQLReader
@@ -22,6 +23,7 @@ def test_create_sql_buffer(self) -> None:
2223
algorithm_type=AlgorithmType.PPO,
2324
path=f"sqlite:///{db_path}",
2425
storage_type=StorageType.SQL,
26+
wrap_in_ray=True,
2527
)
2628
config = BufferConfig(
2729
max_retry_times=3,
@@ -45,3 +47,5 @@ def test_create_sql_buffer(self) -> None:
4547
for _ in range(total_num // read_batch_size):
4648
exps = sql_reader.read()
4749
self.assertEqual(len(exps), read_batch_size)
50+
db_wrapper = ray.get_actor("sql-test_buffer")
51+
self.assertIsNotNone(db_wrapper)

trinity/buffer/db_wrapper.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import time
2+
from typing import List, Optional
3+
4+
import ray
5+
from sqlalchemy import asc, create_engine, desc
6+
from sqlalchemy.exc import OperationalError
7+
from sqlalchemy.orm import sessionmaker
8+
from sqlalchemy.pool import NullPool
9+
10+
from trinity.buffer.schema import Base, create_dynamic_table
11+
from trinity.buffer.utils import retry_session
12+
from trinity.common.config import BufferConfig, StorageConfig
13+
from trinity.common.constants import ReadStrategy
14+
from trinity.utils.log import get_logger
15+
16+
17+
class DBWrapper:
18+
"""
19+
A wrapper of a SQL database.
20+
21+
If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as a Ray Actor,
22+
and provide a remote interface to the local database.
23+
24+
For databases that do not support multi-processing read/write (e.g. sqlite, duckdb), we
25+
recommend setting `wrap_in_ray` to `True`
26+
"""
27+
28+
def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
29+
self.logger = get_logger(__name__)
30+
self.engine = create_engine(storage_config.path, poolclass=NullPool)
31+
self.table_model_cls = create_dynamic_table(
32+
storage_config.algorithm_type, storage_config.name
33+
)
34+
35+
try:
36+
Base.metadata.create_all(self.engine, checkfirst=True)
37+
except OperationalError:
38+
self.logger.warning("Failed to create database, assuming it already exists.")
39+
40+
self.session = sessionmaker(bind=self.engine)
41+
self.batch_size = config.read_batch_size
42+
self.max_retry_times = config.max_retry_times
43+
self.max_retry_interval = config.max_retry_interval
44+
45+
@classmethod
46+
def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig):
47+
if storage_config.wrap_in_ray:
48+
return (
49+
ray.remote(cls)
50+
.options(
51+
name=f"sql-{storage_config.name}",
52+
get_if_exists=True,
53+
)
54+
.remote(storage_config, config)
55+
)
56+
else:
57+
return cls(storage_config, config)
58+
59+
def write(self, data: list) -> None:
60+
with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session:
61+
experience_models = [self.table_model_cls.from_experience(exp) for exp in data]
62+
session.add_all(experience_models)
63+
64+
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
65+
if strategy is None:
66+
strategy = ReadStrategy.LFU
67+
68+
if strategy == ReadStrategy.LFU:
69+
sortOrder = (asc(self.table_model_cls.consumed), desc(self.table_model_cls.id))
70+
71+
elif strategy == ReadStrategy.LRU:
72+
sortOrder = (desc(self.table_model_cls.id),)
73+
74+
elif strategy == ReadStrategy.PRIORITY:
75+
sortOrder = (desc(self.table_model_cls.priority), desc(self.table_model_cls.id))
76+
77+
else:
78+
raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage")
79+
80+
exp_list = []
81+
while len(exp_list) < self.batch_size:
82+
if len(exp_list):
83+
self.logger.info("waiting for experiences...")
84+
time.sleep(1)
85+
with retry_session(
86+
self.session, self.max_retry_times, self.max_retry_interval
87+
) as session:
88+
# get a batch of experiences from the database
89+
experiences = (
90+
session.query(self.table_model_cls)
91+
.filter(self.table_model_cls.reward.isnot(None))
92+
.order_by(*sortOrder) # TODO: very slow
93+
.limit(self.batch_size - len(exp_list))
94+
.with_for_update()
95+
.all()
96+
)
97+
# update the consumed field
98+
for exp in experiences:
99+
exp.consumed += 1
100+
exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences])
101+
self.logger.info(f"get {len(exp_list)} experiences:")
102+
self.logger.info(f"reward = {[exp.reward for exp in exp_list]}")
103+
self.logger.info(f"first prompt_text = {exp_list[0].prompt_text}")
104+
self.logger.info(f"first response_text = {exp_list[0].response_text}")
105+
return exp_list

trinity/buffer/queue.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
2323
if storage_config.path is not None and len(storage_config.path) > 0:
2424
sql_config = deepcopy(storage_config)
2525
sql_config.storage_type = StorageType.SQL
26+
sql_config.wrap_in_ray = False
2627
self.sql_writer = SQLWriter(sql_config, self.config)
2728
else:
2829
self.sql_writer = None

trinity/buffer/reader/queue_reader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
class QueueReader(BufferReader):
1717
"""Reader of the Queue buffer."""
1818

19-
def __init__(self, meta: StorageConfig, config: BufferConfig):
20-
assert meta.storage_type == StorageType.QUEUE
19+
def __init__(self, storage_config: StorageConfig, config: BufferConfig):
20+
assert storage_config.storage_type == StorageType.QUEUE
2121
self.config = config
2222
self.queue = QueueActor.options(
23-
name=f"queue-{meta.name}",
23+
name=f"queue-{storage_config.name}",
2424
get_if_exists=True,
25-
).remote(meta, config)
25+
).remote(storage_config, config)
2626

2727
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
2828
if strategy is not None and strategy != ReadStrategy.FIFO:
Lines changed: 7 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,25 @@
11
"""Reader of the SQL buffer."""
22

3-
import time
43
from typing import List, Optional
54

6-
from sqlalchemy import asc, create_engine, desc
7-
from sqlalchemy.exc import OperationalError
8-
from sqlalchemy.orm import sessionmaker
9-
from sqlalchemy.pool import NullPool
5+
import ray
106

117
from trinity.buffer.buffer_reader import BufferReader
12-
from trinity.buffer.schema import Base, create_dynamic_table
13-
from trinity.buffer.utils import retry_session
8+
from trinity.buffer.db_wrapper import DBWrapper
149
from trinity.common.config import BufferConfig, StorageConfig
1510
from trinity.common.constants import ReadStrategy, StorageType
16-
from trinity.utils.log import get_logger
17-
18-
logger = get_logger(__name__)
1911

2012

2113
class SQLReader(BufferReader):
2214
"""Reader of the SQL buffer."""
2315

2416
def __init__(self, meta: StorageConfig, config: BufferConfig) -> None:
2517
assert meta.storage_type == StorageType.SQL
26-
self.engine = create_engine(meta.path, poolclass=NullPool)
27-
28-
self.table_model_cls = create_dynamic_table(meta.algorithm_type, meta.name)
29-
try:
30-
Base.metadata.create_all(self.engine, checkfirst=True)
31-
except OperationalError:
32-
logger.warning("Failed to create database, assuming it already exists.")
33-
self.session = sessionmaker(bind=self.engine)
34-
self.batch_size = config.read_batch_size
35-
self.max_retry_times = config.max_retry_times
36-
self.max_retry_interval = config.max_retry_interval
18+
self.wrap_in_ray = meta.wrap_in_ray
19+
self.db_wrapper = DBWrapper.get_wrapper(meta, config)
3720

3821
def read(self, strategy: Optional[ReadStrategy] = None) -> List:
39-
if strategy is None:
40-
strategy = ReadStrategy.LFU
41-
42-
if strategy == ReadStrategy.LFU:
43-
sortOrder = (asc(self.table_model_cls.consumed), desc(self.table_model_cls.id))
44-
45-
elif strategy == ReadStrategy.LRU:
46-
sortOrder = (desc(self.table_model_cls.id),)
47-
48-
elif strategy == ReadStrategy.PRIORITY:
49-
sortOrder = (desc(self.table_model_cls.priority), desc(self.table_model_cls.id))
50-
22+
if self.wrap_in_ray:
23+
return ray.get(self.db_wrapper.read.remote(strategy))
5124
else:
52-
raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage")
53-
54-
exp_list = []
55-
while len(exp_list) < self.batch_size:
56-
if len(exp_list):
57-
logger.info("waiting for experiences...")
58-
time.sleep(1)
59-
with retry_session(
60-
self.session, self.max_retry_times, self.max_retry_interval
61-
) as session:
62-
# get a batch of experiences from the database
63-
experiences = (
64-
session.query(self.table_model_cls)
65-
.filter(self.table_model_cls.reward.isnot(None))
66-
.order_by(*sortOrder) # TODO: very slow
67-
.limit(self.batch_size - len(exp_list))
68-
.with_for_update()
69-
.all()
70-
)
71-
# update the consumed field
72-
for exp in experiences:
73-
exp.consumed += 1
74-
exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences])
75-
logger.info(f"get {len(exp_list)} experiences:")
76-
logger.info(f"reward = {[exp.reward for exp in exp_list]}")
77-
logger.info(f"first prompt_text = {exp_list[0].prompt_text}")
78-
logger.info(f"first response_text = {exp_list[0].response_text}")
79-
return exp_list
25+
return self.db_wrapper.read(strategy)

trinity/buffer/writer/sql_writer.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,11 @@
11
"""Writer of the SQL buffer."""
22

3-
from sqlalchemy import create_engine
4-
from sqlalchemy.exc import OperationalError
5-
from sqlalchemy.orm import sessionmaker
6-
from sqlalchemy.pool import NullPool
3+
import ray
74

85
from trinity.buffer.buffer_writer import BufferWriter
9-
from trinity.buffer.schema import Base, create_dynamic_table
10-
from trinity.buffer.utils import retry_session
6+
from trinity.buffer.db_wrapper import DBWrapper
117
from trinity.common.config import BufferConfig, StorageConfig
128
from trinity.common.constants import StorageType
13-
from trinity.utils.log import get_logger
14-
15-
logger = get_logger(__name__)
169

1710

1811
class SQLWriter(BufferWriter):
@@ -22,24 +15,15 @@ def __init__(self, meta: StorageConfig, config: BufferConfig) -> None:
2215
assert meta.storage_type == StorageType.SQL
2316
# we only support write RFT algorithm buffer for now
2417
# TODO: support other algorithms
25-
assert meta.algorithm_type.is_rft, "Only RFT buffer is supported for writing."
26-
self.engine = create_engine(meta.path, poolclass=NullPool)
27-
self.table_model_cls = create_dynamic_table(meta.algorithm_type, meta.name)
28-
29-
try:
30-
Base.metadata.create_all(self.engine, checkfirst=True)
31-
except OperationalError:
32-
logger.warning("Failed to create database, assuming it already exists.")
33-
34-
self.session = sessionmaker(bind=self.engine)
35-
self.batch_size = config.read_batch_size
36-
self.max_retry_times = config.max_retry_times
37-
self.max_retry_interval = config.max_retry_interval
18+
assert meta.algorithm_type.is_rft(), "Only RFT buffer is supported for writing."
19+
self.wrap_in_ray = meta.wrap_in_ray
20+
self.db_wrapper = DBWrapper.get_wrapper(meta, config)
3821

3922
def write(self, data: list) -> None:
40-
with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session:
41-
experience_models = [self.table_model_cls.from_experience(exp) for exp in data]
42-
session.add_all(experience_models)
23+
if self.wrap_in_ray:
24+
ray.get(self.db_wrapper.write.remote(data))
25+
else:
26+
self.db_wrapper.write(data)
4327

4428
def finish(self) -> None:
4529
# TODO: implement this

trinity/common/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ class StorageConfig:
7777
format: FormatConfig = field(default_factory=FormatConfig)
7878
index: int = 0
7979

80+
# used for StorageType.SQL
81+
wrap_in_ray: bool = True
82+
8083
# used for rollout tasks
8184
default_workflow_type: Optional[str] = None
8285
default_reward_fn_type: Optional[str] = None

0 commit comments

Comments
 (0)