Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/forge/data/raw_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Iterator, TypeVar

from forge.interfaces import RawBuffer

K = TypeVar("K")
V = TypeVar("V")


class SimpleRawBuffer(RawBuffer[K, V]):
"""Simple in-memory RawBuffer backed by a Python dictionary."""

def __init__(self) -> None:
self._buffer: dict[K, V] = {}

def __len__(self) -> int:
"""Return the number of key-value pairs in the buffer."""
return len(self._buffer)

def __getitem__(self, key: K) -> V:
"""Get a value from the buffer using the specified key."""
return self._buffer[key]

def __iter__(self) -> Iterator[tuple[K, V]]:
"""Iterate over the key-value pairs in the buffer."""
for k, v in self._buffer.items():
yield k, v

def keys(self) -> Iterator[K]:
"""Iterate over the keys in the buffer."""
for k in self._buffer.keys():
yield k

def add(self, key: K, val: V) -> None:
"""Add a key-value pair to the buffer."""
if key in self._buffer:
raise KeyError(f"Key {key} already exists in the buffer.")
self._buffer[key] = val

def pop(self, key: K) -> V:
"""Remove and return a value from the buffer using the specified key."""
if key not in self._buffer:
raise KeyError(f"Key {key} does not exist in the buffer.")
val = self._buffer[key]
del self._buffer[key]
return val

def clear(self) -> None:
"""Clear the buffer."""
self._buffer.clear()
73 changes: 73 additions & 0 deletions src/forge/data/stateful_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import random
from typing import Any, Generic, List, Mapping, TypeVar

from forge.interfaces import BufferView, StatefulSampler

K = TypeVar("K")
V = TypeVar("V")


class RandomStatefulSampler(StatefulSampler[K, V], Generic[K, V]):
"""A simple stateful sampler that uses Python's random.sample for deterministic sampling.

This sampler maintains an internal random state that can be saved and restored,
allowing for reproducible sampling behavior. It uses random.sample to select
keys from the buffer without replacement.
"""

def __init__(self, seed: int | None = None):
"""Initialize the sampler with an optional random seed.

Args:
seed: Optional seed for the random number generator. If None,
the sampler will use Python's default random state.
"""
if seed is None:
self._random = random.Random()
self._random = random.Random(seed)

def sample_keys(self, buffer: BufferView[K, V], num: int) -> List[K]:
"""Sample keys from the buffer using random.sample.

Args:
buffer: The buffer to sample from
num: Number of keys to sample

Returns:
A list of sampled keys. If num is greater than the buffer size,
returns all available keys.
"""
# Get all keys from the buffer
all_keys = list(buffer.keys())

# If requesting more samples than available, return all keys
if num >= len(all_keys):
return all_keys

# Use random.sample for sampling without replacement
return self._random.sample(all_keys, num)

def state_dict(self):
"""Return the state dict of the sampler.

Returns:
A dictionary containing the random number generator state.
"""
return {"random_state": self._random.getstate()}

def set_state_dict(self, state_dict: Mapping[str, Any]):
"""Set the state dict of the sampler.

Args:
state_dict: Dictionary containing the random state to restore.
"""
if "random_state" in state_dict:
self._random.setstate(state_dict["random_state"])
else:
raise ValueError("Missing 'random_state' in state dict")
152 changes: 150 additions & 2 deletions src/forge/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import Any, Mapping
from typing import Any, Generic, Iterable, Mapping, TypeVar

from forge.types import Action, Message, Observation, Scalar, State

from monarch.actor import Actor, endpoint

from forge.types import Action, Message, Observation, Scalar, State
K = TypeVar("K")
V = TypeVar("V")


class Transform(ABC):
Expand Down Expand Up @@ -88,6 +91,151 @@ async def update_weights(self):
pass


class BufferView(ABC, Generic[K, V]):
"""Abstract base class for a view into a buffer with key-value pairs.

This class defines the interface for accessing elements in a buffer
through dictionary-like operations. It supports generic key and value types.
"""

@abstractmethod
def __len__(self) -> int:
"""Return the number of key-value pairs in the buffer.

Returns:
int: The number of items in the buffer.
"""
pass

@abstractmethod
def __getitem__(self, key: K) -> V:
"""Retrieve a value from the buffer using the specified key.

Args:
key (K): The key to look up in the buffer.

Returns:
V: The value associated with the key.

Raises:
KeyError: If the key is not found in the buffer.
"""
pass

@abstractmethod
def __iter__(self) -> Iterable[tuple[K, V]]:
"""Return an iterator over the key-value pairs in the buffer.

Returns:
Iterable[tuple[K, V]]: An iterator yielding (key, value) tuples.
"""
pass

@abstractmethod
def keys(self) -> Iterable[K]:
"""Return an iterable of all keys in the buffer.

Returns:
Iterable[K]: An iterable containing all keys in the buffer.
"""
pass


class RawBuffer(BufferView[K, V], ABC):
"""Abstract interface for the underlying storage backend (raw buffer) of a ReplayBuffer."""

@abstractmethod
def add(self, key: K, val: V) -> None:
"""
Add a key-value pair to the buffer.

Args:
key (K): The key to store the value under
val (V): The value to store in the buffer

Returns:
None
"""
pass

@abstractmethod
def pop(self, key: K) -> V:
"""
Remove and return a value from the buffer using the specified key.

Args:
key (K): The key to look up and remove from the buffer

Returns:
V: The value associated with the key before removal
"""
pass

@abstractmethod
def clear(self) -> None:
"""
Remove all key-value pairs from the buffer, effectively emptying it.

This method should reset the buffer to its initial empty state.

Returns:
None
"""
pass


class StatefulSampler(ABC, Generic[K, V]):
"""Abstract interface for stateful samplers with deterministic behavior given a state.

This class defines the interface for samplers that maintain internal state and provide
deterministic sampling behavior when the state is fixed.
"""

@abstractmethod
def sample_keys(self, buffer: BufferView[K, V], num: int) -> list[K]:
"""Return the keys of selected samples from the buffer.

This method samples a specified number of keys from the provided buffer
according to the sampler's internal sampling strategy. The sampling
behavior is deterministic for a given internal state of the sampler.

Args:
buffer (BufferView[K, V]): The buffer to sample from, containing key-value pairs.
num (int): Desired number of samples to retrieve from the buffer.
If num is greater than the buffer size, implementation may
return fewer samples or handle it according to the specific
sampling strategy.

Returns:
list[K]: A list of keys corresponding to the selected samples.
The length of this list will typically be equal to num,
unless the buffer contains fewer items.
"""
pass

@abstractmethod
def state_dict(self) -> Mapping[str, Any]:
"""Return the state dict of the sampler.

This method should capture all the internal state necessary to reproduce
the sampler's behavior, such as random number generator states.

Returns:
dict: A dictionary containing the internal state of the sampler.
"""
pass

@abstractmethod
def set_state_dict(self, state_dict):
"""Set the state dict of the sampler.

Args:
state_dict (dict): A dictionary containing the internal state to restore
the sampler to a specific configuration.
"""
pass


class BaseTokenizer(ABC):
"""
Abstract token encoding model that implements ``encode`` and ``decode`` methods.
Expand Down
Loading
Loading