Skip to content
Closed
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
da21e1d
Add reward interface, math reward, unit tests
DNXie Aug 21, 2025
5c72908
Merge branch 'meta-pytorch:main' into main
DNXie Aug 22, 2025
b4d7a61
Merge branch 'meta-pytorch:main' into main
DNXie Aug 25, 2025
02d77c6
Merge branch 'meta-pytorch:main' into main
DNXie Aug 27, 2025
fd1d38b
Merge branch 'meta-pytorch:main' into main
DNXie Aug 28, 2025
f79beee
Merge branch 'meta-pytorch:main' into main
DNXie Aug 28, 2025
d8d775a
Merge branch 'meta-pytorch:main' into main
DNXie Sep 2, 2025
e423c44
Merge branch 'meta-pytorch:main' into main
DNXie Sep 4, 2025
4815c05
Merge branch 'meta-pytorch:main' into main
DNXie Sep 8, 2025
77d41e4
Merge branch 'meta-pytorch:main' into main
DNXie Sep 9, 2025
a3feb1e
Merge branch 'meta-pytorch:main' into main
DNXie Sep 10, 2025
ff6f5c7
add ts interface
DNXie Sep 11, 2025
d9411b9
add unit test
DNXie Sep 11, 2025
dcc8e00
fix lint
DNXie Sep 11, 2025
d812b9c
rename file
DNXie Sep 11, 2025
05dd33b
add delete all function
DNXie Sep 11, 2025
23d7e02
Merge branch 'meta-pytorch:main' into main
DNXie Sep 11, 2025
9953c91
add ts interface
DNXie Sep 11, 2025
726be1c
add unit test
DNXie Sep 11, 2025
e25a239
fix lint
DNXie Sep 11, 2025
ddde20f
rename file
DNXie Sep 11, 2025
0016889
add delete all function
DNXie Sep 11, 2025
1aec7aa
fix tests
DNXie Sep 11, 2025
0d7f4ac
Resolve merge conflict: keep local stores.py
DNXie Sep 11, 2025
0234808
resolve comments. add more test cases
DNXie Sep 11, 2025
ee5bb0c
fix lint
DNXie Sep 11, 2025
b32d840
fix failed test
DNXie Sep 11, 2025
80113df
resolve comments
DNXie Sep 12, 2025
f23285c
apply changes globally
DNXie Sep 12, 2025
895858c
simplify deleteall
DNXie Sep 12, 2025
d8ba98d
fix bug with keys
DNXie Sep 12, 2025
af84852
fix lint
DNXie Sep 12, 2025
8bef515
fix test
DNXie Sep 12, 2025
a58fd2d
add todo
DNXie Sep 15, 2025
780239a
store->backend
DNXie Sep 15, 2025
2d2503b
correct name
DNXie Sep 15, 2025
c66cfd9
Merge branch 'main' into replay_buffer_ts
DNXie Sep 15, 2025
393bcca
fix lint
DNXie Sep 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from forge.controller.actor import ForgeActor
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from forge.data.rewards import MathReward, ThinkingReward
from forge.data.stores import KVStore
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from omegaconf import DictConfig
Expand Down Expand Up @@ -377,6 +378,7 @@ async def main(cfg: DictConfig):
spawn_service(
ServiceConfig(**cfg.replay_buffer.service),
ReplayBuffer,
store=KVStore(),
**exclude_service(cfg.replay_buffer),
),
spawn_service(
Expand Down
2 changes: 2 additions & 0 deletions apps/rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from forge.actors import ReplayBuffer, RLTrainer
from forge.cli.config import parse
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from forge.data.stores import KVStore
from omegaconf import DictConfig

logger = logging.getLogger(__name__)
Expand All @@ -33,6 +34,7 @@ async def run(cfg: DictConfig):
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
ReplayBuffer,
store=KVStore(),
**cfg.replay_buffer,
),
)
Expand Down
8 changes: 6 additions & 2 deletions apps/toy_rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from forge.actors.collector import Collector

