-
Notifications
You must be signed in to change notification settings - Fork 24
Refactor replay buffer to use KV buffer #147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall LGTM.
But we can think more about how evict should work in replay buffer - which of course is outside of the scope of this PR.
src/forge/actors/replay_buffer.py
Outdated
await self._add(episode) | ||
|
||
async def _add(self, episode) -> None: | ||
key = f"rb_ep_{await self.store.numel()}_{uuid.uuid4().hex}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the point of await self.store.numel()
?
Also this may be expensive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re: recovery and determinism, maybe you could use uuid5 or something like highway hash. But I am not sure how important is determinism.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you can add a counter in the ReplayBuffer class
Then derive the key using uuid5 and/or highway hash and/or your favorite hash, with the following 3 pieces of information
- the counter
- the rank of the current worker
- the content of the value.
This will generally avoid duplicate keys even if you have episodes with the same content coming in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need deterministic at this stage. Let's keep things simple. I have dropped the await self.store.numel()
to make things efficient. Thanks for pointing this out.
src/forge/actors/replay_buffer.py
Outdated
|
||
keys = await self.store.keys() | ||
|
||
# TODO: Make this more efficient |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@joecummings Do we still need this TODO?
pass | ||
|
||
|
||
class StoreInterface(ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, pair this down to the exact APIs we will be using in the ReplayBuffer - no more, no less. We can always update the interface later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’ve already pared the interface down to only the APIs we need. My concern is that methods like numel
and delete
are essential for the replay buffer’s functionality (e.g., eviction, checking buffer size) but aren’t yet implemented in TorchStore. If we remove these from the interface, the buffer implementation won’t be able to operate consistently.
src/forge/actors/replay_buffer.py
Outdated
async def _evict(self, curr_policy_version: int) -> None: | ||
keys = await self.store.keys() | ||
for key in keys: | ||
episode = await self.store.get(key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine for now, but leave a comment that we could store each key as a uuid + the policy version and make this more efficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Added.
Once torchstore support fetching with "prefix", this would be much easier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Added. Once torchstore support fetching with "prefix", this would be much easier.
Coming soon!
for trajectory in self.buffer | ||
if (curr_policy_version - trajectory.policy_version) <= self.max_policy_age | ||
] | ||
async def _evict(self, curr_policy_version: int) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We control this internal method so you can pass in the keys from above, which we already calculated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, we have to re-fetch the keys after eviction because _evict
may delete some entries. We fetch once before _evict
to know what to check for eviction, then fetch again after to ensure we only sample from the remaining keys. This prevents trying to access keys that no longer exist.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to reconsider the whole eviction logic. See my previous point re: concurerncy.
@endpoint | ||
async def setup(self) -> None: | ||
self.buffer: list = [] | ||
def __post_init__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@joecummings I changed the setup
to post_init
because I found that the setup
is not called in many of the scripts when we are using ReplayBuffer. And if it is not called, things may not be initialized correctly (e.g., sampler). Let me know if this cause any concerns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Approval 🚀
Note: This is by no means criticism, but I want to point out that the logic here is strictly incorrect when run concurrently (see comments), we can leave that to a future PR.
src/forge/actors/replay_buffer.py
Outdated
total_samples = self.dp_size * bsz | ||
|
||
# Evict old episodes | ||
self._evict(curr_policy_version) | ||
await self._evict(curr_policy_version) | ||
|
||
if total_samples > len(self.buffer): | ||
total_available = await self.store.numel() | ||
if total_samples > total_available: | ||
return None | ||
|
||
keys = await self.store.keys() | ||
|
||
# TODO: Make this more efficient |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a general comment: _evict()
before getting keys is not a reliable way to ensure we don't get outdated policies.
Since we have several await points between _evict()
and keys()
. Unless you want to put an async lock on self.store, which you probably don't. This is beyond the scope of this PR though, please and an TODO here. cc @joecummings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even nothing is concurrent at all at this point. I think we should at least keep in mind we will need to support concurrency in the very near future. Also it's not necessarily harder to write concurrently correct program. Albeit we do need to be more careful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Will add this TODO before landing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
for trajectory in self.buffer | ||
if (curr_policy_version - trajectory.policy_version) <= self.max_policy_age | ||
] | ||
async def _evict(self, curr_policy_version: int) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to reconsider the whole eviction logic. See my previous point re: concurerncy.
src/forge/actors/replay_buffer.py
Outdated
async def _evict(self, curr_policy_version: int) -> None: | ||
keys = await self.store.keys() | ||
for key in keys: | ||
episode = await self.store.get(key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Added. Once torchstore support fetching with "prefix", this would be much easier.
Coming soon!
src/forge/actors/replay_buffer.py
Outdated
class ReplayBuffer(ForgeActor): | ||
"""Simple in-memory replay buffer implementation.""" | ||
|
||
store: StoreInterface |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe call this the backend?
wdyt @LucasLLC ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed store
-> backend
src/forge/actors/replay_buffer.py
Outdated
@endpoint | ||
async def state_dict(self) -> dict[str, Any]: | ||
keys = await self.store.keys() | ||
episodes = [(k, await self.store.get(k)) for k in keys] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not ideal IMO - is there a way we could dump / serialize the contents of the store?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really cool! Transparently my concern is that this adds a lot of complexity, which is risky when we don't have multi-node e2e running yet.
Timing wise, it'll be better to shelf this for now and re-visit this once we have something running (maybe even in the next 2 weeks). We'll inevitably hit bottlenecks, but that'll motivate our longer term design and implementation
A simple single-node key-value (KV) store implementation of StoreInterface. | ||
This acts as a temporary backend for the replay buffer until torchstore | ||
supports the full set of operations we need (delete, pop, keys, numel, etc.). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain the consistency semantics and the thread safety of the interface?
For example, once put finishes, any future gets will always see the effect of the put.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current KVStore
implementation is a simple in-memory dictionary with async methods, intended as a temporary backend until torchstore
is ready. It does not provide thread safety or strong consistency guarantees in the presence of concurrent access. Specifically:
- If multiple coroutines access or modify the store concurrently, race conditions may occur (e.g., a
get
may see stale or missing data if adelete
orput
happens at the same time). - In a single-threaded
asyncio
event loop, as long as each operation is awaited, the store behaves as expected: once aput
completes, subsequentget
s will see the new value. - However, if the store is accessed from multiple threads or if multiple async tasks interleave operations without awaiting, consistency is not guaranteed.
The plan is to switch to torchstore once the key APIs like delete
and numel
are ready, which should provide proper concurrency and consistency guarantees.
This PR refactors replay buffer storage to KV-based buffer.
We can’t integrate torchstore yet since the necessary APIs aren’t implemented (numel, delete, etc.). This PR refactors replay buffer storage to KV-based data structure, so we’ll be able to switch the backend easily once torchstore is ready.
StoreInterface
for KV store abstraction and further integration of torchstoreKVStore
as a temporary KV store backendKVStore
intest_kv_store.py
ReplayBuffer
to useStoreInterface
instead of a local listtest_replay_buffer.py
andtest_toy_rl.py
accordinglyReplayBuffer
inapps.grpo.main
,apps.rl.main
,apps.toy_rl.main
Test