Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 35 additions & 29 deletions src/forge/actors/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@
# LICENSE file in the root directory of this source tree.

import random
import uuid
from dataclasses import dataclass
from typing import Any

from monarch.actor import endpoint

from forge.controller import ForgeActor
from forge.data.raw_buffer import SimpleRawBuffer
from forge.data.stateful_sampler import RandomStatefulSampler

from forge.interfaces import StatefulSampler

from monarch.actor import endpoint


@dataclass
Expand All @@ -22,16 +27,23 @@ class ReplayBuffer(ForgeActor):
seed: int | None = None

@endpoint
async def setup(self) -> None:
self.buffer: list = []
async def setup(self, *, sampler: StatefulSampler | None = None) -> None:
self._buffer = SimpleRawBuffer[int, Any]()
if self.seed is None:
self.seed = random.randint(0, 2**32)
random.seed(self.seed)
self.sampler = random.sample
if sampler is None:
sampler = RandomStatefulSampler(seed=self.seed)

self._sampler = sampler

@endpoint
async def add(self, episode) -> None:
self.buffer.append(episode)
# I think key should be provided by the caller, but let's just generate a random one for now
# Note that this means add() is not deterministic, however the original implementation using list
# isn't actually deterministic either because it depends on the order of add() being called.
# Alternatively, add a field in Trajectory as the id of the trajectory.
key = uuid.uuid4().int
self._buffer.add(key, episode)

@endpoint
async def sample(self, curr_policy_version: int, batch_size: int | None = None):
Expand All @@ -50,15 +62,11 @@ async def sample(self, curr_policy_version: int, batch_size: int | None = None):
# Evict old episodes
self._evict(curr_policy_version)

if bsz > len(self.buffer):
if bsz > len(self._buffer):
return None

# TODO: Make this more efficient
idx_to_sample = self.sampler(range(len(self.buffer)), k=bsz)
sorted_idxs = sorted(
idx_to_sample, reverse=True
) # Sort in desc order to avoid shifting idxs
sampled_episodes = [self.buffer.pop(i) for i in sorted_idxs]
keys_to_sample = self._sampler.sample_keys(self._buffer, num=bsz)
sampled_episodes = [self._buffer.pop(k) for k in keys_to_sample]
return sampled_episodes

@endpoint
Expand All @@ -72,35 +80,33 @@ async def evict(self, curr_policy_version: int) -> None:
self._evict(curr_policy_version)

def _evict(self, curr_policy_version: int) -> None:
self.buffer = [
trajectory
for trajectory in self.buffer
if (curr_policy_version - trajectory.policy_version) <= self.max_policy_age
]

@endpoint
async def _getitem(self, idx: int):
return self.buffer[idx]
keys_to_delete = []
for key, episode in self._buffer:
if curr_policy_version - episode.policy_version > self.max_policy_age:
keys_to_delete.append(key)
for key in keys_to_delete:
self._buffer.pop(key)

@endpoint
async def _numel(self) -> int:
"""Number of elements (episodes) in the replay buffer."""
return len(self.buffer)
return len(self._buffer)

@endpoint
async def clear(self) -> None:
"""Clear the replay buffer immediately - dropping all episodes."""
self.buffer.clear()
self._buffer.clear()

@endpoint
async def state_dict(self) -> dict[str, Any]:
return {
"buffer": self.buffer,
"rng_state": random.getstate(),
"buffer": self._buffer,
"sampler_state": self._sampler.state_dict(),
"seed": self.seed,
}

@endpoint
async def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self.buffer = state_dict["buffer"]
random.setstate(state_dict["rng_state"])
self._buffer = state_dict["buffer"]
self._sampler.set_state_dict(state_dict["sampler_state"])
self.seed = state_dict["seed"]
55 changes: 55 additions & 0 deletions src/forge/data/raw_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Iterator, TypeVar

from forge.interfaces import RawBuffer

K = TypeVar("K")
V = TypeVar("V")


class SimpleRawBuffer(RawBuffer[K, V]):
"""Simple in-memory RawBuffer backed by a Python dictionary."""

def __init__(self) -> None:
self._buffer: dict[K, V] = {}

def __len__(self) -> int:
"""Return the number of key-value pairs in the buffer."""
return len(self._buffer)

def __getitem__(self, key: K) -> V:
"""Get a value from the buffer using the specified key."""
return self._buffer[key]

def __iter__(self) -> Iterator[tuple[K, V]]:
"""Iterate over the key-value pairs in the buffer."""
for k, v in self._buffer.items():
yield k, v

def keys(self) -> Iterator[K]:
"""Iterate over the keys in the buffer."""
for k in self._buffer.keys():
yield k

def add(self, key: K, val: V) -> None:
"""Add a key-value pair to the buffer."""
if key in self._buffer:
raise KeyError(f"Key {key} already exists in the buffer.")
self._buffer[key] = val

def pop(self, key: K) -> V:
"""Remove and return a value from the buffer using the specified key."""
if key not in self._buffer:
raise KeyError(f"Key {key} does not exist in the buffer.")
val = self._buffer[key]
del self._buffer[key]
return val

def clear(self) -> None:
"""Clear the buffer."""
self._buffer.clear()
73 changes: 73 additions & 0 deletions src/forge/data/stateful_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import random
from typing import Any, Generic, List, Mapping, TypeVar

from forge.interfaces import BufferView, StatefulSampler

K = TypeVar("K")
V = TypeVar("V")


class RandomStatefulSampler(StatefulSampler[K, V], Generic[K, V]):
"""A simple stateful sampler that uses Python's random.sample for deterministic sampling.

This sampler maintains an internal random state that can be saved and restored,
allowing for reproducible sampling behavior. It uses random.sample to select
keys from the buffer without replacement.
"""

def __init__(self, seed: int | None = None):
"""Initialize the sampler with an optional random seed.

Args:
seed: Optional seed for the random number generator. If None,
the sampler will use Python's default random state.
"""
if seed is None:
self._random = random.Random()
self._random = random.Random(seed)

def sample_keys(self, buffer: BufferView[K, V], num: int) -> List[K]:
"""Sample keys from the buffer using random.sample.

Args:
buffer: The buffer to sample from
num: Number of keys to sample

Returns:
A list of sampled keys. If num is greater than the buffer size,
returns all available keys.
"""
# Get all keys from the buffer
all_keys = list(buffer.keys())

# If requesting more samples than available, return all keys
if num >= len(all_keys):
return all_keys

# Use random.sample for sampling without replacement
return self._random.sample(all_keys, num)

def state_dict(self):
"""Return the state dict of the sampler.

Returns:
A dictionary containing the random number generator state.
"""
return {"random_state": self._random.getstate()}

def set_state_dict(self, state_dict: Mapping[str, Any]):
"""Set the state dict of the sampler.

Args:
state_dict: Dictionary containing the random state to restore.
"""
if "random_state" in state_dict:
self._random.setstate(state_dict["random_state"])
else:
raise ValueError("Missing 'random_state' in state dict")
Loading
Loading