from forge.actors.replay_buffer import ReplayBuffer

from forge.data.stores import KVStore
from forge.interfaces import Environment, Policy
from forge.types import Action, Observation, State
from monarch.actor import endpoint, proc_mesh
Expand Down Expand Up @@ -141,8 +143,10 @@ async def main():
replay_buffer = await replay_procs.spawn(
"replay_buffer",
ReplayBuffer,
SAMPLES_PER_BATCH, # batch_size
float("inf"), # max_policy_age
store=KVStore(),
batch_size=SAMPLES_PER_BATCH,
max_policy_age=float("inf"),
dp_size=1,
)

# TODO - add in an example of a "vLLM executor" and "vLLM controller"
Expand Down
63 changes: 37 additions & 26 deletions src/forge/actors/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,36 @@
# LICENSE file in the root directory of this source tree.

import random
import uuid
from dataclasses import dataclass
from typing import Any

from monarch.actor import endpoint

from forge.controller import ForgeActor
from forge.interfaces import StoreInterface


@dataclass
class ReplayBuffer(ForgeActor):
"""Simple in-memory replay buffer implementation."""

store: StoreInterface
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe call this the backend?

wdyt @LucasLLC ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed store -> backend

batch_size: int
max_policy_age: int
dp_size: int = 1
seed: int | None = None

