Skip to content
Closed
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
da21e1d
Add reward interface, math reward, unit tests
DNXie Aug 21, 2025
5c72908
Merge branch 'meta-pytorch:main' into main
DNXie Aug 22, 2025
b4d7a61
Merge branch 'meta-pytorch:main' into main
DNXie Aug 25, 2025
02d77c6
Merge branch 'meta-pytorch:main' into main
DNXie Aug 27, 2025
fd1d38b
Merge branch 'meta-pytorch:main' into main
DNXie Aug 28, 2025
f79beee
Merge branch 'meta-pytorch:main' into main
DNXie Aug 28, 2025
d8d775a
Merge branch 'meta-pytorch:main' into main
DNXie Sep 2, 2025
e423c44
Merge branch 'meta-pytorch:main' into main
DNXie Sep 4, 2025
4815c05
Merge branch 'meta-pytorch:main' into main
DNXie Sep 8, 2025
77d41e4
Merge branch 'meta-pytorch:main' into main
DNXie Sep 9, 2025
a3feb1e
Merge branch 'meta-pytorch:main' into main
DNXie Sep 10, 2025
ff6f5c7
add ts interface
DNXie Sep 11, 2025
d9411b9
add unit test
DNXie Sep 11, 2025
dcc8e00
fix lint
DNXie Sep 11, 2025
d812b9c
rename file
DNXie Sep 11, 2025
05dd33b
add delete all function
DNXie Sep 11, 2025
23d7e02
Merge branch 'meta-pytorch:main' into main
DNXie Sep 11, 2025
9953c91
add ts interface
DNXie Sep 11, 2025
726be1c
add unit test
DNXie Sep 11, 2025
e25a239
fix lint
DNXie Sep 11, 2025
ddde20f
rename file
DNXie Sep 11, 2025
0016889
add delete all function
DNXie Sep 11, 2025
1aec7aa
fix tests
DNXie Sep 11, 2025
0d7f4ac
Resolve merge conflict: keep local stores.py
DNXie Sep 11, 2025
0234808
resolve comments. add more test cases
DNXie Sep 11, 2025
ee5bb0c
fix lint
DNXie Sep 11, 2025
b32d840
fix failed test
DNXie Sep 11, 2025
80113df
resolve comments
DNXie Sep 12, 2025
f23285c
apply changes globally
DNXie Sep 12, 2025
895858c
simplify deleteall
DNXie Sep 12, 2025
d8ba98d
fix bug with keys
DNXie Sep 12, 2025
af84852
fix lint
DNXie Sep 12, 2025
8bef515
fix test
DNXie Sep 12, 2025
a58fd2d
add todo
DNXie Sep 15, 2025
780239a
store->backend
DNXie Sep 15, 2025
2d2503b
correct name
DNXie Sep 15, 2025
c66cfd9
Merge branch 'main' into replay_buffer_ts
DNXie Sep 15, 2025
393bcca
fix lint
DNXie Sep 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 39 additions & 22 deletions src/forge/actors/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,40 @@
# LICENSE file in the root directory of this source tree.

import random
import uuid
from dataclasses import dataclass
from typing import Any

from monarch.actor import endpoint

from forge.controller import ForgeActor
from forge.interfaces import StoreInterface

from monarch.actor import endpoint


@dataclass
class ReplayBuffer(ForgeActor):
"""Simple in-memory replay buffer implementation."""

store: StoreInterface
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe call this the backend?

wdyt @LucasLLC ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed store -> backend

batch_size: int
max_policy_age: int
dp_size: int = 1
seed: int | None = None

@endpoint
async def setup(self) -> None:
self.buffer: list = []
if self.seed is None:
self.seed = random.randint(0, 2**32)
random.seed(self.seed)
self.sampler = random.sample

@endpoint
async def add(self, episode) -> None:
self.buffer.append(episode)
await self._add(episode)

async def _add(self, episode) -> None:
key = f"rb_ep_{await self.store.numel()}_{uuid.uuid4().hex}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the point of await self.store.numel()?
Also this may be expensive.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re: recovery and determinism, maybe you could use uuid5 or something like highway hash. But I am not sure how important is determinism.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can add a counter in the ReplayBuffer class

Then derive the key using uuid5 and/or highway hash and/or your favorite hash, with the following 3 pieces of information

  • the counter
  • the rank of the current worker
  • the content of the value.

This will generally avoid duplicate keys even if you have episodes with the same content coming in.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need deterministic at this stage. Let's keep things simple. I have dropped the await self.store.numel() to make things efficient. Thanks for pointing this out.

await self.store.put(key, episode)

@endpoint
async def sample(self, curr_policy_version: int, batch_size: int | None = None):
Expand All @@ -50,18 +56,21 @@ async def sample(self, curr_policy_version: int, batch_size: int | None = None):
total_samples = self.dp_size * bsz

# Evict old episodes
self._evict(curr_policy_version)
await self._evict(curr_policy_version)

if total_samples > len(self.buffer):
total_available = await self.store.numel()
if total_samples > total_available:
return None

