Skip to content

Conversation

DNXie
Copy link
Member

@DNXie DNXie commented Sep 11, 2025

This PR refactors replay buffer storage to KV-based buffer.
We can’t integrate torchstore yet since the necessary APIs aren’t implemented (numel, delete, etc.). This PR refactors replay buffer storage to KV-based data structure, so we’ll be able to switch the backend easily once torchstore is ready.

  • Added StoreInterface for KV store abstraction and further integration of torchstore
  • Implemented KVStore as a temporary KV store backend
  • Added unit tests for KVStore in test_kv_store.py
  • Refactored ReplayBuffer to useStoreInterface instead of a local list
  • Updated test_replay_buffer.py and test_toy_rl.py accordingly
  • Updated all callsites of ReplayBuffer in apps.grpo.main, apps.rl.main, apps.toy_rl.main

Test

pytest tests/unit_tests/test_kv_store.py
pytest tests/unit_tests/test_replay_buffer.py
pytest tests/unit_tests/rl/test_toy_rl.py
python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
python -m apps.rl.main --config apps/rl/llama3_8b.yaml
python -m apps.toy_rl.main

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 11, 2025
Copy link
Contributor

@casteryh casteryh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall LGTM.
But we can think more about how evict should work in replay buffer - which of course is outside of the scope of this PR.

await self._add(episode)

async def _add(self, episode) -> None:
key = f"rb_ep_{await self.store.numel()}_{uuid.uuid4().hex}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the point of await self.store.numel()?
Also this may be expensive.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re: recovery and determinism, maybe you could use uuid5 or something like highway hash. But I am not sure how important is determinism.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can add a counter in the ReplayBuffer class

Then derive the key using uuid5 and/or highway hash and/or your favorite hash, with the following 3 pieces of information

  • the counter
  • the rank of the current worker
  • the content of the value.

This will generally avoid duplicate keys even if you have episodes with the same content coming in.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need deterministic at this stage. Let's keep things simple. I have dropped the await self.store.numel() to make things efficient. Thanks for pointing this out.


keys = await self.store.keys()

# TODO: Make this more efficient
Copy link
Member Author

@DNXie DNXie Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joecummings Do we still need this TODO?

@DNXie DNXie changed the title [WIP] Use torchstore as backend for replay buffer [WIP] refactor replay buffer to use KV buffer Sep 11, 2025
@DNXie DNXie requested a review from LucasLLC September 11, 2025 18:45
@DNXie DNXie changed the title [WIP] refactor replay buffer to use KV buffer Refactor replay buffer to use KV buffer Sep 11, 2025
@DNXie DNXie changed the title Refactor replay buffer to use KV buffer [WIP] Refactor replay buffer to use KV buffer Sep 11, 2025
@DNXie DNXie changed the title [WIP] Refactor replay buffer to use KV buffer [RFC] Refactor replay buffer to use KV buffer Sep 12, 2025
pass


class StoreInterface(ABC):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, pair this down to the exact APIs we will be using in the ReplayBuffer - no more, no less. We can always update the interface later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve already pared the interface down to only the APIs we need. My concern is that methods like numel and delete are essential for the replay buffer’s functionality (e.g., eviction, checking buffer size) but aren’t yet implemented in TorchStore. If we remove these from the interface, the buffer implementation won’t be able to operate consistently.

async def _evict(self, curr_policy_version: int) -> None:
keys = await self.store.keys()
for key in keys:
episode = await self.store.get(key)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine for now, but leave a comment that we could store each key as a uuid + the policy version and make this more efficient.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Added.
Once torchstore support fetching with "prefix", this would be much easier.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Added. Once torchstore support fetching with "prefix", this would be much easier.

Coming soon!

for trajectory in self.buffer
if (curr_policy_version - trajectory.policy_version) <= self.max_policy_age
]
async def _evict(self, curr_policy_version: int) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We control this internal method so you can pass in the keys from above, which we already calculated.

Copy link
Member Author

@DNXie DNXie Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we have to re-fetch the keys after eviction because _evict may delete some entries. We fetch once before _evict to know what to check for eviction, then fetch again after to ensure we only sample from the remaining keys. This prevents trying to access keys that no longer exist.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to reconsider the whole eviction logic. See my previous point re: concurerncy.

