diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 85b9a7d7c..4de8658c6 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -25,6 +25,7 @@ from forge.controller.provisioner import shutdown from forge.controller.service import ServiceConfig, shutdown_service, spawn_service from forge.data.rewards import MathReward, ThinkingReward +from forge.data.stores import KVStore from forge.data.utils import exclude_service from forge.util.metric_logging import get_metric_logger from monarch.actor import endpoint @@ -373,6 +374,7 @@ async def main(cfg: DictConfig): spawn_service( ServiceConfig(**cfg.replay_buffer.service), ReplayBuffer, + backend=KVStore(), **exclude_service(cfg.replay_buffer), ), spawn_service( diff --git a/apps/rl/main.py b/apps/rl/main.py index 7d00eb09e..0036ce946 100644 --- a/apps/rl/main.py +++ b/apps/rl/main.py @@ -21,6 +21,7 @@ from forge.actors import ReplayBuffer, RLTrainer from forge.cli.config import parse from forge.controller.service import ServiceConfig, shutdown_service, spawn_service +from forge.data.stores import KVStore from omegaconf import DictConfig from torch import Tensor @@ -145,6 +146,7 @@ async def run(cfg: DictConfig): spawn_service( ServiceConfig(procs_per_replica=1, num_replicas=1), ReplayBuffer, + backend=KVStore(), collate=collate, **cfg.replay_buffer, ), diff --git a/apps/toy_rl/main.py b/apps/toy_rl/main.py index 5a961caba..9c2657815 100644 --- a/apps/toy_rl/main.py +++ b/apps/toy_rl/main.py @@ -17,6 +17,8 @@ from forge.actors.collector import Collector from forge.actors.replay_buffer import ReplayBuffer + +from forge.data.stores import KVStore from forge.interfaces import Environment, Policy from forge.types import Action, Observation, State from monarch.actor import endpoint, proc_mesh @@ -141,8 +143,10 @@ async def main(): replay_buffer = await replay_procs.spawn( "replay_buffer", ReplayBuffer, - SAMPLES_PER_BATCH, # batch_size - float("inf"), # max_policy_age + backend=KVStore(), + batch_size=SAMPLES_PER_BATCH, + max_policy_age=float("inf"), + dp_size=1, ) # TODO - add in an example of a "vLLM executor" and "vLLM controller" diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index 985fc4052..7972e7e37 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -5,27 +5,28 @@ # LICENSE file in the root directory of this source tree. import random +import uuid from dataclasses import dataclass from typing import Any, Callable from monarch.actor import endpoint from forge.controller import ForgeActor +from forge.interfaces import StoreInterface @dataclass class ReplayBuffer(ForgeActor): """Simple in-memory replay buffer implementation.""" + backend: StoreInterface batch_size: int max_policy_age: int dp_size: int = 1 seed: int | None = None collate: Callable = lambda batch: batch - @endpoint - async def setup(self) -> None: - self.buffer: list = [] + def __post_init__(self): if self.seed is None: self.seed = random.randint(0, 2**32) random.seed(self.seed) @@ -33,7 +34,8 @@ async def setup(self) -> None: @endpoint async def add(self, episode: "Episode") -> None: - self.buffer.append(episode) + key = f"rb_{uuid.uuid4().hex}" + await self.backend.put(key, episode) @endpoint async def sample(self, curr_policy_version: int, batch_size: int | None = None): @@ -51,20 +53,20 @@ async def sample(self, curr_policy_version: int, batch_size: int | None = None): total_samples = self.dp_size * bsz # Evict old episodes - self._evict(curr_policy_version) + # TODO: _evict() before keys() isn't concurrency-safe; may need async lock or refactor. See PR #147. + await self._evict(curr_policy_version) + + keys = await self.backend.keys() - if total_samples > len(self.buffer): + total_available = await self.backend.numel() + if total_samples > total_available: return None # TODO: prefetch samples in advance - idx_to_sample = self.sampler(range(len(self.buffer)), k=total_samples) - # Pop episodes in descending order to avoid shifting issues - popped = [self.buffer.pop(i) for i in sorted(idx_to_sample, reverse=True)] + idx_to_sample = self.sampler(range(len(keys)), k=total_samples) - # Reorder popped episodes to match the original random sample order - sorted_idxs = sorted(idx_to_sample, reverse=True) - idx_to_popped = dict(zip(sorted_idxs, popped)) - sampled_episodes = [idx_to_popped[i] for i in idx_to_sample] + # Fetch and remove the sampled episodes + sampled_episodes = [await self.backend.pop(keys[i]) for i in idx_to_sample] # Reshape into (dp_size, bsz, ...) reshaped_episodes = [ @@ -82,38 +84,47 @@ async def evict(self, curr_policy_version: int) -> None: Args: curr_policy_version (int): The current policy version. """ - self._evict(curr_policy_version) + await self._evict(curr_policy_version) - def _evict(self, curr_policy_version: int) -> None: - self.buffer = [ - trajectory - for trajectory in self.buffer - if (curr_policy_version - trajectory.policy_version) <= self.max_policy_age - ] + async def _evict(self, curr_policy_version: int) -> None: + keys = await self.backend.keys() + for key in keys: + episode = await self.backend.get(key) + # TODO: Could store keys as policy_version+uuid to evict without fetching each episode + if (curr_policy_version - episode.policy_version) > self.max_policy_age: + await self.backend.delete(key) @endpoint - async def _getitem(self, idx: int): - return self.buffer[idx] + async def _getitem(self, key: str): + return await self.backend.get(key) @endpoint async def _numel(self) -> int: """Number of elements (episodes) in the replay buffer.""" - return len(self.buffer) + return await self.backend.numel() @endpoint async def clear(self) -> None: """Clear the replay buffer immediately - dropping all episodes.""" - self.buffer.clear() + await self._clear() + + async def _clear(self) -> None: + await self.backend.delete_all() @endpoint async def state_dict(self) -> dict[str, Any]: + keys = await self.backend.keys() + episodes = [(k, await self.backend.get(k)) for k in keys] return { - "buffer": self.buffer, + "buffer": episodes, "rng_state": random.getstate(), "seed": self.seed, } @endpoint async def load_state_dict(self, state_dict: dict[str, Any]) -> None: - self.buffer = state_dict["buffer"] + await self._clear() + for k, ep in state_dict["buffer"]: + await self.backend.put(k, ep) random.setstate(state_dict["rng_state"]) + self.seed = state_dict["seed"] diff --git a/src/forge/data/stores.py b/src/forge/data/stores.py new file mode 100644 index 000000000..3505cb55e --- /dev/null +++ b/src/forge/data/stores.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Any + +from src.forge.interfaces import StoreInterface + + +class KVStore(StoreInterface): + """ + A simple single-node key-value (KV) store implementation of StoreInterface. + + This acts as a temporary backend for the replay buffer until torchstore + supports the full set of operations we need (delete, pop, keys, numel, etc.). + """ + + def __init__(self): + self._store = {} + + async def put(self, key: str, value: Any) -> None: + self._store[key] = value + + async def get(self, key: str) -> Any: + return self._store[key] + + async def exists(self, key: str) -> bool: + # Check if a key exists in the KV store + return key in self._store + + async def keys(self, prefix: str | None = None) -> list[str]: + # Return all keys, optionally filtered by prefix + if prefix is None: + return list(self._store.keys()) + return [k for k in self._store if k.startswith(prefix)] + + async def numel(self, prefix: str | None = None) -> int: + # Return the number of key-value pairs, optionally filtered by prefix + return len(await self.keys(prefix)) + + async def delete(self, key: str) -> None: + # Delete a key-value pair from the store + del self._store[key] + + async def pop(self, key: str) -> Any: + # Remove and return a key-value pair (get + delete) + return self._store.pop(key) + + async def delete_all(self, prefix: str | None = None) -> None: + # Delete all key-value pairs matching the given prefix + if prefix is None: + # Optimize for deleting all keys + self._store = {} + else: + # Delete only keys matching the prefix + keys_to_delete = await self.keys(prefix) + for key in keys_to_delete: + del self._store[key] diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index 3dbbd560e..b10120043 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import Any, Mapping +from typing import Any, List, Mapping from monarch.actor import endpoint @@ -208,6 +208,133 @@ def __call__(self, observation: Observation) -> float: pass +class StoreInterface(ABC): + """ + Abstract base class for a KV store. This closely follows the interface of + torchstore. + """ + + # TODO: support this in torchstore. + @abstractmethod + async def numel(self, prefix=None) -> int: + """Return the number of keys starting with the given prefix. + The prefix matching follows reverse domain name notation convention. + + Args: + prefix (str): The prefix to match against stored keys. + For example, "xyz" matches "xyz.abc.def" but "xy" does not. + Note: None is the prefix of all keys, while "" is the prefix of keys + starting with "." and "" itself. + + Returns: + int: The number of keys matching the prefix in the store. + """ + pass + + @abstractmethod + async def keys(self, prefix=None) -> List[str]: + """Return an iterable of all keys in the store matching the given prefix. + The prefix matching follows reverse domain name notation convention. + + Args: + prefix (str): The prefix to match against stored keys. + For example, "xyz" matches "xyz.abc.def" but "xy" does not. + Note: None is the prefix of all keys, while "" is the prefix of keys + starting with "." and "" itself. + + Returns: + Iterable[K]: An iterable containing all keys in the buffer. + """ + pass + + @abstractmethod + async def put(self, key: str, value: Any) -> None: + """ + Add a key-value pair to the buffer. + + Args: + key (K): The key to store the value under + val (V): The value to store in the buffer + + Returns: + None + """ + pass + + @abstractmethod + async def get(self, key: str) -> Any: + """ + Get a key-value pair from the store. + + Args: + key (K): The key to get + + Returns: + V: The value stored under the key + + Raises: + KeyError: If the key does not exist in the store + """ + pass + + @abstractmethod + async def exists(self, key: str) -> bool: + """ + Check if a key exists in the store. + """ + pass + + # TODO: support this in torchstore. + @abstractmethod + async def pop(self, key: str) -> Any: + """ + Get a key-value pair from the store, and delete it from the store. + + Args: + key (K): The key to get + + Returns: + V: The value stored under the key + + Raises: + KeyError: If the key does not exist in the store + """ + + # TODO: support this in torchstore. + @abstractmethod + async def delete(self, key: str) -> None: + """ + Delete a key-value pair from the store. + + Args: + key (K): The key to delete + + Returns: + None + + Raises: + KeyError: If the key does not exist in the store + """ + pass + + # TODO: support this in torchstore. + @abstractmethod + async def delete_all(self, prefix=None) -> None: + """ + Delete all key-value pairs from the store matching the given prefix. + The prefix matching follows reverse domain name notation convention. + + Args: + prefix (str): The prefix to match against stored keys. + For example, "xyz" matches "xyz.abc.def" but "xy" does not. + Note: None is the prefix of all keys, while "" is the prefix of keys + starting with "." and "" itself. + + Returns: None + """ + pass + + # TODO # class RLLoss(ABC): diff --git a/tests/unit_tests/rl/test_toy_rl.py b/tests/unit_tests/rl/test_toy_rl.py index 1ee79ed47..27cd23881 100644 --- a/tests/unit_tests/rl/test_toy_rl.py +++ b/tests/unit_tests/rl/test_toy_rl.py @@ -26,6 +26,7 @@ # testing purposes. It lacks some features of the real proc_mesh # but spawns much quicker from monarch.actor import Actor, endpoint, local_proc_mesh +from src.forge.data.stores import KVStore class TestToyEnvironment: @@ -211,10 +212,11 @@ async def test_full_rl_pipeline_simulation(self): replay_buffer = await proc.spawn( "replay_buffer", ReplayBuffer, - 1, # batch_size - 1, # max_policy_age + backend=KVStore(), + batch_size=1, + max_policy_age=1, + dp_size=1, ) - await replay_buffer.setup.call() collector = await proc.spawn( "collector", Collector, diff --git a/tests/unit_tests/test_kv_store.py b/tests/unit_tests/test_kv_store.py new file mode 100644 index 000000000..6d227ae9a --- /dev/null +++ b/tests/unit_tests/test_kv_store.py @@ -0,0 +1,136 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Test for forge/data/stores.py""" + +import pytest +import pytest_asyncio +from forge.data.stores import KVStore + + +class TestKVStore: + @pytest_asyncio.fixture + async def store(self) -> KVStore: + return KVStore() + + @pytest.mark.asyncio + async def test_put_different_types(self, store: KVStore) -> None: + """Test put and get with different value types.""" + await store.put("string_key", "string_value") + await store.put("int_key", 42) + await store.put("dict_key", {"nested": "dict"}) + await store.put("list_key", [1, 2, 3]) + + assert await store.get("string_key") == "string_value" + assert await store.get("int_key") == 42 + assert await store.get("dict_key") == {"nested": "dict"} + assert await store.get("list_key") == [1, 2, 3] + + @pytest.mark.asyncio + async def test_get_nonexistent_key(self, store: KVStore) -> None: + """Test getting a key that doesn't exist raises KeyError.""" + with pytest.raises(KeyError): + await store.get("nonexistent_key") + + @pytest.mark.asyncio + async def test_exists(self, store: KVStore) -> None: + """Test exists method.""" + assert not await store.exists("key1") + await store.put("key1", "value1") + assert await store.exists("key1") + + @pytest.mark.asyncio + async def test_keys_with_prefix(self, store: KVStore) -> None: + """Test keys method with prefix.""" + await store.put("user.001", "user1") + await store.put("user.002", "user2") + await store.put("post.001", "post1") + await store.put("comment.001", "comment1") + + user_keys = await store.keys("user") + assert set(user_keys) == {"user.001", "user.002"} + + post_keys = await store.keys("post") + assert set(post_keys) == {"post.001"} + + empty_keys = await store.keys("nonexistent") + assert empty_keys == [] + + # Test delete_all with prefix + await store.delete_all("user") + assert await store.numel("user") == 0 + assert await store.numel() == 2 # post and comment remain + + # Test delete_all with non-existent prefix + await store.delete_all("nonexistent") + assert await store.numel() == 2 # no change + + @pytest.mark.asyncio + async def test_keys_empty_store(self, store: KVStore) -> None: + """Test keys method on empty store.""" + keys = await store.keys() + assert keys == [] + + keys_with_prefix = await store.keys("prefix") + assert keys_with_prefix == [] + + @pytest.mark.asyncio + async def test_numel_with_prefix(self, store: KVStore) -> None: + """Test numel method with prefix.""" + await store.put("user.001", "user1") + await store.put("user.002", "user2") + await store.put("post.001", "post1") + + assert await store.numel("user") == 2 + assert await store.numel("post") == 1 + assert await store.numel("nonexistent") == 0 + assert await store.numel() == 3 + + @pytest.mark.asyncio + async def test_delete(self, store: KVStore) -> None: + """Test delete method.""" + await store.put("key1", "value1") + assert await store.exists("key1") + + await store.delete("key1") + assert not await store.exists("key1") + + @pytest.mark.asyncio + async def test_delete_nonexistent_key(self, store: KVStore) -> None: + """Test deleting a key that doesn't exist raises KeyError.""" + with pytest.raises(KeyError): + await store.delete("nonexistent_key") + + @pytest.mark.asyncio + async def test_pop(self, store: KVStore) -> None: + """Test pop method.""" + await store.put("key1", "value1") + + result = await store.pop("key1") + assert result == "value1" + assert not await store.exists("key1") + + @pytest.mark.asyncio + async def test_pop_nonexistent_key(self, store: KVStore) -> None: + """Test popping a key that doesn't exist raises KeyError.""" + with pytest.raises(KeyError): + await store.pop("nonexistent_key") + + @pytest.mark.asyncio + async def test_none_values(self, store: KVStore) -> None: + """Test storing and retrieving None values.""" + await store.put("none_key", None) + + assert await store.exists("none_key") + result = await store.get("none_key") + assert result is None + + popped_result = await store.pop("none_key") + assert popped_result is None + + # Test delete_all on empty store + await store.delete_all() + assert await store.numel() == 0 diff --git a/tests/unit_tests/test_replay_buffer.py b/tests/unit_tests/test_replay_buffer.py index 4463c3f2c..b245a27fc 100644 --- a/tests/unit_tests/test_replay_buffer.py +++ b/tests/unit_tests/test_replay_buffer.py @@ -9,6 +9,7 @@ import pytest import pytest_asyncio from forge.actors.replay_buffer import ReplayBuffer +from forge.data.stores import KVStore from forge.types import Trajectory from monarch.actor import proc_mesh @@ -18,10 +19,15 @@ class TestReplayBuffer: @pytest_asyncio.fixture async def replay_buffer(self) -> ReplayBuffer: mesh = await proc_mesh(gpus=1) + backend = KVStore() replay_buffer = await mesh.spawn( - "replay_buffer", ReplayBuffer, batch_size=2, max_policy_age=1 + "replay_buffer", + ReplayBuffer, + backend=backend, + batch_size=2, + max_policy_age=1, + dp_size=1, ) - await replay_buffer.setup.call() return replay_buffer @pytest.mark.asyncio @@ -29,7 +35,12 @@ async def test_add(self, replay_buffer: ReplayBuffer) -> None: trajectory = Trajectory(policy_version=0) await replay_buffer.add.call_one(trajectory) assert replay_buffer._numel.call_one().get() == 1 - assert replay_buffer._getitem.call_one(0).get() == trajectory + assert ( + replay_buffer.sample.call_one(curr_policy_version=1, batch_size=1).get()[0][ + 0 + ] + == trajectory + ) replay_buffer.clear.call_one().get() @pytest.mark.asyncio @@ -39,19 +50,42 @@ async def test_add_multiple(self, replay_buffer) -> None: await replay_buffer.add.call_one(trajectory_0) await replay_buffer.add.call_one(trajectory_1) assert replay_buffer._numel.call_one().get() == 2 - assert replay_buffer._getitem.call_one(0).get() == trajectory_0 - assert replay_buffer._getitem.call_one(1).get() == trajectory_1 + sampled = replay_buffer.sample.call_one( + curr_policy_version=1, batch_size=2 + ).get() + flat_sampled = [ep for dp in sampled for ep in dp] + assert all(ep in [trajectory_0, trajectory_1] for ep in flat_sampled) + + # By curr_policy_version = 2, t0 should be evicted (age > 1) + result = replay_buffer.sample.call_one(curr_policy_version=2).get() + assert result is None # not enough episodes left (only t1 remains) + replay_buffer.clear.call_one().get() @pytest.mark.asyncio async def test_state_dict_save_load(self, replay_buffer) -> None: - trajectory = Trajectory(policy_version=0) - await replay_buffer.add.call_one(trajectory) + trajectory_0 = Trajectory(policy_version=0) + trajectory_1 = Trajectory(policy_version=1) + await replay_buffer.add.call_one(trajectory_0) + await replay_buffer.add.call_one(trajectory_1) + + # Save state dict state_dict = replay_buffer.state_dict.call_one().get() + + # Clear the buffer replay_buffer.clear.call_one().get() assert replay_buffer._numel.call_one().get() == 0 + + # Load state dict await replay_buffer.load_state_dict.call_one(state_dict) - assert replay_buffer._numel.call_one().get() == 1 + assert replay_buffer._numel.call_one().get() == 2 + + # Save state again + restored_state_dict = replay_buffer.state_dict.call_one().get() + + # Check equality (buffer contents + rng_state + seed) + assert state_dict == restored_state_dict + replay_buffer.clear.call_one().get() @pytest.mark.asyncio @@ -113,11 +147,16 @@ async def test_sample_with_evictions(self, replay_buffer) -> None: async def test_sample_dp_size(self) -> None: """Test that len(samples) == dp_size when sampling.""" mesh = await proc_mesh(gpus=1) + backend = KVStore() # Create replay buffer with dp_size=3 replay_buffer = await mesh.spawn( - "replay_buffer", ReplayBuffer, batch_size=2, max_policy_age=1, dp_size=3 + "replay_buffer", + ReplayBuffer, + backend=backend, + batch_size=2, + max_policy_age=1, + dp_size=3, ) - await replay_buffer.setup.call() # Add enough trajectories to sample for i in range(10):