keys = await self.store.keys()

# TODO: Make this more efficient
Copy link
Member Author

@DNXie DNXie Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joecummings Do we still need this TODO?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a general comment: _evict() before getting keys is not a reliable way to ensure we don't get outdated policies.
Since we have several await points between _evict() and keys(). Unless you want to put an async lock on self.store, which you probably don't. This is beyond the scope of this PR though, please and an TODO here. cc @joecummings

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even nothing is concurrent at all at this point. I think we should at least keep in mind we will need to support concurrency in the very near future. Also it's not necessarily harder to write concurrently correct program. Albeit we do need to be more careful.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Will add this TODO before landing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

idx_to_sample = self.sampler(range(len(self.buffer)), k=total_samples)
idx_to_sample = self.sampler(range(len(keys)), k=total_samples)
# Pop episodes in descending order to avoid shifting issues
popped = [self.buffer.pop(i) for i in sorted(idx_to_sample, reverse=True)]
sorted_idxs = sorted(idx_to_sample, reverse=True)
popped = [await self.store.pop(keys[i]) for i in sorted_idxs]

# Reorder popped episodes to match the original random sample order
sorted_idxs = sorted(idx_to_sample, reverse=True)
idx_to_popped = dict(zip(sorted_idxs, popped))
sampled_episodes = [idx_to_popped[i] for i in idx_to_sample]

Expand All @@ -81,38 +90,46 @@ async def evict(self, curr_policy_version: int) -> None:
Args:
curr_policy_version (int): The current policy version.
"""
self._evict(curr_policy_version)
await self._evict(curr_policy_version)

def _evict(self, curr_policy_version: int) -> None:
self.buffer = [
trajectory
for trajectory in self.buffer
if (curr_policy_version - trajectory.policy_version) <= self.max_policy_age
]
async def _evict(self, curr_policy_version: int) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We control this internal method so you can pass in the keys from above, which we already calculated.

Copy link
Member Author

@DNXie DNXie Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we have to re-fetch the keys after eviction because _evict may delete some entries. We fetch once before _evict to know what to check for eviction, then fetch again after to ensure we only sample from the remaining keys. This prevents trying to access keys that no longer exist.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to reconsider the whole eviction logic. See my previous point re: concurerncy.

keys = await self.store.keys()
for key in keys:
episode = await self.store.get(key)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine for now, but leave a comment that we could store each key as a uuid + the policy version and make this more efficient.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Added.
Once torchstore support fetching with "prefix", this would be much easier.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Added. Once torchstore support fetching with "prefix", this would be much easier.

Coming soon!

if (curr_policy_version - episode.policy_version) > self.max_policy_age:
await self.store.delete(key)

@endpoint
async def _getitem(self, idx: int):
return self.buffer[idx]
async def _getitem(self, key: str):
return await self.store.get(key)

@endpoint
async def _numel(self) -> int:
"""Number of elements (episodes) in the replay buffer."""
return len(self.buffer)
return await self.store.numel()

@endpoint
async def clear(self) -> None:
"""Clear the replay buffer immediately - dropping all episodes."""
self.buffer.clear()
await self._clear()

async def _clear(self) -> None:
await self.store.delete_all()

@endpoint
async def state_dict(self) -> dict[str, Any]:
keys = await self.store.keys()
episodes = [await self.store.get(k) for k in keys]
return {
"buffer": self.buffer,
"buffer": episodes,
"rng_state": random.getstate(),
"seed": self.seed,
}

@endpoint
async def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self.buffer = state_dict["buffer"]
await self._clear()
for ep in state_dict["buffer"]:
await self._add(ep)
random.setstate(state_dict["rng_state"])
self.seed = state_dict["seed"]
68 changes: 68 additions & 0 deletions src/forge/data/stores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Any

<<<<<<< HEAD
from forge.interfaces import StoreInterface
=======
from src.forge.interfaces import StoreInterface
>>>>>>> 05dd33b7eb5574b9db6f8ec0ee2417919946b0fe


class KVStore(StoreInterface):
"""
A simple single-node key-value (KV) store implementation of StoreInterface.
This acts as a temporary backend for the replay buffer until torchstore
supports the full set of operations we need (delete, pop, keys, numel, etc.).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the consistency semantics and the thread safety of the interface?

For example, once put finishes, any future gets will always see the effect of the put.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current KVStore implementation is a simple in-memory dictionary with async methods, intended as a temporary backend until torchstore is ready. It does not provide thread safety or strong consistency guarantees in the presence of concurrent access. Specifically:

  • If multiple coroutines access or modify the store concurrently, race conditions may occur (e.g., a get may see stale or missing data if a delete or put happens at the same time).
  • In a single-threaded asyncio event loop, as long as each operation is awaited, the store behaves as expected: once a put completes, subsequent gets will see the new value.
  • However, if the store is accessed from multiple threads or if multiple async tasks interleave operations without awaiting, consistency is not guaranteed.

