From 6547207b07eea22d090a6bb3d3c35e1a6cf63226 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Wed, 27 Aug 2025 19:03:14 -0700 Subject: [PATCH 1/3] [RFC] Add BufferView and RawBuffer interfaces Summary: Added `BufferView` and `RawBuffer` interface. A sampler will operate on a BufferView and return the sampled keys. A ReplayBuffer will own a RawBuffer and operate on that. Test Plan: n/a --- src/forge/interfaces.py | 100 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 98 insertions(+), 2 deletions(-) diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index 4bd2d4bbe..736c3356d 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -5,11 +5,14 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import Any, Mapping +from typing import Any, Generic, Iterable, Mapping, TypeVar + +from forge.types import Action, Message, Observation, Scalar, State from monarch.actor import Actor, endpoint -from forge.types import Action, Message, Observation, Scalar, State +K = TypeVar("K") +V = TypeVar("V") class Transform(ABC): @@ -88,6 +91,99 @@ async def update_weights(self): pass +class BufferView(ABC, Generic[K, V]): + """Abstract base class for a view into a buffer with key-value pairs. + + This class defines the interface for accessing elements in a buffer + through dictionary-like operations. It supports generic key and value types. + """ + + @abstractmethod + def __len__(self) -> int: + """Return the number of key-value pairs in the buffer. + + Returns: + int: The number of items in the buffer. + """ + pass + + @abstractmethod + def __getitem__(self, key: K) -> V: + """Retrieve a value from the buffer using the specified key. + + Args: + key (K): The key to look up in the buffer. + + Returns: + V: The value associated with the key. + + Raises: + KeyError: If the key is not found in the buffer. + """ + pass + + @abstractmethod + def __iter__(self) -> Iterable[tuple[K, V]]: + """Return an iterator over the key-value pairs in the buffer. + + Returns: + Iterable[tuple[K, V]]: An iterator yielding (key, value) tuples. + """ + pass + + @abstractmethod + def keys(self) -> Iterable[K]: + """Return an iterable of all keys in the buffer. + + Returns: + Iterable[K]: An iterable containing all keys in the buffer. + """ + pass + + +class RawBuffer(BufferView[K, V], ABC): + """Abstract interface for the underlying storage backend (raw buffer) of a ReplayBuffer.""" + + @abstractmethod + def add(self, key: K, val: V) -> None: + """ + Add a key-value pair to the buffer. + + Args: + key (K): The key to store the value under + val (V): The value to store in the buffer + + Returns: + None + """ + pass + + @abstractmethod + def pop(self, key: K) -> V: + """ + Remove and return a value from the buffer using the specified key. + + Args: + key (K): The key to look up and remove from the buffer + + Returns: + V: The value associated with the key before removal + """ + pass + + @abstractmethod + def clear(self) -> None: + """ + Remove all key-value pairs from the buffer, effectively emptying it. + + This method should reset the buffer to its initial empty state. + + Returns: + None + """ + pass + + class BaseTokenizer(ABC): """ Abstract token encoding model that implements ``encode`` and ``decode`` methods. From 2ab5d1852b0d8c949c6d16ae2af91f64038271a2 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Wed, 27 Aug 2025 19:03:14 -0700 Subject: [PATCH 2/3] Implement SimpleRawBuffer, a RawBuffer backed by a python dict. Summary: Implement SimpleRawBuffer, a RawBuffer backed by a python dict. Test Plan: unit tests --- src/forge/data/raw_buffer.py | 55 +++++++ tests/unit_tests/test_raw_buffer.py | 231 ++++++++++++++++++++++++++++ 2 files changed, 286 insertions(+) create mode 100644 src/forge/data/raw_buffer.py create mode 100644 tests/unit_tests/test_raw_buffer.py diff --git a/src/forge/data/raw_buffer.py b/src/forge/data/raw_buffer.py new file mode 100644 index 000000000..ed5a14918 --- /dev/null +++ b/src/forge/data/raw_buffer.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Iterator, TypeVar + +from forge.interfaces import RawBuffer + +K = TypeVar("K") +V = TypeVar("V") + + +class SimpleRawBuffer(RawBuffer[K, V]): + """Simple in-memory RawBuffer backed by a Python dictionary.""" + + def __init__(self) -> None: + self._buffer: dict[K, V] = {} + + def __len__(self) -> int: + """Return the number of key-value pairs in the buffer.""" + return len(self._buffer) + + def __getitem__(self, key: K) -> V: + """Get a value from the buffer using the specified key.""" + return self._buffer[key] + + def __iter__(self) -> Iterator[tuple[K, V]]: + """Iterate over the key-value pairs in the buffer.""" + for k, v in self._buffer.items(): + yield k, v + + def keys(self) -> Iterator[K]: + """Iterate over the keys in the buffer.""" + for k in self._buffer.keys(): + yield k + + def add(self, key: K, val: V) -> None: + """Add a key-value pair to the buffer.""" + if key in self._buffer: + raise KeyError(f"Key {key} already exists in the buffer.") + self._buffer[key] = val + + def pop(self, key: K) -> V: + """Remove and return a value from the buffer using the specified key.""" + if key not in self._buffer: + raise KeyError(f"Key {key} does not exist in the buffer.") + val = self._buffer[key] + del self._buffer[key] + return val + + def clear(self) -> None: + """Clear the buffer.""" + self._buffer.clear() diff --git a/tests/unit_tests/test_raw_buffer.py b/tests/unit_tests/test_raw_buffer.py new file mode 100644 index 000000000..a6a8e5242 --- /dev/null +++ b/tests/unit_tests/test_raw_buffer.py @@ -0,0 +1,231 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Test for data/raw_buffer.py""" + +import pytest + +from forge.data.raw_buffer import SimpleRawBuffer + + +class TestSimpleRawBuffer: + """Test suite for SimpleRawBuffer class.""" + + def test_init_empty_buffer(self): + """Test that a new buffer is initialized empty.""" + buffer = SimpleRawBuffer[str, int]() + assert len(buffer) == 0 + + def test_add_single_item(self): + """Test adding a single key-value pair.""" + buffer = SimpleRawBuffer[str, int]() + buffer.add("key1", 100) + + assert len(buffer) == 1 + assert buffer["key1"] == 100 + + def test_add_multiple_items(self): + """Test adding multiple key-value pairs.""" + buffer = SimpleRawBuffer[str, int]() + buffer.add("key1", 100) + buffer.add("key2", 200) + buffer.add("key3", 300) + + assert len(buffer) == 3 + assert buffer["key1"] == 100 + assert buffer["key2"] == 200 + assert buffer["key3"] == 300 + + def test_add_duplicate_key_raises_error(self): + """Test that adding a duplicate key raises KeyError.""" + buffer = SimpleRawBuffer[str, int]() + buffer.add("key1", 100) + + with pytest.raises(KeyError, match="Key key1 already exists in the buffer"): + buffer.add("key1", 200) + + def test_getitem_existing_key(self): + """Test retrieving an existing key.""" + buffer = SimpleRawBuffer[str, int]() + buffer.add("test_key", 42) + + assert buffer["test_key"] == 42 + + def test_getitem_missing_key_raises_error(self): + """Test that accessing a non-existent key raises KeyError.""" + buffer = SimpleRawBuffer[str, int]() + + with pytest.raises(KeyError): + _ = buffer["missing_key"] + + def test_pop_existing_key(self): + """Test removing and returning a value for an existing key.""" + buffer = SimpleRawBuffer[str, int]() + buffer.add("key1", 100) + buffer.add("key2", 200) + + value = buffer.pop("key1") + + assert value == 100 + assert len(buffer) == 1 + assert "key1" not in buffer.keys() + assert buffer["key2"] == 200 + + def test_pop_missing_key_raises_error(self): + """Test that popping a non-existent key raises KeyError.""" + buffer = SimpleRawBuffer[str, int]() + + with pytest.raises( + KeyError, match="Key missing_key does not exist in the buffer" + ): + buffer.pop("missing_key") + + def test_keys_iteration(self): + """Test iterating over keys.""" + buffer = SimpleRawBuffer[str, int]() + buffer.add("key1", 100) + buffer.add("key2", 200) + buffer.add("key3", 300) + + keys = list(buffer.keys()) + + assert len(keys) == 3 + assert "key1" in keys + assert "key2" in keys + assert "key3" in keys + + def test_keys_empty_buffer(self): + """Test iterating over keys in an empty buffer.""" + buffer = SimpleRawBuffer[str, int]() + + keys = list(buffer.keys()) + + assert keys == [] + + def test_iter_key_value_pairs(self): + """Test iterating over key-value pairs.""" + buffer = SimpleRawBuffer[str, int]() + buffer.add("key1", 100) + buffer.add("key2", 200) + + items = list(buffer) + + assert len(items) == 2 + assert ("key1", 100) in items + assert ("key2", 200) in items + + def test_iter_empty_buffer(self): + """Test iterating over an empty buffer.""" + buffer = SimpleRawBuffer[str, int]() + + items = list(buffer) + + assert items == [] + + def test_clear_buffer(self): + """Test clearing the buffer.""" + buffer = SimpleRawBuffer[str, int]() + buffer.add("key1", 100) + buffer.add("key2", 200) + buffer.add("key3", 300) + + assert len(buffer) == 3 + + buffer.clear() + + assert len(buffer) == 0 + assert list(buffer.keys()) == [] + assert list(buffer) == [] + + def test_clear_empty_buffer(self): + """Test clearing an already empty buffer.""" + buffer = SimpleRawBuffer[str, int]() + + assert len(buffer) == 0 + + buffer.clear() + + assert len(buffer) == 0 + + def test_different_value_types(self): + """Test buffer with different value types.""" + buffer = SimpleRawBuffer[str, list[int]]() + buffer.add("list1", [1, 2, 3]) + buffer.add("list2", [4, 5, 6]) + + assert buffer["list1"] == [1, 2, 3] + assert buffer["list2"] == [4, 5, 6] + + def test_different_key_types(self): + """Test buffer with different key types.""" + buffer = SimpleRawBuffer[int, str]() + buffer.add(1, "value1") + buffer.add(2, "value2") + + assert buffer[1] == "value1" + assert buffer[2] == "value2" + + def test_complex_workflow(self): + """Test a complex workflow with multiple operations.""" + buffer = SimpleRawBuffer[str, int]() + + # Add some items + buffer.add("a", 1) + buffer.add("b", 2) + buffer.add("c", 3) + assert len(buffer) == 3 + + # Pop one item + value = buffer.pop("b") + assert value == 2 + assert len(buffer) == 2 + + # Add another item + buffer.add("d", 4) + assert len(buffer) == 3 + + # Verify remaining items + assert buffer["a"] == 1 + assert buffer["c"] == 3 + assert buffer["d"] == 4 + + # Clear and verify empty + buffer.clear() + assert len(buffer) == 0 + + def test_len_consistency(self): + """Test that len() remains consistent with add/pop operations.""" + buffer = SimpleRawBuffer[str, int]() + + # Initially empty + assert len(buffer) == 0 + + # Add items and check length + for i in range(5): + buffer.add(f"key{i}", i) + assert len(buffer) == i + 1 + + # Remove items and check length + for i in range(5): + buffer.pop(f"key{i}") + assert len(buffer) == 4 - i + + def test_none_values(self): + """Test storing None values.""" + buffer = SimpleRawBuffer[str, int | None]() + buffer.add("none_value", None) + buffer.add("int_value", 42) + + assert buffer["none_value"] is None + assert buffer["int_value"] == 42 + + def test_empty_string_key(self): + """Test using empty string as key.""" + buffer = SimpleRawBuffer[str, int]() + buffer.add("", 42) + + assert buffer[""] == 42 + assert "" in list(buffer.keys()) From 1aeb2b37f531271c79ccdb5a9e10a47a3c16c270 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Wed, 27 Aug 2025 19:03:14 -0700 Subject: [PATCH 3/3] Add StatefulSampler interface and implement RandomStatefulSampler Summary: This diff adds a new interface called `StatefulSampler` and implements a new class called `RandomStatefulSampler`. The `RandomStatefulSampler` class is a stateful sampler that uses Python's `random.sample` function for deterministic sampling. Test Plan: unit tests --- src/forge/data/stateful_sampler.py | 73 +++++++++++++++++++++++ src/forge/interfaces.py | 52 ++++++++++++++++ tests/unit_tests/test_stateful_sampler.py | 59 ++++++++++++++++++ 3 files changed, 184 insertions(+) create mode 100644 src/forge/data/stateful_sampler.py create mode 100644 tests/unit_tests/test_stateful_sampler.py diff --git a/src/forge/data/stateful_sampler.py b/src/forge/data/stateful_sampler.py new file mode 100644 index 000000000..3279a658c --- /dev/null +++ b/src/forge/data/stateful_sampler.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import random +from typing import Any, Generic, List, Mapping, TypeVar + +from forge.interfaces import BufferView, StatefulSampler + +K = TypeVar("K") +V = TypeVar("V") + + +class RandomStatefulSampler(StatefulSampler[K, V], Generic[K, V]): + """A simple stateful sampler that uses Python's random.sample for deterministic sampling. + + This sampler maintains an internal random state that can be saved and restored, + allowing for reproducible sampling behavior. It uses random.sample to select + keys from the buffer without replacement. + """ + + def __init__(self, seed: int | None = None): + """Initialize the sampler with an optional random seed. + + Args: + seed: Optional seed for the random number generator. If None, + the sampler will use Python's default random state. + """ + if seed is None: + self._random = random.Random() + self._random = random.Random(seed) + + def sample_keys(self, buffer: BufferView[K, V], num: int) -> List[K]: + """Sample keys from the buffer using random.sample. + + Args: + buffer: The buffer to sample from + num: Number of keys to sample + + Returns: + A list of sampled keys. If num is greater than the buffer size, + returns all available keys. + """ + # Get all keys from the buffer + all_keys = list(buffer.keys()) + + # If requesting more samples than available, return all keys + if num >= len(all_keys): + return all_keys + + # Use random.sample for sampling without replacement + return self._random.sample(all_keys, num) + + def state_dict(self): + """Return the state dict of the sampler. + + Returns: + A dictionary containing the random number generator state. + """ + return {"random_state": self._random.getstate()} + + def set_state_dict(self, state_dict: Mapping[str, Any]): + """Set the state dict of the sampler. + + Args: + state_dict: Dictionary containing the random state to restore. + """ + if "random_state" in state_dict: + self._random.setstate(state_dict["random_state"]) + else: + raise ValueError("Missing 'random_state' in state dict") diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index 736c3356d..c0579f09f 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -184,6 +184,58 @@ def clear(self) -> None: pass +class StatefulSampler(ABC, Generic[K, V]): + """Abstract interface for stateful samplers with deterministic behavior given a state. + + This class defines the interface for samplers that maintain internal state and provide + deterministic sampling behavior when the state is fixed. + """ + + @abstractmethod + def sample_keys(self, buffer: BufferView[K, V], num: int) -> list[K]: + """Return the keys of selected samples from the buffer. + + This method samples a specified number of keys from the provided buffer + according to the sampler's internal sampling strategy. The sampling + behavior is deterministic for a given internal state of the sampler. + + Args: + buffer (BufferView[K, V]): The buffer to sample from, containing key-value pairs. + num (int): Desired number of samples to retrieve from the buffer. + If num is greater than the buffer size, implementation may + return fewer samples or handle it according to the specific + sampling strategy. + + Returns: + list[K]: A list of keys corresponding to the selected samples. + The length of this list will typically be equal to num, + unless the buffer contains fewer items. + """ + pass + + @abstractmethod + def state_dict(self) -> Mapping[str, Any]: + """Return the state dict of the sampler. + + This method should capture all the internal state necessary to reproduce + the sampler's behavior, such as random number generator states. + + Returns: + dict: A dictionary containing the internal state of the sampler. + """ + pass + + @abstractmethod + def set_state_dict(self, state_dict): + """Set the state dict of the sampler. + + Args: + state_dict (dict): A dictionary containing the internal state to restore + the sampler to a specific configuration. + """ + pass + + class BaseTokenizer(ABC): """ Abstract token encoding model that implements ``encode`` and ``decode`` methods. diff --git a/tests/unit_tests/test_stateful_sampler.py b/tests/unit_tests/test_stateful_sampler.py new file mode 100644 index 000000000..287b69aa5 --- /dev/null +++ b/tests/unit_tests/test_stateful_sampler.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +from forge.data.raw_buffer import SimpleRawBuffer +from forge.data.stateful_sampler import RandomStatefulSampler + +from forge.interfaces import RawBuffer + + +class TestRandomStatefulSampler: + @pytest.fixture + def raw_buffer(self) -> RawBuffer[int, int]: + buffer = SimpleRawBuffer[int, int]() + for n in range(1000): + buffer.add(n, n) + return buffer + + def test_init(self): + sampler = RandomStatefulSampler() + assert True + + def test_init_with_seed(self): + sampler1 = RandomStatefulSampler(seed=42) + sampler2 = RandomStatefulSampler(seed=41) + assert str(sampler1.state_dict()) != str(sampler2.state_dict()) + + def test_state_dict(self): + sampler = RandomStatefulSampler() + state_dict = sampler.state_dict() + assert "random_state" in state_dict + assert state_dict["random_state"] is not None + + def test_set_state_dict_no_random_state(self): + sampler = RandomStatefulSampler() + state_dict = {} + with pytest.raises(ValueError, match="Missing 'random_state'"): + sampler.set_state_dict(state_dict) + + def test_deterministic(self, raw_buffer): + sampler1 = RandomStatefulSampler(seed=42) + sampler2 = RandomStatefulSampler() + sampler2.set_state_dict(sampler1.state_dict()) + for _ in range(10): + batch1 = sampler1.sample_keys(raw_buffer, 5) + batch2 = sampler2.sample_keys(raw_buffer, 5) + assert batch1 == batch2 + + def test_deterministic_resume(self, raw_buffer): + sampler1 = RandomStatefulSampler(seed=42) + sampler2 = RandomStatefulSampler() + for _ in range(10): + sampler2.set_state_dict(sampler1.state_dict()) + batch1 = sampler1.sample_keys(raw_buffer, 5) + batch2 = sampler2.sample_keys(raw_buffer, 5) + assert batch1 == batch2