Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tests/ray/test_replay_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import unittest
import asyncio
from xtuner.v1.rl.base.replay_buffer import ReplayBuffer, StorageIndices, FIFOStorageBackend, StalenessStorageBackend
from xtuner.v1.data_proto.rl_data import RolloutState, Status
Comment on lines +2 to +4
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests/ray/test_replay_buffer.py has several unused imports (asyncio, StorageIndices, RolloutState). Cleaning these up avoids confusion about what the tests actually exercise.

Suggested change
import asyncio
from xtuner.v1.rl.base.replay_buffer import ReplayBuffer, StorageIndices, FIFOStorageBackend, StalenessStorageBackend
from xtuner.v1.data_proto.rl_data import RolloutState, Status
from xtuner.v1.rl.base.replay_buffer import ReplayBuffer, FIFOStorageBackend, StalenessStorageBackend
from xtuner.v1.data_proto.rl_data import Status

Copilot uses AI. Check for mistakes.

class MockState:
def __init__(self, id, staleness=0):
self.id = id
self.seq_staleness = staleness

class TestReplayBuffer(unittest.IsolatedAsyncioTestCase):
async def test_fifo_backend(self):
backend = FIFOStorageBackend()
buffer = ReplayBuffer(storage_backend=backend)
states = [MockState(i) for i in range(1, 4)]

await buffer.put(states, "task1", Status.COMPLETED)
res = await buffer.get(2, "task1", Status.COMPLETED)

self.assertEqual(len(res), 2)
self.assertEqual(res[0].id, 1)
self.assertEqual(res[1].id, 2)

async def test_staleness_priority(self):
backend = StalenessStorageBackend(min_staleness=0, max_staleness=5)
buffer = ReplayBuffer(storage_backend=backend)

s1 = MockState(id="low", staleness=1)
s5 = MockState(id="high", staleness=5)

await buffer.put([s1], "task1", Status.COMPLETED)
await buffer.put([s5], "task1", Status.COMPLETED)

res = await buffer.get(2, "task1", Status.COMPLETED)
self.assertEqual(res[0].id, "high")
self.assertEqual(res[1].id, "low")

async def test_multi_task(self):
buffer = ReplayBuffer()
await buffer.put([MockState(100)], "task_a", Status.COMPLETED)
await buffer.put([MockState(200)], "task_b", Status.COMPLETED)

res_a = await buffer.get(10, "task_a", Status.COMPLETED)
res_b = await buffer.get(10, "task_b", Status.COMPLETED)
self.assertEqual(len(res_a), 1)
self.assertEqual(res_a[0].id, 100)
self.assertEqual(len(res_b), 1)
self.assertEqual(res_b[0].id, 200)
222 changes: 222 additions & 0 deletions xtuner/v1/rl/base/replay_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
import asyncio
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from dataclasses import dataclass, field

from xtuner.v1.data_proto.rl_data import RolloutState, Status


@dataclass
class StorageIndices:
# 为不同存储后段提供统一的索引接口
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StorageIndices doc comment has a typo: “存储后段” should be “存储后端”.

Suggested change
# 为不同存储后段提供统一的索引接口
# 为不同存储后端提供统一的索引接口

Copilot uses AI. Check for mistakes.
task_name: str | None = None
group_status: Status | None = None
tags: dict = field(default_factory=dict) # 非等于的条件则使用 scores_gt > 0.8

Comment on lines +14 to +15
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tags-based partitioning logic is part of the public ReplayBuffer.put/get API (via **kwargs), but there are no tests asserting that different tag values map to different storage partitions and don’t mix. Consider adding a small test that writes items with different tag combinations and verifies isolation.

Copilot uses AI. Check for mistakes.

class StorageBackend(ABC):
@abstractmethod
async def put(self, items: list[RolloutState], storage_indices: StorageIndices): ...
@abstractmethod
async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]: ...
@abstractmethod
def __len__(self): ...


class FIFOStorageBackend(StorageBackend):
# 普通的先进先出,用完就丢,不持久保存,目前同步应该就够用了
def __init__(self, limit: int = 0):
self.limit = limit
if limit > 0:
self._storage = defaultdict(lambda: deque(maxlen=limit))
else:
self._storage = defaultdict(deque)

async def put(self, items: list[RolloutState], storage_indices: StorageIndices):
indices = self._hash_storage_indices(storage_indices)
self._storage[indices].extend(items)

async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]:
indices = self._hash_storage_indices(storage_indices)
target_count = min(count, len(self._storage[indices]))
target_items = []
for _ in range(target_count):
target_items.append(self._storage[indices].popleft())
return target_items

def _hash_storage_indices(self, indices: StorageIndices) -> tuple:
base = (indices.task_name, indices.group_status)