The plan is to switch to torchstore once the key APIs like delete and numel are ready, which should provide proper concurrency and consistency guarantees.

"""

def __init__(self):
self._store = {}

async def put(self, key: str, value: Any) -> None:
self._store[key] = value

async def get(self, key: str) -> Any:
return self._store[key]

async def exists(self, key: str) -> bool:
# Check if a key exists in the KV store
return key in self._store

async def keys(self, prefix: str | None = None) -> list[str]:
# Return all keys, optionally filtered by prefix
if prefix is None:
return list(self._store.keys())
return [k for k in self._store if k.startswith(prefix)]

async def numel(self, prefix: str | None = None) -> int:
# Return the number of key-value pairs, optionally filtered by prefix
return len(await self.keys(prefix))

async def delete(self, key: str) -> None:
# Delete a key-value pair from the store
del self._store[key]

async def pop(self, key: str) -> Any:
# Remove and return a key-value pair (get + delete)
return self._store.pop(key)

async def delete_all(self, prefix: str | None = None) -> int:
# Delete all key-value pairs matching the given prefix
if prefix is None:
# Optimize for deleting all keys
count = len(self._store)
self._store = {}
return count
else:
# Delete only keys matching the prefix
keys_to_delete = await self.keys(prefix)
for key in keys_to_delete:
del self._store[key]
return len(keys_to_delete)
134 changes: 131 additions & 3 deletions src/forge/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import Any, Mapping

from monarch.actor import endpoint
from typing import Any, List, Mapping

from forge.controller import ForgeActor

from forge.types import Action, Message, Observation, Scalar, State

from monarch.actor import endpoint


class Transform(ABC):
"""Abstract base class for observation transforms.
Expand Down Expand Up @@ -208,6 +208,134 @@ def __call__(self, observation: Observation) -> float:
pass


class StoreInterface(ABC):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, pair this down to the exact APIs we will be using in the ReplayBuffer - no more, no less. We can always update the interface later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve already pared the interface down to only the APIs we need. My concern is that methods like numel and delete are essential for the replay buffer’s functionality (e.g., eviction, checking buffer size) but aren’t yet implemented in TorchStore. If we remove these from the interface, the buffer implementation won’t be able to operate consistently.

"""
Abstract base class for a KV store. This closely follows the interface of
torchstore.
"""

# TODO: support this in torchstore.
@abstractmethod
async def numel(self, prefix=None) -> int:
"""Return the number of keys starting with the given prefix.
The prefix matching follows reverse domain name notation convention.

Args:
prefix (str): The prefix to match against stored keys.
For example, "xyz" matches "xyz.abc.def" but "xy" does not.
Note: None is the prefix of all keys, while "" is the prefix of keys
starting with "." and "" itself.

Returns:
int: The number of keys matching the prefix in the store.
"""
pass

@abstractmethod
async def keys(self, prefix=None) -> List[str]:
"""Return an iterable of all keys in the store matching the given prefix.
The prefix matching follows reverse domain name notation convention.

Args:
prefix (str): The prefix to match against stored keys.
For example, "xyz" matches "xyz.abc.def" but "xy" does not.
Note: None is the prefix of all keys, while "" is the prefix of keys
starting with "." and "" itself.

Returns:
Iterable[K]: An iterable containing all keys in the buffer.
"""
pass

@abstractmethod
async def put(self, key: str, value: Any) -> None:
"""
Add a key-value pair to the buffer.

Args:
key (K): The key to store the value under
val (V): The value to store in the buffer

Returns:
None
"""
pass

@abstractmethod
async def get(self, key: str) -> Any:
"""
Get a key-value pair from the store.

Args:
key (K): The key to get

Returns:
V: The value stored under the key

Raises:
KeyError: If the key does not exist in the store
"""
pass

@abstractmethod
async def exists(self, key: str) -> bool:
"""
Check if a key exists in the store.
"""
pass

# TODO: support this in torchstore.
@abstractmethod
async def pop(self, key: str) -> Any:
"""
Get a key-value pair from the store, and delete it from the store.

Args:
key (K): The key to get

Returns:
V: The value stored under the key

Raises:
KeyError: If the key does not exist in the store
"""

# TODO: support this in torchstore.
@abstractmethod
async def delete(self, key: str) -> None:
"""
Delete a key-value pair from the store.

Args:
key (K): The key to delete

Returns:
None

Raises:
KeyError: If the key does not exist in the store
"""
pass

# TODO: support this in torchstore.
@abstractmethod
async def delete_all(self, prefix=None) -> int:
"""
Delete all key-value pairs from the store matching the given prefix.
The prefix matching follows reverse domain name notation convention.

Args:
prefix (str): The prefix to match against stored keys.
For example, "xyz" matches "xyz.abc.def" but "xy" does not.
Note: None is the prefix of all keys, while "" is the prefix of keys
starting with "." and "" itself.

Returns:
int: The number of keys deleted from the store.
"""
pass


# TODO
# class RLLoss(ABC):

Expand Down
Loading
Loading