-
Notifications
You must be signed in to change notification settings - Fork 33
Refactor replay buffer to use KV buffer #147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 32 commits
da21e1d
5c72908
b4d7a61
02d77c6
fd1d38b
f79beee
d8d775a
e423c44
4815c05
77d41e4
a3feb1e
ff6f5c7
d9411b9
dcc8e00
d812b9c
05dd33b
23d7e02
9953c91
726be1c
e25a239
ddde20f
0016889
1aec7aa
0d7f4ac
0234808
ee5bb0c
b32d840
80113df
f23285c
895858c
d8ba98d
af84852
8bef515
a58fd2d
780239a
2d2503b
c66cfd9
393bcca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,34 +5,36 @@ | |
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import random | ||
| import uuid | ||
| from dataclasses import dataclass | ||
| from typing import Any | ||
|
|
||
| from monarch.actor import endpoint | ||
|
|
||
| from forge.controller import ForgeActor | ||
| from forge.interfaces import StoreInterface | ||
|
|
||
|
|
||
| @dataclass | ||
| class ReplayBuffer(ForgeActor): | ||
| """Simple in-memory replay buffer implementation.""" | ||
|
|
||
| store: StoreInterface | ||
| batch_size: int | ||
| max_policy_age: int | ||
| dp_size: int = 1 | ||
| seed: int | None = None | ||
|
|
||
| @endpoint | ||
| async def setup(self) -> None: | ||
| self.buffer: list = [] | ||
| def __post_init__(self): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @joecummings I changed the |
||
| if self.seed is None: | ||
| self.seed = random.randint(0, 2**32) | ||
| random.seed(self.seed) | ||
| self.sampler = random.sample | ||
|
|
||
| @endpoint | ||
| async def add(self, episode) -> None: | ||
| self.buffer.append(episode) | ||
| key = f"rb_{uuid.uuid4().hex}" | ||
| await self.store.put(key, episode) | ||
|
|
||
| @endpoint | ||
| async def sample(self, curr_policy_version: int, batch_size: int | None = None): | ||
|
|
@@ -50,20 +52,20 @@ async def sample(self, curr_policy_version: int, batch_size: int | None = None): | |
| total_samples = self.dp_size * bsz | ||
|
|
||
| # Evict old episodes | ||
| self._evict(curr_policy_version) | ||
|
|
||
| if total_samples > len(self.buffer): | ||
| await self._evict(curr_policy_version) | ||
|
|
||
| keys = await self.store.keys() | ||
|
|
||
| total_available = await self.store.numel() | ||
| if total_samples > total_available: | ||
| return None | ||
|
|
||
| # TODO: Make this more efficient | ||
|
||
| idx_to_sample = self.sampler(range(len(self.buffer)), k=total_samples) | ||
| # Pop episodes in descending order to avoid shifting issues | ||
| popped = [self.buffer.pop(i) for i in sorted(idx_to_sample, reverse=True)] | ||
| idx_to_sample = self.sampler(range(len(keys)), k=total_samples) | ||
|
|
||
| # Reorder popped episodes to match the original random sample order | ||
| sorted_idxs = sorted(idx_to_sample, reverse=True) | ||
| idx_to_popped = dict(zip(sorted_idxs, popped)) | ||
| sampled_episodes = [idx_to_popped[i] for i in idx_to_sample] | ||
| # Fetch and remove the sampled episodes | ||
| sampled_episodes = [await self.store.pop(keys[i]) for i in idx_to_sample] | ||
|
|
||
| # Reshape into (dp_size, bsz, ...) | ||
| reshaped_episodes = [ | ||
|
|
@@ -81,38 +83,47 @@ async def evict(self, curr_policy_version: int) -> None: | |
| Args: | ||
| curr_policy_version (int): The current policy version. | ||
| """ | ||
| self._evict(curr_policy_version) | ||
| await self._evict(curr_policy_version) | ||
|
|
||
| def _evict(self, curr_policy_version: int) -> None: | ||
| self.buffer = [ | ||
| trajectory | ||
| for trajectory in self.buffer | ||
| if (curr_policy_version - trajectory.policy_version) <= self.max_policy_age | ||
| ] | ||
| async def _evict(self, curr_policy_version: int) -> None: | ||
DNXie marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We control this internal method so you can pass in the keys from above, which we already calculated. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, we have to re-fetch the keys after eviction because There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need to reconsider the whole eviction logic. See my previous point re: concurerncy. |
||
| keys = await self.store.keys() | ||
| for key in keys: | ||
| episode = await self.store.get(key) | ||
|
||
| # TODO: Could store keys as policy_version+uuid to evict without fetching each episode | ||
| if (curr_policy_version - episode.policy_version) > self.max_policy_age: | ||
| await self.store.delete(key) | ||
|
|
||
| @endpoint | ||
| async def _getitem(self, idx: int): | ||
| return self.buffer[idx] | ||
| async def _getitem(self, key: str): | ||
| return await self.store.get(key) | ||
|
|
||
| @endpoint | ||
| async def _numel(self) -> int: | ||
| """Number of elements (episodes) in the replay buffer.""" | ||
| return len(self.buffer) | ||
| return await self.store.numel() | ||
|
|
||
| @endpoint | ||
| async def clear(self) -> None: | ||
| """Clear the replay buffer immediately - dropping all episodes.""" | ||
| self.buffer.clear() | ||
| await self._clear() | ||
|
|
||
| async def _clear(self) -> None: | ||
| await self.store.delete_all() | ||
|
|
||
| @endpoint | ||
| async def state_dict(self) -> dict[str, Any]: | ||
| keys = await self.store.keys() | ||
| episodes = [(k, await self.store.get(k)) for k in keys] | ||
|
||
| return { | ||
| "buffer": self.buffer, | ||
| "buffer": episodes, | ||
DNXie marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "rng_state": random.getstate(), | ||
| "seed": self.seed, | ||
| } | ||
|
|
||
| @endpoint | ||
| async def load_state_dict(self, state_dict: dict[str, Any]) -> None: | ||
| self.buffer = state_dict["buffer"] | ||
| await self._clear() | ||
| for k, ep in state_dict["buffer"]: | ||
| await self.store.put(k, ep) | ||
| random.setstate(state_dict["rng_state"]) | ||
| self.seed = state_dict["seed"] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
|
|
||
| from typing import Any | ||
|
|
||
| from src.forge.interfaces import StoreInterface | ||
|
|
||
|
|
||
| class KVStore(StoreInterface): | ||
DNXie marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| A simple single-node key-value (KV) store implementation of StoreInterface. | ||
| This acts as a temporary backend for the replay buffer until torchstore | ||
| supports the full set of operations we need (delete, pop, keys, numel, etc.). | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain the consistency semantics and the thread safety of the interface? For example, once put finishes, any future gets will always see the effect of the put. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current
The plan is to switch to torchstore once the key APIs like |
||
| """ | ||
|
|
||
| def __init__(self): | ||
| self._store = {} | ||
|
|
||
| async def put(self, key: str, value: Any) -> None: | ||
| self._store[key] = value | ||
|
|
||
| async def get(self, key: str) -> Any: | ||
| return self._store[key] | ||
|
|
||
| async def exists(self, key: str) -> bool: | ||
| # Check if a key exists in the KV store | ||
| return key in self._store | ||
|
|
||
| async def keys(self, prefix: str | None = None) -> list[str]: | ||
| # Return all keys, optionally filtered by prefix | ||
| if prefix is None: | ||
| return list(self._store.keys()) | ||
| return [k for k in self._store if k.startswith(prefix)] | ||
|
|
||
| async def numel(self, prefix: str | None = None) -> int: | ||
| # Return the number of key-value pairs, optionally filtered by prefix | ||
| return len(await self.keys(prefix)) | ||
|
|
||
| async def delete(self, key: str) -> None: | ||
| # Delete a key-value pair from the store | ||
| del self._store[key] | ||
|
|
||
| async def pop(self, key: str) -> Any: | ||
| # Remove and return a key-value pair (get + delete) | ||
| return self._store.pop(key) | ||
|
|
||
| async def delete_all(self, prefix: str | None = None) -> None: | ||
| # Delete all key-value pairs matching the given prefix | ||
| if prefix is None: | ||
| # Optimize for deleting all keys | ||
| self._store = {} | ||
| else: | ||
| # Delete only keys matching the prefix | ||
| keys_to_delete = await self.keys(prefix) | ||
| for key in keys_to_delete: | ||
| del self._store[key] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,7 @@ | |
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from abc import ABC, abstractmethod | ||
| from typing import Any, Mapping | ||
| from typing import Any, List, Mapping | ||
|
|
||
| from monarch.actor import endpoint | ||
|
|
||
|
|
@@ -208,6 +208,133 @@ def __call__(self, observation: Observation) -> float: | |
| pass | ||
|
|
||
|
|
||
| class StoreInterface(ABC): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now, pair this down to the exact APIs we will be using in the ReplayBuffer - no more, no less. We can always update the interface later. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I’ve already pared the interface down to only the APIs we need. My concern is that methods like |
||
| """ | ||
| Abstract base class for a KV store. This closely follows the interface of | ||
| torchstore. | ||
| """ | ||
|
|
||
| # TODO: support this in torchstore. | ||
| @abstractmethod | ||
| async def numel(self, prefix=None) -> int: | ||
| """Return the number of keys starting with the given prefix. | ||
| The prefix matching follows reverse domain name notation convention. | ||
|
|
||
| Args: | ||
| prefix (str): The prefix to match against stored keys. | ||
| For example, "xyz" matches "xyz.abc.def" but "xy" does not. | ||
| Note: None is the prefix of all keys, while "" is the prefix of keys | ||
| starting with "." and "" itself. | ||
|
|
||
| Returns: | ||
| int: The number of keys matching the prefix in the store. | ||
| """ | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| async def keys(self, prefix=None) -> List[str]: | ||
| """Return an iterable of all keys in the store matching the given prefix. | ||
| The prefix matching follows reverse domain name notation convention. | ||
|
|
||
| Args: | ||
| prefix (str): The prefix to match against stored keys. | ||
| For example, "xyz" matches "xyz.abc.def" but "xy" does not. | ||
| Note: None is the prefix of all keys, while "" is the prefix of keys | ||
| starting with "." and "" itself. | ||
|
|
||
| Returns: | ||
| Iterable[K]: An iterable containing all keys in the buffer. | ||
| """ | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| async def put(self, key: str, value: Any) -> None: | ||
| """ | ||
| Add a key-value pair to the buffer. | ||
|
|
||
| Args: | ||
| key (K): The key to store the value under | ||
| val (V): The value to store in the buffer | ||
|
|
||
| Returns: | ||
| None | ||
| """ | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| async def get(self, key: str) -> Any: | ||
| """ | ||
| Get a key-value pair from the store. | ||
|
|
||
| Args: | ||
| key (K): The key to get | ||
|
|
||
| Returns: | ||
| V: The value stored under the key | ||
|
|
||
| Raises: | ||
| KeyError: If the key does not exist in the store | ||
| """ | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| async def exists(self, key: str) -> bool: | ||
| """ | ||
| Check if a key exists in the store. | ||
| """ | ||
| pass | ||
|
|
||
| # TODO: support this in torchstore. | ||
| @abstractmethod | ||
| async def pop(self, key: str) -> Any: | ||
| """ | ||
| Get a key-value pair from the store, and delete it from the store. | ||
|
|
||
| Args: | ||
| key (K): The key to get | ||
|
|
||
| Returns: | ||
| V: The value stored under the key | ||
|
|
||
| Raises: | ||
| KeyError: If the key does not exist in the store | ||
| """ | ||
|
|
||
| # TODO: support this in torchstore. | ||
| @abstractmethod | ||
| async def delete(self, key: str) -> None: | ||
| """ | ||
| Delete a key-value pair from the store. | ||
|
|
||
| Args: | ||
| key (K): The key to delete | ||
|
|
||
| Returns: | ||
| None | ||
|
|
||
| Raises: | ||
| KeyError: If the key does not exist in the store | ||
| """ | ||
| pass | ||
|
|
||
| # TODO: support this in torchstore. | ||
| @abstractmethod | ||
| async def delete_all(self, prefix=None) -> None: | ||
| """ | ||
| Delete all key-value pairs from the store matching the given prefix. | ||
| The prefix matching follows reverse domain name notation convention. | ||
|
|
||
| Args: | ||
| prefix (str): The prefix to match against stored keys. | ||
| For example, "xyz" matches "xyz.abc.def" but "xy" does not. | ||
| Note: None is the prefix of all keys, while "" is the prefix of keys | ||
| starting with "." and "" itself. | ||
|
|
||
| Returns: None | ||
| """ | ||
| pass | ||
|
|
||
|
|
||
| # TODO | ||
| # class RLLoss(ABC): | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe call this the backend?
wdyt @LucasLLC ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed
store->backend