@endpoint
async def setup(self) -> None:
self.buffer: list = []
def __post_init__(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joecummings I changed the setup to post_init because I found that the setup is not called in many of the scripts when we are using ReplayBuffer. And if it is not called, things may not be initialized correctly (e.g., sampler). Let me know if this cause any concerns.

if self.seed is None:
self.seed = random.randint(0, 2**32)
random.seed(self.seed)
self.sampler = random.sample

@endpoint
async def add(self, episode) -> None:
self.buffer.append(episode)
key = f"rb_{uuid.uuid4().hex}"
await self.store.put(key, episode)

@endpoint
async def sample(self, curr_policy_version: int, batch_size: int | None = None):
Expand All @@ -50,20 +52,20 @@ async def sample(self, curr_policy_version: int, batch_size: int | None = None):
total_samples = self.dp_size * bsz

# Evict old episodes
self._evict(curr_policy_version)

if total_samples > len(self.buffer):
await self._evict(curr_policy_version)

keys = await self.store.keys()

total_available = await self.store.numel()
if total_samples > total_available:
return None

# TODO: Make this more efficient
Copy link
Member Author

@DNXie DNXie Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joecummings Do we still need this TODO?

idx_to_sample = self.sampler(range(len(self.buffer)), k=total_samples)
# Pop episodes in descending order to avoid shifting issues
popped = [self.buffer.pop(i) for i in sorted(idx_to_sample, reverse=True)]
idx_to_sample = self.sampler(range(len(keys)), k=total_samples)

# Reorder popped episodes to match the original random sample order
sorted_idxs = sorted(idx_to_sample, reverse=True)
idx_to_popped = dict(zip(sorted_idxs, popped))
sampled_episodes = [idx_to_popped[i] for i in idx_to_sample]
# Fetch and remove the sampled episodes
sampled_episodes = [await self.store.pop(keys[i]) for i in idx_to_sample]

# Reshape into (dp_size, bsz, ...)
reshaped_episodes = [
Expand All @@ -81,38 +83,47 @@ async def evict(self, curr_policy_version: int) -> None:
Args:
curr_policy_version (int): The current policy version.
"""
self._evict(curr_policy_version)
await self._evict(curr_policy_version)

def _evict(self, curr_policy_version: int) -> None:
self.buffer = [
trajectory
for trajectory in self.buffer
if (curr_policy_version - trajectory.policy_version) <= self.max_policy_age
]
async def _evict(self, curr_policy_version: int) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We control this internal method so you can pass in the keys from above, which we already calculated.

Copy link
Member Author

@DNXie DNXie Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we have to re-fetch the keys after eviction because _evict may delete some entries. We fetch once before _evict to know what to check for eviction, then fetch again after to ensure we only sample from the remaining keys. This prevents trying to access keys that no longer exist.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to reconsider the whole eviction logic. See my previous point re: concurerncy.

keys = await self.store.keys()
for key in keys:
episode = await self.store.get(key)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine for now, but leave a comment that we could store each key as a uuid + the policy version and make this more efficient.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Added.
Once torchstore support fetching with "prefix", this would be much easier.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Added. Once torchstore support fetching with "prefix", this would be much easier.

Coming soon!

# TODO: Could store keys as policy_version+uuid to evict without fetching each episode
if (curr_policy_version - episode.policy_version) > self.max_policy_age:
await self.store.delete(key)

@endpoint
async def _getitem(self, idx: int):
return self.buffer[idx]
async def _getitem(self, key: str):
return await self.store.get(key)

@endpoint
async def _numel(self) -> int:
"""Number of elements (episodes) in the replay buffer."""
return len(self.buffer)
return await self.store.numel()

@endpoint
async def clear(self) -> None:
"""Clear the replay buffer immediately - dropping all episodes."""
self.buffer.clear()
await self._clear()

async def _clear(self) -> None:
await self.store.delete_all()

@endpoint
async def state_dict(self) -> dict[str, Any]:
keys = await self.store.keys()
episodes = [(k, await self.store.get(k)) for k in keys]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not ideal IMO - is there a way we could dump / serialize the contents of the store?

@LucasLLC ?

return {
"buffer": self.buffer,
"buffer": episodes,
"rng_state": random.getstate(),
"seed": self.seed,
}

@endpoint
async def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self.buffer = state_dict["buffer"]
await self._clear()
for k, ep in state_dict["buffer"]:
await self.store.put(k, ep)
random.setstate(state_dict["rng_state"])
self.seed = state_dict["seed"]
61 changes: 61 additions & 0 deletions src/forge/data/stores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Any

from src.forge.interfaces import StoreInterface


class KVStore(StoreInterface):
"""
A simple single-node key-value (KV) store implementation of StoreInterface.
This acts as a temporary backend for the replay buffer until torchstore
supports the full set of operations we need (delete, pop, keys, numel, etc.).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the consistency semantics and the thread safety of the interface?

For example, once put finishes, any future gets will always see the effect of the put.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current KVStore implementation is a simple in-memory dictionary with async methods, intended as a temporary backend until torchstore is ready. It does not provide thread safety or strong consistency guarantees in the presence of concurrent access. Specifically:

  • If multiple coroutines access or modify the store concurrently, race conditions may occur (e.g., a get may see stale or missing data if a delete or put happens at the same time).
  • In a single-threaded asyncio event loop, as long as each operation is awaited, the store behaves as expected: once a put completes, subsequent gets will see the new value.
  • However, if the store is accessed from multiple threads or if multiple async tasks interleave operations without awaiting, consistency is not guaranteed.

The plan is to switch to torchstore once the key APIs like delete and numel are ready, which should provide proper concurrency and consistency guarantees.

"""

def __init__(self):
self._store = {}

async def put(self, key: str, value: Any) -> None:
self._store[key] = value

async def get(self, key: str) -> Any:
return self._store[key]

async def exists(self, key: str) -> bool:
# Check if a key exists in the KV store
return key in self._store

async def keys(self, prefix: str | None = None) -> list[str]:
# Return all keys, optionally filtered by prefix
if prefix is None:
return list(self._store.keys())
return [k for k in self._store if k.startswith(prefix)]

async def numel(self, prefix: str | None = None) -> int:
# Return the number of key-value pairs, optionally filtered by prefix
return len(await self.keys(prefix))

async def delete(self, key: str) -> None:
# Delete a key-value pair from the store
del self._store[key]

async def pop(self, key: str) -> Any:
# Remove and return a key-value pair (get + delete)
return self._store.pop(key)

async def delete_all(self, prefix: str | None = None) -> None:
# Delete all key-value pairs matching the given prefix
if prefix is None:
# Optimize for deleting all keys
self._store = {}
else:
# Delete only keys matching the prefix
keys_to_delete = await self.keys(prefix)
for key in keys_to_delete:
del self._store[key]
129 changes: 128 additions & 1 deletion src/forge/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import Any, Mapping
from typing import Any, List, Mapping

from monarch.actor import endpoint

Expand Down Expand Up @@ -208,6 +208,133 @@ def __call__(self, observation: Observation) -> float:
pass


class StoreInterface(ABC):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, pair this down to the exact APIs we will be using in the ReplayBuffer - no more, no less. We can always update the interface later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve already pared the interface down to only the APIs we need. My concern is that methods like numel and delete are essential for the replay buffer’s functionality (e.g., eviction, checking buffer size) but aren’t yet implemented in TorchStore. If we remove these from the interface, the buffer implementation won’t be able to operate consistently.

"""
Abstract base class for a KV store. This closely follows the interface of
torchstore.
"""

# TODO: support this in torchstore.
@abstractmethod
async def numel(self, prefix=None) -> int:
"""Return the number of keys starting with the given prefix.
The prefix matching follows reverse domain name notation convention.

Args:
prefix (str): The prefix to match against stored keys.
For example, "xyz" matches "xyz.abc.def" but "xy" does not.
Note: None is the prefix of all keys, while "" is the prefix of keys
starting with "." and "" itself.

Returns:
int: The number of keys matching the prefix in the store.
"""
pass

@abstractmethod
async def keys(self, prefix=None) -> List[str]:
"""Return an iterable of all keys in the store matching the given prefix.
The prefix matching follows reverse domain name notation convention.

Args:
prefix (str): The prefix to match against stored keys.
For example, "xyz" matches "xyz.abc.def" but "xy" does not.
Note: None is the prefix of all keys, while "" is the prefix of keys
starting with "." and "" itself.

Returns:
Iterable[K]: An iterable containing all keys in the buffer.
"""
pass

@abstractmethod
async def put(self, key: str, value: Any) -> None:
"""
Add a key-value pair to the buffer.

Args:
key (K): The key to store the value under
val (V): The value to store in the buffer

Returns:
None
"""
pass

@abstractmethod
async def get(self, key: str) -> Any:
"""
Get a key-value pair from the store.

Args:
key (K): The key to get

Returns:
V: The value stored under the key

Raises:
KeyError: If the key does not exist in the store
"""
pass

@abstractmethod
async def exists(self, key: str) -> bool:
"""
Check if a key exists in the store.
"""
pass

# TODO: support this in torchstore.
@abstractmethod
async def pop(self, key: str) -> Any:
"""
Get a key-value pair from the store, and delete it from the store.

Args:
key (K): The key to get

Returns:
V: The value stored under the key

Raises:
KeyError: If the key does not exist in the store
"""

# TODO: support this in torchstore.
@abstractmethod
async def delete(self, key: str) -> None:
"""
Delete a key-value pair from the store.

Args:
key (K): The key to delete

Returns:
None

Raises:
KeyError: If the key does not exist in the store
"""
pass

# TODO: support this in torchstore.
@abstractmethod
async def delete_all(self, prefix=None) -> None:
"""
Delete all key-value pairs from the store matching the given prefix.
The prefix matching follows reverse domain name notation convention.

Args:
prefix (str): The prefix to match against stored keys.
For example, "xyz" matches "xyz.abc.def" but "xy" does not.
Note: None is the prefix of all keys, while "" is the prefix of keys
starting with "." and "" itself.

Returns: None
"""
pass


# TODO
# class RLLoss(ABC):

Expand Down
Loading
Loading