if indices.tags:
sorted_tags = tuple(sorted(indices.tags.items()))
return base + sorted_tags
return base
Comment on lines +47 to +53
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_hash_storage_indices builds a tuple used as a dict key; if any tag value is unhashable (e.g., list/dict), this will raise TypeError at runtime. Since tags come from **kwargs, consider validating/coercing tag values to hashable types (e.g., str(value)/json.dumps) or restricting the accepted tag value types.

Copilot uses AI. Check for mistakes.

def __len__(self):
return sum(len(q) for q in self._storage.values())


class StalenessStorageBackend(StorageBackend):
# xtuner v1的异步的replay buffer的实现,同样不持久保存
# TODO(@duanyanhui): 还没实现completed/aborted/expired状态的切换,这个考虑下在哪里完成
def __init__(self, limit: int = 0, max_staleness: int = 0, min_staleness: int = 0):
self.limit = limit
self.max_staleness = max_staleness
self.min_staleness = min_staleness
self._storage = defaultdict(lambda: {i: deque() for i in range(min_staleness, max_staleness + 1)})
self._bucket_counts = defaultdict(int)
Comment on lines +62 to +67
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StalenessStorageBackend.__init__ accepts a limit parameter but it is never enforced (items can grow unbounded). Either implement eviction behavior consistent with FIFOStorageBackend(limit=...) or remove the parameter to avoid misleading API.

Copilot uses AI. Check for mistakes.

async def put(self, items: list[RolloutState], storage_indices: StorageIndices):
indices = self._hash_storage_indices(storage_indices)
group_seq_staleness = max([item.seq_staleness for item in items])
Comment on lines +70 to +71
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StalenessStorageBackend.put will crash if items is empty (max() on empty list). Also, if any seq_staleness falls outside [min_staleness, max_staleness] (defaults are both 0), self._storage[indices][group_seq_staleness] will raise KeyError. Consider explicitly handling empty input (no-op or clear error) and validating/clamping seq_staleness to the configured bucket range (or dynamically creating buckets).

Suggested change
indices = self._hash_storage_indices(storage_indices)
group_seq_staleness = max([item.seq_staleness for item in items])
# If there are no items, treat this as a no-op to avoid max() on an empty list.
if not items:
return
indices = self._hash_storage_indices(storage_indices)
group_seq_staleness = max(item.seq_staleness for item in items)
# Clamp staleness into the configured bucket range to avoid KeyError.
group_seq_staleness = max(self.min_staleness,
min(self.max_staleness, group_seq_staleness))

Copilot uses AI. Check for mistakes.
self._storage[indices][group_seq_staleness].extend(items)
self._bucket_counts[indices] += len(items)

async def get(self, count: int, storage_indices: StorageIndices) -> list[RolloutState]:
indices = self._hash_storage_indices(storage_indices)
if self._bucket_counts[indices] == 0:
return []

target_items = []
needed = count

for s in range(self.max_staleness, self.min_staleness - 1, -1):
if needed <= 0:
break
cur_bucket = self._storage[indices][s]
take = min(len(cur_bucket), needed)
for _ in range(take):
target_items.append(cur_bucket.popleft())
self._bucket_counts[indices] -= take
needed -= take
return target_items

def _hash_storage_indices(self, indices: StorageIndices) -> tuple:
base = (indices.task_name, indices.group_status)

if indices.tags:
sorted_tags = tuple(sorted(indices.tags.items()))
return base + sorted_tags
return base
Comment on lines +94 to +100
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as FIFO backend: _hash_storage_indices returns a tuple used as a dict key, so unhashable tag values (e.g., list/dict) will raise TypeError. Since this backend also accepts tags via StorageIndices, consider validating/coercing tag values or restricting tag types.

Copilot uses AI. Check for mistakes.

def __len__(self):
return sum(count for count in self._bucket_counts.values())


class PandasStorageBackend(StorageBackend):
def __init__(self, limit: int = 0):
raise NotImplementedError("PandasStorageBackend is under development and not yet implemented.")
import pandas as pd
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement is unreachable.

Copilot uses AI. Check for mistakes.

self._df = pd.DataFrame(columns=["task_name", "group_status", "data"])

def __len__(self): ...
async def put(self, items: list[RolloutState], indices: StorageIndices):
import pandas as pd

new_rows = []
base_info = {"task_name": indices.task_name, "group_status": indices.group_status, **indices.tags}

for item in items:
row = base_info.copy()
row["data"] = item
new_rows.append(row)

new_df = pd.DataFrame(new_rows)
self._df = pd.concat([self._df, new_df], ignore_index=True, sort=False)