@DNXie DNXie marked this pull request as ready for review September 12, 2025 22:15
@endpoint
async def setup(self) -> None:
self.buffer: list = []
def __post_init__(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joecummings I changed the setup to post_init because I found that the setup is not called in many of the scripts when we are using ReplayBuffer. And if it is not called, things may not be initialized correctly (e.g., sampler). Let me know if this cause any concerns.

@DNXie DNXie requested a review from joecummings September 12, 2025 22:52
@DNXie DNXie changed the title [RFC] Refactor replay buffer to use KV buffer Refactor replay buffer to use KV buffer Sep 12, 2025
Copy link
Contributor

@casteryh casteryh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, Approval 🚀
Note: This is by no means criticism, but I want to point out that the logic here is strictly incorrect when run concurrently (see comments), we can leave that to a future PR.

Comment on lines 56 to 67
total_samples = self.dp_size * bsz

# Evict old episodes
self._evict(curr_policy_version)
await self._evict(curr_policy_version)

if total_samples > len(self.buffer):
total_available = await self.store.numel()
if total_samples > total_available:
return None

keys = await self.store.keys()

# TODO: Make this more efficient
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a general comment: _evict() before getting keys is not a reliable way to ensure we don't get outdated policies.
Since we have several await points between _evict() and keys(). Unless you want to put an async lock on self.store, which you probably don't. This is beyond the scope of this PR though, please and an TODO here. cc @joecummings

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even nothing is concurrent at all at this point. I think we should at least keep in mind we will need to support concurrency in the very near future. Also it's not necessarily harder to write concurrently correct program. Albeit we do need to be more careful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Will add this TODO before landing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

for trajectory in self.buffer
if (curr_policy_version - trajectory.policy_version) <= self.max_policy_age
]
async def _evict(self, curr_policy_version: int) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to reconsider the whole eviction logic. See my previous point re: concurerncy.

async def _evict(self, curr_policy_version: int) -> None:
keys = await self.store.keys()
for key in keys:
episode = await self.store.get(key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Added. Once torchstore support fetching with "prefix", this would be much easier.

Coming soon!

class ReplayBuffer(ForgeActor):
"""Simple in-memory replay buffer implementation."""

store: StoreInterface
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe call this the backend?

wdyt @LucasLLC ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed store -> backend

@endpoint
async def state_dict(self) -> dict[str, Any]:
keys = await self.store.keys()
episodes = [(k, await self.store.get(k)) for k in keys]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not ideal IMO - is there a way we could dump / serialize the contents of the store?

@LucasLLC ?

Copy link
Contributor

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really cool! Transparently my concern is that this adds a lot of complexity, which is risky when we don't have multi-node e2e running yet.

Timing wise, it'll be better to shelf this for now and re-visit this once we have something running (maybe even in the next 2 weeks). We'll inevitably hit bottlenecks, but that'll motivate our longer term design and implementation

A simple single-node key-value (KV) store implementation of StoreInterface.
This acts as a temporary backend for the replay buffer until torchstore
supports the full set of operations we need (delete, pop, keys, numel, etc.).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the consistency semantics and the thread safety of the interface?

For example, once put finishes, any future gets will always see the effect of the put.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current KVStore implementation is a simple in-memory dictionary with async methods, intended as a temporary backend until torchstore is ready. It does not provide thread safety or strong consistency guarantees in the presence of concurrent access. Specifically:

  • If multiple coroutines access or modify the store concurrently, race conditions may occur (e.g., a get may see stale or missing data if a delete or put happens at the same time).
  • In a single-threaded asyncio event loop, as long as each operation is awaited, the store behaves as expected: once a put completes, subsequent gets will see the new value.
  • However, if the store is accessed from multiple threads or if multiple async tasks interleave operations without awaiting, consistency is not guaranteed.

The plan is to switch to torchstore once the key APIs like delete and numel are ready, which should provide proper concurrency and consistency guarantees.

@DNXie DNXie closed this by deleting the head repository Oct 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants