-
Notifications
You must be signed in to change notification settings - Fork 404
[ReplayBuffer] add ReplayBuffer with various StorageBackend #1490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: rl_design
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| import unittest | ||
| import asyncio | ||
| from xtuner.v1.rl.base.replay_buffer import ReplayBuffer, StorageIndices, FIFOStorageBackend, StalenessStorageBackend | ||
| from xtuner.v1.data_proto.rl_data import RolloutState, Status | ||
|
|
||
| class MockState: | ||
| def __init__(self, id, staleness=0): | ||
| self.id = id | ||
| self.seq_staleness = staleness | ||
|
|
||
| class TestReplayBuffer(unittest.IsolatedAsyncioTestCase): | ||
| async def test_fifo_backend(self): | ||
| backend = FIFOStorageBackend() | ||
| buffer = ReplayBuffer(storage_backend=backend) | ||
| states = [MockState(i) for i in range(1, 4)] | ||
|
|
||
| await buffer.put(states, "task1", Status.COMPLETED) | ||
| res = await buffer.get(2, "task1", Status.COMPLETED) | ||
|
|
||
| self.assertEqual(len(res), 2) | ||
| self.assertEqual(res[0].id, 1) | ||
| self.assertEqual(res[1].id, 2) | ||
|
|
||
| async def test_staleness_priority(self): | ||
| backend = StalenessStorageBackend(min_staleness=0, max_staleness=5) | ||
| buffer = ReplayBuffer(storage_backend=backend) | ||
|
|
||
| s1 = MockState(id="low", staleness=1) | ||
| s5 = MockState(id="high", staleness=5) | ||
|
|
||
| await buffer.put([s1], "task1", Status.COMPLETED) | ||
| await buffer.put([s5], "task1", Status.COMPLETED) | ||
|
|
||
| res = await buffer.get(2, "task1", Status.COMPLETED) | ||
| self.assertEqual(res[0].id, "high") | ||
| self.assertEqual(res[1].id, "low") | ||
|
|
||
| async def test_multi_task(self): | ||
| buffer = ReplayBuffer() | ||
| await buffer.put([MockState(100)], "task_a", Status.COMPLETED) | ||
| await buffer.put([MockState(200)], "task_b", Status.COMPLETED) | ||
|
|
||
| res_a = await buffer.get(10, "task_a", Status.COMPLETED) | ||
| res_b = await buffer.get(10, "task_b", Status.COMPLETED) | ||
| self.assertEqual(len(res_a), 1) | ||
| self.assertEqual(res_a[0].id, 100) | ||
| self.assertEqual(len(res_b), 1) | ||
| self.assertEqual(res_b[0].id, 200) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,222 @@ | ||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||
| from abc import ABC, abstractmethod | ||||||||||||||||||||||
| from collections import defaultdict, deque | ||||||||||||||||||||||
| from dataclasses import dataclass, field | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| from xtuner.v1.data_proto.rl_data import RolloutState, Status | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @dataclass | ||||||||||||||||||||||
| class StorageIndices: | ||||||||||||||||||||||
| # 为不同存储后段提供统一的索引接口 | ||||||||||||||||||||||
|
||||||||||||||||||||||
| # 为不同存储后段提供统一的索引接口 | |
| # 为不同存储后端提供统一的索引接口 |
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tags-based partitioning logic is part of the public ReplayBuffer.put/get API (via **kwargs), but there are no tests asserting that different tag values map to different storage partitions and don’t mix. Consider adding a small test that writes items with different tag combinations and verifies isolation.
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_hash_storage_indices builds a tuple used as a dict key; if any tag value is unhashable (e.g., list/dict), this will raise TypeError at runtime. Since tags come from **kwargs, consider validating/coercing tag values to hashable types (e.g., str(value)/json.dumps) or restricting the accepted tag value types.
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
StalenessStorageBackend.__init__ accepts a limit parameter but it is never enforced (items can grow unbounded). Either implement eviction behavior consistent with FIFOStorageBackend(limit=...) or remove the parameter to avoid misleading API.
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
StalenessStorageBackend.put will crash if items is empty (max() on empty list). Also, if any seq_staleness falls outside [min_staleness, max_staleness] (defaults are both 0), self._storage[indices][group_seq_staleness] will raise KeyError. Consider explicitly handling empty input (no-op or clear error) and validating/clamping seq_staleness to the configured bucket range (or dynamically creating buckets).
| indices = self._hash_storage_indices(storage_indices) | |
| group_seq_staleness = max([item.seq_staleness for item in items]) | |
| # If there are no items, treat this as a no-op to avoid max() on an empty list. | |
| if not items: | |
| return | |
| indices = self._hash_storage_indices(storage_indices) | |
| group_seq_staleness = max(item.seq_staleness for item in items) | |
| # Clamp staleness into the configured bucket range to avoid KeyError. | |
| group_seq_staleness = max(self.min_staleness, | |
| min(self.max_staleness, group_seq_staleness)) |
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as FIFO backend: _hash_storage_indices returns a tuple used as a dict key, so unhashable tag values (e.g., list/dict) will raise TypeError. Since this backend also accepts tags via StorageIndices, consider validating/coercing tag values or restricting tag types.
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This statement is unreachable.
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PandasStorageBackend.get is defined as a synchronous method, but the StorageBackend interface defines async def get(...). Even though the class currently raises in __init__, keeping signatures consistent will prevent accidental misuse later and avoids confusing API expectations (callers will await this).
| def get(self, count: int, indices: StorageIndices) -> list[RolloutState]: | |
| async def get(self, count: int, indices: StorageIndices) -> list[RolloutState]: |
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This statement is unreachable.
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SQLStorageBackend.get builds SQL using f-strings for both the JSON path ($.{key}) and LIMIT {count}. If key is user-controlled (it comes from indices.tags), this is a SQL injection risk once this backend is implemented. Prefer validating key against an allowlist/regex and using parameter binding for LIMIT (and avoid interpolating raw values into the query string).
Copilot
AI
Feb 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ReplayBuffer.__init__ annotates storage_backend as StorageBackend but defaults it to None, which will fail type checking under mypy’s strict optional rules. Please change the annotation to StorageBackend | None (or Optional[StorageBackend]).
| def __init__(self, storage_backend: StorageBackend = None): | |
| def __init__(self, storage_backend: StorageBackend | None = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests/ray/test_replay_buffer.pyhas several unused imports (asyncio,StorageIndices,RolloutState). Cleaning these up avoids confusion about what the tests actually exercise.