def get(self, count: int, indices: StorageIndices) -> list[RolloutState]:
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PandasStorageBackend.get is defined as a synchronous method, but the StorageBackend interface defines async def get(...). Even though the class currently raises in __init__, keeping signatures consistent will prevent accidental misuse later and avoids confusing API expectations (callers will await this).

Suggested change
def get(self, count: int, indices: StorageIndices) -> list[RolloutState]:
async def get(self, count: int, indices: StorageIndices) -> list[RolloutState]:

Copilot uses AI. Check for mistakes.
if self._df.empty:
return []
mask = (self._df["task_name"] == indices.task_name) & (self._df["group_status"] == indices.group_status)
for key, value in indices.tags.items():
if key in self._df.columns:
mask &= self._df[key] == value
else:
return []
target_df = self._df[mask].head(count)
if target_df.empty:
return []
result = target_df["data"].tolist()
self._df.drop(target_df.index, inplace=True)
return result


class SQLStorageBackend(StorageBackend):
def __init__(self, db_path: str = ":memory:"):
raise NotImplementedError("SQLStorageBackend is under development and not yet implemented.")
self.db_path = db_path
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement is unreachable.

Copilot uses AI. Check for mistakes.
self._init_db()

def _init_db(self): ...
def _serialize_item(self, item: RolloutState) -> bytes: ...
def _deserialize_item(self, blob: bytes) -> RolloutState: ...
def __len__(self): ...

async def put(self, items: list[RolloutState], indices: StorageIndices):
import json
import sqlite3

rows = []
tags_json = json.dumps(indices.tags)

for item in items:
data_blob = self._serialize_item(item)
rows.append((indices.task_name, indices.group_status, tags_json, data_blob))

with sqlite3.connect(self.db_path) as conn:
conn.executemany(
"INSERT INTO replay_buffer (task_name, group_status, tags, data) VALUES (?, ?, ?, ?)", rows
)

async def get(self, count: int, indices: StorageIndices) -> list[RolloutState]:
import sqlite3

# 构建动态查询
query = "SELECT id, data FROM replay_buffer WHERE task_name = ? AND group_status = ?"
params = [indices.task_name, indices.group_status]

# SQLite 的 JSON 查询语法 (需要 SQLite 3.38+,如果是旧版本需要用 LIKE 模拟或不做 DB 级过滤)
# 这里演示简单的方法:如果在 Python 端过滤 tags 效率低,但在 SQL 端过滤 JSON 语法较复杂。
# 为了通用性,这里我只用 task 和 status 查出候选集,然后用 Python 过滤 Tags (如果 tags 很复杂建议把 tags 独立成列)
# 或者使用 JSON_EXTRACT (推荐)
for key, value in indices.tags.items():
# 注意:JSON 中数值和字符串的区别。这里假设 value 都是简单类型。
# $.key 取出对应的值
query += f" AND json_extract(tags, '$.{key}') = ?"
params.append(value)

query += f" LIMIT {count}"

Comment on lines +175 to +190
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SQLStorageBackend.get builds SQL using f-strings for both the JSON path ($.{key}) and LIMIT {count}. If key is user-controlled (it comes from indices.tags), this is a SQL injection risk once this backend is implemented. Prefer validating key against an allowlist/regex and using parameter binding for LIMIT (and avoid interpolating raw values into the query string).

Copilot uses AI. Check for mistakes.
results = []
ids_to_delete = []

with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(query, params)
rows = cursor.fetchall()

for row_id, data_blob in rows:
results.append(self._deserialize_item(data_blob))
ids_to_delete.append(row_id)

if ids_to_delete:
placeholders = ",".join("?" for _ in ids_to_delete)
conn.execute(f"DELETE FROM replay_buffer WHERE id IN ({placeholders})", ids_to_delete)

return results


class ReplayBuffer:
def __init__(self, storage_backend: StorageBackend = None):
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReplayBuffer.__init__ annotates storage_backend as StorageBackend but defaults it to None, which will fail type checking under mypy’s strict optional rules. Please change the annotation to StorageBackend | None (or Optional[StorageBackend]).

Suggested change
def __init__(self, storage_backend: StorageBackend = None):
def __init__(self, storage_backend: StorageBackend | None = None):

Copilot uses AI. Check for mistakes.
self._storage = FIFOStorageBackend() if storage_backend is None else storage_backend
self._lock = asyncio.Lock()

async def put(self, items: list[RolloutState], task_name: str, group_status: Status, **kwargs) -> None:
indices = StorageIndices(task_name=task_name, group_status=group_status, tags=kwargs)
async with self._lock:
await self._storage.put(items, indices)

async def get(self, batch_size: int, task_name: str, group_status: Status, **kwargs) -> list[RolloutState]:
indices = StorageIndices(task_name=task_name, group_status=group_status, tags=kwargs)
async with self._lock:
return await self._storage.get(batch_size, indices)
Loading