Skip to content

Commit 49ff522

Browse files
authored
Keep SQL Experience Buffer behavior consistent with previous versions (#248)
1 parent d6d3be4 commit 49ff522

File tree

12 files changed

+178
-55
lines changed

12 files changed

+178
-55
lines changed

tests/buffer/experience_storage_test.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
import time
66

77
import torch
8+
from parameterized import parameterized
89

910
from tests.tools import RayUnittestBaseAysnc
1011
from trinity.buffer.reader.sql_reader import SQLReader
1112
from trinity.buffer.writer.sql_writer import SQLWriter
1213
from trinity.common.config import BufferConfig, StorageConfig
1314
from trinity.common.constants import StorageType
14-
from trinity.common.experience import Experience
15+
from trinity.common.experience import EID, Experience
1516

1617
DB_PATH = os.path.join(os.path.dirname(__file__), "test.db")
1718

@@ -28,10 +29,11 @@ def setUp(self):
2829
if os.path.exists(DB_PATH):
2930
os.remove(DB_PATH)
3031

31-
async def test_sql_storage(self):
32+
@parameterized.expand([("sft",), ("dpo",)])
33+
async def test_sql_storage(self, schema_type):
3234
meta = StorageConfig(
3335
name="test_storage",
34-
schema_type="experience",
36+
schema_type=schema_type,
3537
storage_type=StorageType.SQL,
3638
max_read_timeout=3,
3739
path=f"sqlite:///{DB_PATH}",
@@ -49,8 +51,6 @@ async def test_sql_storage(self):
4951
)
5052
for i in range(1, self.put_batch_size + 1)
5153
]
52-
for exp in exps:
53-
exp.info = {"model_version": 0, "use_count": 0}
5454
for _ in range(self.total_num // self.put_batch_size):
5555
await writer.write_async(exps)
5656
for _ in range(self.total_num // self.train_batch_size):
@@ -88,3 +88,49 @@ def thread_read(reader, result_queue):
8888
value = cursor.execute("SELECT COUNT(*) FROM test_storage;").fetchall()
8989
self.assertEqual(value[0][0], self.total_num + self.put_batch_size * 2)
9090
self.assertRaises(StopIteration, reader.read, batch_size=1)
91+
92+
async def test_sql_experience_buffer(self):
93+
meta = StorageConfig(
94+
name="test_storage",
95+
schema_type="experience",
96+
storage_type=StorageType.SQL,
97+
max_read_timeout=3,
98+
path=f"sqlite:///{DB_PATH}",
99+
)
100+
writer = SQLWriter(meta, self.config)
101+
reader = SQLReader(meta, self.config)
102+
self.assertEqual(await writer.acquire(), 1)
103+
for idx in range(self.total_num // self.put_batch_size):
104+
exps = [
105+
Experience(
106+
eid=EID(task=idx * self.put_batch_size + i),
107+
tokens=torch.tensor([float(j) for j in range(i + 1)]),
108+
prompt_length=i,
109+
reward=float(i),
110+
logprobs=torch.tensor([0.1]),
111+
)
112+
for i in range(1, self.put_batch_size + 1)
113+
]
114+
await writer.write_async(exps)
115+
cnt = self.total_num
116+
for _ in range(self.total_num // self.train_batch_size):
117+
exps = reader.read()
118+
self.assertEqual(len(exps), self.train_batch_size)
119+
for exp in exps:
120+
self.assertEqual(exp.eid.task, cnt)
121+
cnt -= 1
122+
123+
# experience buffer support experience reuse
124+
cnt = self.total_num
125+
for _ in range(self.total_num // self.train_batch_size):
126+
exps = reader.read()
127+
self.assertEqual(len(exps), self.train_batch_size)
128+
for exp in exps:
129+
self.assertEqual(exp.eid.task, cnt)
130+
cnt -= 1
131+
self.assertEqual(await writer.release(), 0)
132+
133+
conn = sqlite3.connect(DB_PATH)
134+
cursor = conn.cursor()
135+
value = cursor.execute("SELECT COUNT(*) FROM test_storage;").fetchall()
136+
self.assertEqual(value[0][0], self.total_num)

tests/trainer/trainer_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
SyncStyle,
3131
)
3232
from trinity.common.models.utils import get_checkpoint_dir_with_step_num
33-
from trinity.manager.manager import CacheManager
33+
from trinity.manager.state_manager import StateManager
3434

3535

3636
class BaseTrainerCase(RayUnittestBase):
@@ -266,7 +266,7 @@ def test_trainer(self):
266266
self.config.buffer.trainer_input.experience_buffer = StorageConfig(
267267
name="test_sql_storage",
268268
max_read_timeout=20,
269-
storage_type=StorageType.SQL,
269+
storage_type=StorageType.QUEUE,
270270
max_retry_times=10,
271271
)
272272
self.config.check_and_update()
@@ -516,10 +516,10 @@ def test_fully_async_mode(self):
516516
rollout_metrics = parser.metric_list("rollout")
517517
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
518518
# check the checkpoint
519-
explorer1_cache = CacheManager(explorer1_config)
519+
explorer1_cache = StateManager(explorer1_config)
520520
cache = explorer1_cache.load_explorer()
521521
self.assertEqual(cache["latest_iteration"], 4)
522-
explorer2_cache = CacheManager(explorer2_config)
522+
explorer2_cache = StateManager(explorer2_config)
523523
cache = explorer2_cache.load_explorer()
524524
self.assertEqual(cache["latest_iteration"], 4)
525525
# check the lastest checkpoint

trinity/buffer/pipelines/experience_pipeline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ async def process(self, exps: List[Experience]) -> Dict:
135135
return result_metrics
136136

137137
async def close(self) -> None:
138-
await self.output.release()
138+
try:
139+
await self.output.release()
140+
except Exception as e:
141+
self.logger.error(f"Failed to release output buffer: {e}")
139142
for operator in self.operators:
140143
operator.close()

trinity/buffer/schema/sql_schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class ExperienceModel(Base): # type: ignore
4545
reward = Column(Float, nullable=True)
4646
# serialized experience object
4747
experience_bytes = Column(LargeBinary, nullable=True)
48+
consumed = Column(Integer, default=0, index=True)
4849

4950
def to_experience(self) -> Experience:
5051
"""Load the experience from the database."""

trinity/buffer/storage/sql.py

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import ray
88
from datasets import Dataset
9-
from sqlalchemy import asc
9+
from sqlalchemy import asc, desc
1010
from sqlalchemy.orm import sessionmaker
1111

1212
from trinity.buffer.schema import init_engine
@@ -88,29 +88,33 @@ def release(self) -> int:
8888

8989

9090
class SQLExperienceStorage(SQLStorage):
91+
"""Used as trainer input."""
92+
9193
def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
9294
super().__init__(storage_config, config)
9395
self.batch_size = config.train_batch_size
9496
self.max_timeout = storage_config.max_read_timeout
97+
# TODO: optimize the following logic
98+
if storage_config.schema_type == "experience":
99+
# NOTE: consistent with the old version of experience buffer
100+
self._read_method = self._read_priority
101+
else:
102+
# SFT / DPO uses FIFO style
103+
self._read_method = self._read_fifo
95104

96105
def write(self, data: List[Experience]) -> None:
97106
with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session:
98107
experience_models = [self.table_model_cls.from_experience(exp) for exp in data]
99108
session.add_all(experience_models)
109+
self.logger.info(f"Write {len(experience_models)} experiences to SQL storage.")
100110

101-
def read(self, batch_size: Optional[int] = None) -> List[Experience]:
102-
if self.stopped:
103-
raise StopIteration()
104-
111+
def _read_fifo(self, batch_size: int) -> List[Experience]:
112+
"""Read experiences in FIFO order."""
105113
exp_list = []
106-
batch_size = batch_size or self.batch_size # type: ignore
107114
start_time = time.time()
108115
while len(exp_list) < batch_size:
109116
if self.stopped:
110117
raise StopIteration()
111-
if len(exp_list):
112-
self.logger.info(f"Waiting for {batch_size - len(exp_list)} more experiences...")
113-
time.sleep(1)
114118
if time.time() - start_time > self.max_timeout:
115119
self.logger.warning(
116120
f"Max read timeout reached ({self.max_timeout} s), only get {len(exp_list)} experiences, stopping..."
@@ -131,8 +135,61 @@ def read(self, batch_size: Optional[int] = None) -> List[Experience]:
131135
self.offset = experiences[-1].id
132136
start_time = time.time()
133137
exp_list.extend([self.table_model_cls.to_experience(exp) for exp in experiences])
138+
if len(exp_list) < batch_size:
139+
self.logger.info(f"Waiting for {batch_size - len(exp_list)} more experiences...")
140+
time.sleep(1)
134141
return exp_list
135142

143+
def _read_priority(self, batch_size: int) -> List[Experience]:
144+
exp_list = []
145+
start_time = time.time()
146+
latest_size = 0
147+
while latest_size < batch_size:
148+
if self.stopped:
149+
raise StopIteration()
150+
if time.time() - start_time > self.max_timeout:
151+
self.logger.warning(
152+
f"Max read timeout reached ({self.max_timeout} s), only get {latest_size} experiences, stopping..."
153+
)
154+
raise StopIteration()
155+
with retry_session(
156+
self.session, self.max_retry_times, self.max_retry_interval
157+
) as session:
158+
experiences = (
159+
session.query(self.table_model_cls)
160+
.order_by(asc(self.table_model_cls.consumed), desc(self.table_model_cls.id))
161+
.limit(batch_size)
162+
.with_for_update()
163+
.all()
164+
)
165+
if len(experiences) != batch_size:
166+
if latest_size != len(experiences):
167+
latest_size = len(experiences)
168+
start_time = time.time()
169+
else:
170+
ids = [exp.id for exp in experiences]
171+
session.query(self.table_model_cls).filter(
172+
self.table_model_cls.id.in_(ids)
173+
).update(
174+
{self.table_model_cls.consumed: self.table_model_cls.consumed + 1},
175+
synchronize_session=False,
176+
)
177+
exp_list.extend(
178+
[self.table_model_cls.to_experience(exp) for exp in experiences]
179+
)
180+
break
181+
182+
self.logger.info(f"Waiting for {batch_size - len(exp_list)} more experiences...")
183+
time.sleep(1)
184+
return exp_list
185+
186+
def read(self, batch_size: Optional[int] = None) -> List[Experience]:
187+
if self.stopped:
188+
raise StopIteration()
189+
190+
batch_size = batch_size or self.batch_size
191+
return self._read_method(batch_size)
192+
136193
@classmethod
137194
def load_from_dataset(
138195
cls, dataset: Dataset, storage_config: StorageConfig, config: BufferConfig
@@ -158,6 +215,8 @@ def load_from_dataset(
158215

159216

160217
class SQLTaskStorage(SQLStorage):
218+
"""Used as explorer input."""
219+
161220
def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
162221
super().__init__(storage_config, config)
163222
self.batch_size = config.batch_size

trinity/buffer/writer/sql_writer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,19 @@ def write(self, data: list) -> None:
2525

2626
async def write_async(self, data):
2727
if self.wrap_in_ray:
28-
await self.db_wrapper.write.remote(data)
28+
ray.get(self.db_wrapper.write.remote(data))
2929
else:
3030
self.db_wrapper.write(data)
3131

3232
async def acquire(self) -> int:
3333
if self.wrap_in_ray:
34-
return await self.db_wrapper.acquire.remote()
34+
return ray.get(self.db_wrapper.acquire.remote())
3535
else:
3636
return 0
3737

3838
async def release(self) -> int:
3939
if self.wrap_in_ray:
40-
return await self.db_wrapper.release.remote()
40+
return ray.get(self.db_wrapper.release.remote())
4141
else:
4242
self.db_wrapper.release()
4343
return 0

trinity/common/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,16 @@ class StorageConfig:
8181
path: Optional[str] = None
8282
repeat_times: Optional[int] = None
8383

84+
# For continuing training
85+
index: int = 0
86+
8487
# used for multi-modal data
8588
mm_data_kwargs: dict = field(default_factory=dict)
8689

8790
# used for StorageType.FILE
8891
split: str = "train"
8992
subset_name: Optional[str] = None
9093
format: FormatConfig = field(default_factory=FormatConfig)
91-
index: int = 0
9294

9395
# used for StorageType.QUEUE
9496
capacity: int = 10000

trinity/explorer/explorer.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
)
2626
from trinity.common.models import create_inference_models
2727
from trinity.explorer.scheduler import Scheduler
28-
from trinity.manager.manager import CacheManager
28+
from trinity.manager.state_manager import StateManager
2929
from trinity.manager.synchronizer import Synchronizer
3030
from trinity.utils.log import get_logger
3131
from trinity.utils.monitor import MONITOR, gather_metrics
@@ -38,16 +38,16 @@ class Explorer:
3838
def __init__(self, config: Config):
3939
self.logger = get_logger(config.explorer.name, in_ray_actor=True)
4040
load_plugins()
41-
self.cache = CacheManager(config)
42-
explorer_meta = self.cache.load_explorer()
43-
self.explore_step_num = explorer_meta.get("latest_iteration", 0)
41+
self.state = StateManager(config)
42+
explorer_state = self.state.load_explorer()
43+
self.explore_step_num = explorer_state.get("latest_iteration", 0)
4444
self.last_sync_step = self.explore_step_num if self.explore_step_num > 0 else -1
4545
self.synchronizer = Synchronizer.get_actor(config)
4646
self.config = config
4747
self.algorithm_manager = AlgorithmManager(config)
4848
self.models, self.auxiliary_models = create_inference_models(config)
4949
self.experience_pipeline = self._init_experience_pipeline()
50-
self.config.buffer.explorer_input.taskset.index = explorer_meta.get("latest_task_index", 0)
50+
self.config.buffer.explorer_input.taskset.index = explorer_state.get("latest_task_index", 0)
5151
self.taskset = get_buffer_reader(
5252
self.config.buffer.explorer_input.taskset, self.config.buffer
5353
)
@@ -326,7 +326,7 @@ async def save_checkpoint(self, sync_weight: bool = False) -> None:
326326
)
327327

328328
# save explore checkpoint
329-
self.cache.save_explorer(
329+
self.state.save_explorer(
330330
current_step=self.explore_step_num,
331331
current_task_index=self.explore_step_num * self.config.buffer.batch_size,
332332
)
@@ -345,7 +345,6 @@ async def _finish_steps(self, start_step: int, end_step: int, model_version: int
345345
async def _finish_explore_step(self, step: int, model_version: int) -> None:
346346
statuses, exps = await self.scheduler.get_results(batch_id=step)
347347
metric = {"rollout/model_version": model_version}
348-
# TODO: avoid blocking
349348
pipeline_metrics = await self.experience_pipeline.process.remote(exps)
350349
metric.update(pipeline_metrics)
351350
if statuses:

trinity/manager/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from trinity.manager.manager import CacheManager
1+
from trinity.manager.state_manager import StateManager
22
from trinity.manager.synchronizer import Synchronizer
33

44
__all__ = [
5-
"CacheManager",
5+
"StateManager",
66
"Synchronizer",
77
]

0 commit comments

Comments
 (0)