From 9dba3e61aa8ddedfb9a4039eb5f7e8201d753cd5 Mon Sep 17 00:00:00 2001 From: pbontrager Date: Tue, 14 Oct 2025 19:56:07 +0000 Subject: [PATCH 1/9] updated replay buffer --- src/forge/actors/replay_buffer.py | 174 +++++++++++++++---------- tests/unit_tests/test_replay_buffer.py | 26 ++-- 2 files changed, 117 insertions(+), 83 deletions(-) diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index a92fd5501..d04e87006 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -6,7 +6,9 @@ import logging import random +from collections import deque from dataclasses import dataclass +from operator import itemgetter from typing import Any, Callable from monarch.actor import endpoint @@ -19,22 +21,48 @@ logger.setLevel(logging.INFO) +@dataclass +class BufferEntry: + data: "Episode" + sample_count: int = 0 + + +def default_evict(buffer, policy_version, max_samples=None, max_age=None): + """Default buffer eviction policy""" + indices = [] + for i, entry in enumerate(buffer): + if max_age and policy_version - entry.data.policy_version > max_age: + continue + if max_samples and entry.sample_count >= max_samples: + continue + indices.append(i) + return indices + + +def default_sample(buffer, sample_size, sampler, policy_version): + """Default buffer sampling policy""" + if sample_size > len(buffer): + return None + return sampler(range(len(buffer)), k=sample_size) + + @dataclass class ReplayBuffer(ForgeActor): """Simple in-memory replay buffer implementation.""" batch_size: int - max_policy_age: int dp_size: int = 1 + max_policy_age: int | None = None + max_buffer_size: int | None = None + max_resample_count: int | None = 0 seed: int | None = None collate: Callable = lambda batch: batch - - def __post_init__(self): - super().__init__() + eviction_policy: Callable = default_evict + sample_policy: Callable = default_sample @endpoint async def setup(self) -> None: - self.buffer: list = [] + self.buffer: deque = deque(maxlen=self.max_buffer_size) if self.seed is None: self.seed = random.randint(0, 2**32) random.seed(self.seed) @@ -42,20 +70,16 @@ async def setup(self) -> None: @endpoint async def add(self, episode: "Episode") -> None: - self.buffer.append(episode) + self.buffer.append(BufferEntry(episode)) record_metric("buffer/add/count_episodes_added", 1, Reduce.SUM) @endpoint @trace("buffer_perf/sample", track_memory=False) - async def sample( - self, curr_policy_version: int, batch_size: int | None = None - ) -> tuple[tuple[Any, ...], ...] | None: + async def sample(self, curr_policy_version: int) -> tuple[tuple[Any, ...], ...] | None: """Sample from the replay buffer. Args: curr_policy_version (int): The current policy version. - batch_size (int, optional): Number of episodes to sample. If none, defaults to batch size - passed in at initialization. Returns: A list of sampled episodes with shape (dp_size, bsz, ...) or None if there are not enough episodes in the buffer. @@ -63,45 +87,37 @@ async def sample( # Record sample request metric record_metric("buffer/sample/count_sample_requests", 1, Reduce.SUM) - bsz = batch_size if batch_size is not None else self.batch_size - total_samples = self.dp_size * bsz + total_samples = self.dp_size * self.batch_size - # Evict old episodes + # Evict episodes self._evict(curr_policy_version) - - if total_samples > len(self.buffer): - return None - - # Calculate buffer utilization - utilization_pct = ( - (total_samples / len(self.buffer)) * 100 if len(self.buffer) > 0 else 0 - ) - - record_metric( - "buffer/sample/avg_buffer_utilization", - len(self.buffer), - Reduce.MEAN, - ) - - record_metric( - "buffer/sample/avg_buffer_utilization_pct", - utilization_pct, - Reduce.MEAN, - ) + + # Calculate metrics + if len(self.buffer) > 0: + record_metric( + "buffer/sample/avg_data_utilization", + total_samples / len(self.buffer), + Reduce.MEAN, + ) + if self.max_buffer_size: + record_metric( + "buffer/sample/avg_buffer_utilization", + len(self.buffer) / self.max_buffer_size, + Reduce.MEAN, + ) # TODO: prefetch samples in advance - idx_to_sample = self.sampler(range(len(self.buffer)), k=total_samples) - # Pop episodes in descending order to avoid shifting issues - popped = [self.buffer.pop(i) for i in sorted(idx_to_sample, reverse=True)] - - # Reorder popped episodes to match the original random sample order - sorted_idxs = sorted(idx_to_sample, reverse=True) - idx_to_popped = dict(zip(sorted_idxs, popped)) - sampled_episodes = [idx_to_popped[i] for i in idx_to_sample] + sampled_indices = self.sample_policy(self.buffer, total_samples, self.sampler, curr_policy_version) + if sampled_indices is None: + return None + sampled_episodes = [] + for entry in self._collect(sampled_indices): + entry.sample_count += 1 + sampled_episodes.append(entry.data) # Reshape into (dp_size, bsz, ...) reshaped_episodes = [ - sampled_episodes[dp_idx * bsz : (dp_idx + 1) * bsz] + sampled_episodes[dp_idx * self.batch_size : (dp_idx + 1) * self.batch_size] for dp_idx in range(self.dp_size) ] @@ -117,46 +133,68 @@ async def evict(self, curr_policy_version: int) -> None: """ self._evict(curr_policy_version) - def _evict(self, curr_policy_version: int) -> None: + + def _evict(self, curr_policy_version): buffer_len_before_evict = len(self.buffer) - self.buffer = [ - trajectory - for trajectory in self.buffer - if (curr_policy_version - trajectory.policy_version) <= self.max_policy_age - ] - buffer_len_after_evict = len(self.buffer) + indices = self.eviction_policy(self.buffer, curr_policy_version, self.max_resample_count + 1, self.max_policy_age) + self.buffer = deque(self._collect(indices)) # Record evict metrics - policy_staleness = [ - curr_policy_version - ep.policy_version for ep in self.buffer + policy_age = [ + curr_policy_version - ep.data.policy_version for ep in self.buffer ] - if policy_staleness: + if policy_age: record_metric( - "buffer/evict/avg_policy_staleness", - sum(policy_staleness) / len(policy_staleness), + "buffer/evict/avg_policy_age", + sum(policy_age) / len(policy_age), Reduce.MEAN, ) record_metric( - "buffer/evict/max_policy_staleness", - max(policy_staleness), + "buffer/evict/max_policy_age", + max(policy_age), Reduce.MAX, ) - - # Record eviction metrics - evicted_count = buffer_len_before_evict - buffer_len_after_evict - if evicted_count > 0: - record_metric( - "buffer/evict/sum_episodes_evicted", evicted_count, Reduce.SUM - ) - + + evicted_count = buffer_len_before_evict - len(self.buffer) + record_metric( + "buffer/evict/sum_episodes_evicted", evicted_count, Reduce.SUM + ) + logger.debug( f"maximum policy age: {self.max_policy_age}, current policy version: {curr_policy_version}, " - f"{evicted_count} episodes expired, {buffer_len_after_evict} episodes left" + f"{evicted_count} episodes expired, {len(self.buffer)} episodes left" ) + + def _collect(self, indices: list[int]): + """Efficiently traverse deque and collect elements at each requested index""" + n = len(self.buffer) + if n == 0 or len(indices) == 0: + return [] + + # Normalize indices and store with their original order + indexed = [(pos, idx % n) for pos, idx in enumerate(indices)] + indexed.sort(key=itemgetter(1)) + + result = [None] * len(indices) + rotations = 0 # logical current index + total_rotation = 0 # total net rotation applied + + for orig_pos, idx in indexed: + move = idx - rotations + self.buffer.rotate(-move) + total_rotation += move + rotations = idx + result[orig_pos] = self.buffer[0] + + # Restore original deque orientation + self.buffer.rotate(total_rotation) + + return result + @endpoint async def _getitem(self, idx: int): - return self.buffer[idx] + return self.buffer[idx].data @endpoint async def _numel(self) -> int: @@ -181,3 +219,5 @@ async def state_dict(self) -> dict[str, Any]: async def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.buffer = state_dict["buffer"] random.setstate(state_dict["rng_state"]) + + diff --git a/tests/unit_tests/test_replay_buffer.py b/tests/unit_tests/test_replay_buffer.py index 05d415398..c0ae4d016 100644 --- a/tests/unit_tests/test_replay_buffer.py +++ b/tests/unit_tests/test_replay_buffer.py @@ -70,24 +70,15 @@ async def test_sample(self, replay_buffer) -> None: await replay_buffer.add.call_one(trajectory_1) assert replay_buffer._numel.call_one().get() == 2 - # Test a simple sampling w/ no evictions + # Test a simple sampling samples = await replay_buffer.sample.call_one(curr_policy_version=1) assert samples is not None assert len(samples[0]) == 2 + assert replay_buffer._numel.call_one().get() == 2 - # Test sampling with overriding batch size - await replay_buffer.add.call_one(trajectory_0) - samples = await replay_buffer.sample.call_one( - curr_policy_version=1, batch_size=1 - ) - assert samples is not None - assert len(samples[0]) == 1 - - # Test sampling w/ overriding batch size (not enough samples in buffer, returns None) + # Test sampling (not enough samples in buffer, returns None) await replay_buffer.add.call_one(trajectory_0) - samples = await replay_buffer.sample.call_one( - curr_policy_version=1, batch_size=3 - ) + samples = await replay_buffer.sample.call_one(curr_policy_version=1) assert samples is None replay_buffer.clear.call_one().get() @@ -95,15 +86,18 @@ async def test_sample(self, replay_buffer) -> None: async def test_sample_with_evictions(self, replay_buffer) -> None: trajectory_0 = Trajectory(policy_version=0) trajectory_1 = Trajectory(policy_version=1) + trajectory_2 = Trajectory(policy_version=2) await replay_buffer.add.call_one(trajectory_0) await replay_buffer.add.call_one(trajectory_1) - assert replay_buffer._numel.call_one().get() == 2 + await replay_buffer.add.call_one(trajectory_2) + assert replay_buffer._numel.call_one().get() == 3 samples = await replay_buffer.sample.call_one( - curr_policy_version=2, batch_size=1 + curr_policy_version=2, ) assert samples is not None - assert len(samples[0]) == 1 + assert len(samples[0]) == 2 assert samples[0][0] == trajectory_1 + assert replay_buffer._numel.call_one().get() == 2 replay_buffer.clear.call_one().get() @pytest.mark.asyncio From 37a348cf5925be08c87ca3d5d9c9434bb070b5a1 Mon Sep 17 00:00:00 2001 From: pbontrager Date: Tue, 14 Oct 2025 20:24:09 +0000 Subject: [PATCH 2/9] lint --- src/forge/actors/replay_buffer.py | 39 +++++++++++++++++-------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index d04e87006..5f03664e2 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -29,7 +29,7 @@ class BufferEntry: def default_evict(buffer, policy_version, max_samples=None, max_age=None): """Default buffer eviction policy""" - indices = [] + indices = [] for i, entry in enumerate(buffer): if max_age and policy_version - entry.data.policy_version > max_age: continue @@ -75,7 +75,9 @@ async def add(self, episode: "Episode") -> None: @endpoint @trace("buffer_perf/sample", track_memory=False) - async def sample(self, curr_policy_version: int) -> tuple[tuple[Any, ...], ...] | None: + async def sample( + self, curr_policy_version: int + ) -> tuple[tuple[Any, ...], ...] | None: """Sample from the replay buffer. Args: @@ -91,7 +93,7 @@ async def sample(self, curr_policy_version: int) -> tuple[tuple[Any, ...], ...] # Evict episodes self._evict(curr_policy_version) - + # Calculate metrics if len(self.buffer) > 0: record_metric( @@ -107,7 +109,9 @@ async def sample(self, curr_policy_version: int) -> tuple[tuple[Any, ...], ...] ) # TODO: prefetch samples in advance - sampled_indices = self.sample_policy(self.buffer, total_samples, self.sampler, curr_policy_version) + sampled_indices = self.sample_policy( + self.buffer, total_samples, self.sampler, curr_policy_version + ) if sampled_indices is None: return None sampled_episodes = [] @@ -133,10 +137,14 @@ async def evict(self, curr_policy_version: int) -> None: """ self._evict(curr_policy_version) - def _evict(self, curr_policy_version): buffer_len_before_evict = len(self.buffer) - indices = self.eviction_policy(self.buffer, curr_policy_version, self.max_resample_count + 1, self.max_policy_age) + indices = self.eviction_policy( + self.buffer, + curr_policy_version, + self.max_resample_count + 1, + self.max_policy_age, + ) self.buffer = deque(self._collect(indices)) # Record evict metrics @@ -154,18 +162,15 @@ def _evict(self, curr_policy_version): max(policy_age), Reduce.MAX, ) - + evicted_count = buffer_len_before_evict - len(self.buffer) - record_metric( - "buffer/evict/sum_episodes_evicted", evicted_count, Reduce.SUM - ) - + record_metric("buffer/evict/sum_episodes_evicted", evicted_count, Reduce.SUM) + logger.debug( f"maximum policy age: {self.max_policy_age}, current policy version: {curr_policy_version}, " f"{evicted_count} episodes expired, {len(self.buffer)} episodes left" ) - def _collect(self, indices: list[int]): """Efficiently traverse deque and collect elements at each requested index""" n = len(self.buffer) @@ -175,21 +180,21 @@ def _collect(self, indices: list[int]): # Normalize indices and store with their original order indexed = [(pos, idx % n) for pos, idx in enumerate(indices)] indexed.sort(key=itemgetter(1)) - + result = [None] * len(indices) rotations = 0 # logical current index total_rotation = 0 # total net rotation applied - + for orig_pos, idx in indexed: move = idx - rotations self.buffer.rotate(-move) total_rotation += move rotations = idx result[orig_pos] = self.buffer[0] - + # Restore original deque orientation self.buffer.rotate(total_rotation) - + return result @endpoint @@ -219,5 +224,3 @@ async def state_dict(self) -> dict[str, Any]: async def load_state_dict(self, state_dict: dict[str, Any]) -> None: self.buffer = state_dict["buffer"] random.setstate(state_dict["rng_state"]) - - From 4d2a1cea01f6fa7979d1dd032db45db771965f22 Mon Sep 17 00:00:00 2001 From: pbontrager Date: Tue, 14 Oct 2025 21:05:21 +0000 Subject: [PATCH 3/9] fixed flakey test --- tests/unit_tests/test_replay_buffer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/test_replay_buffer.py b/tests/unit_tests/test_replay_buffer.py index c0ae4d016..c1a5c47e4 100644 --- a/tests/unit_tests/test_replay_buffer.py +++ b/tests/unit_tests/test_replay_buffer.py @@ -96,7 +96,8 @@ async def test_sample_with_evictions(self, replay_buffer) -> None: ) assert samples is not None assert len(samples[0]) == 2 - assert samples[0][0] == trajectory_1 + assert samples[0][0].policy_version > 0 + assert samples[0][1].policy_version > 0 assert replay_buffer._numel.call_one().get() == 2 replay_buffer.clear.call_one().get() From 09e19f2beea700fe8b6ae1757b268a5d17199b2f Mon Sep 17 00:00:00 2001 From: pbontrager Date: Wed, 15 Oct 2025 16:09:40 +0000 Subject: [PATCH 4/9] added type hints --- src/forge/actors/replay_buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index 5f03664e2..d011b668c 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -27,7 +27,7 @@ class BufferEntry: sample_count: int = 0 -def default_evict(buffer, policy_version, max_samples=None, max_age=None): +def default_evict(buffer: deque, policy_version: int, max_samples: int = None, max_ag: inte = None): """Default buffer eviction policy""" indices = [] for i, entry in enumerate(buffer): @@ -39,7 +39,7 @@ def default_evict(buffer, policy_version, max_samples=None, max_age=None): return indices -def default_sample(buffer, sample_size, sampler, policy_version): +def default_sample(buffer: deque, sample_size: int, sampler: Callable, policy_version: int): """Default buffer sampling policy""" if sample_size > len(buffer): return None From 165f57bfc310203e7f25f9569af4fbaf4a6844bf Mon Sep 17 00:00:00 2001 From: pbontrager Date: Wed, 15 Oct 2025 16:14:48 +0000 Subject: [PATCH 5/9] ran linting --- src/forge/actors/replay_buffer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index d011b668c..f61396c03 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -27,7 +27,9 @@ class BufferEntry: sample_count: int = 0 -def default_evict(buffer: deque, policy_version: int, max_samples: int = None, max_ag: inte = None): +def default_evict( + buffer: deque, policy_version: int, max_samples: int = None, max_ag: inte = None +): """Default buffer eviction policy""" indices = [] for i, entry in enumerate(buffer): @@ -39,7 +41,9 @@ def default_evict(buffer: deque, policy_version: int, max_samples: int = None, m return indices -def default_sample(buffer: deque, sample_size: int, sampler: Callable, policy_version: int): +def default_sample( + buffer: deque, sample_size: int, sampler: Callable, policy_version: int +): """Default buffer sampling policy""" if sample_size > len(buffer): return None From 8f7a77dc478e5773b3b2eab02b917d1073550e4a Mon Sep 17 00:00:00 2001 From: pbontrager Date: Wed, 15 Oct 2025 16:32:06 +0000 Subject: [PATCH 6/9] bug fix --- src/forge/actors/replay_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index f61396c03..68facbe9a 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -28,7 +28,7 @@ class BufferEntry: def default_evict( - buffer: deque, policy_version: int, max_samples: int = None, max_ag: inte = None + buffer: deque, policy_version: int, max_samples: int = None, max_ag: int = None ): """Default buffer eviction policy""" indices = [] From 03fe93c80587e33e417795e2a8d4d83ce17e2590 Mon Sep 17 00:00:00 2001 From: pbontrager Date: Wed, 15 Oct 2025 17:20:14 +0000 Subject: [PATCH 7/9] another bug fix --- src/forge/actors/replay_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index 68facbe9a..f24bbe07d 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -28,7 +28,7 @@ class BufferEntry: def default_evict( - buffer: deque, policy_version: int, max_samples: int = None, max_ag: int = None + buffer: deque, policy_version: int, max_samples: int = None, max_age: int = None ): """Default buffer eviction policy""" indices = [] From 98d7ecca3dd31964c7e5599ece66d31dccff21e1 Mon Sep 17 00:00:00 2001 From: pbontrager Date: Wed, 15 Oct 2025 18:07:02 +0000 Subject: [PATCH 8/9] responsed to review --- src/forge/actors/replay_buffer.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index f24bbe07d..529ff3025 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -27,7 +27,7 @@ class BufferEntry: sample_count: int = 0 -def default_evict( +def age_evict( buffer: deque, policy_version: int, max_samples: int = None, max_age: int = None ): """Default buffer eviction policy""" @@ -41,13 +41,11 @@ def default_evict( return indices -def default_sample( - buffer: deque, sample_size: int, sampler: Callable, policy_version: int -): +def random_sample(buffer: deque, sample_size: int, policy_version: int): """Default buffer sampling policy""" if sample_size > len(buffer): return None - return sampler(range(len(buffer)), k=sample_size) + return random.sample(range(len(buffer)), k=sample_size) @dataclass @@ -61,8 +59,8 @@ class ReplayBuffer(ForgeActor): max_resample_count: int | None = 0 seed: int | None = None collate: Callable = lambda batch: batch - eviction_policy: Callable = default_evict - sample_policy: Callable = default_sample + eviction_policy: Callable = age_evict + sample_policy: Callable = random_sample @endpoint async def setup(self) -> None: @@ -70,7 +68,6 @@ async def setup(self) -> None: if self.seed is None: self.seed = random.randint(0, 2**32) random.seed(self.seed) - self.sampler = random.sample @endpoint async def add(self, episode: "Episode") -> None: @@ -114,7 +111,7 @@ async def sample( # TODO: prefetch samples in advance sampled_indices = self.sample_policy( - self.buffer, total_samples, self.sampler, curr_policy_version + self.buffer, total_samples, curr_policy_version ) if sampled_indices is None: return None From 6aa8b5a864f139f3b9dcbd3b38eb79cfdc0cdf21 Mon Sep 17 00:00:00 2001 From: pbontrager Date: Wed, 15 Oct 2025 19:11:46 +0000 Subject: [PATCH 9/9] added collect test --- src/forge/actors/replay_buffer.py | 8 ++++---- tests/unit_tests/test_replay_buffer.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index 529ff3025..b40329d5f 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -29,8 +29,8 @@ class BufferEntry: def age_evict( buffer: deque, policy_version: int, max_samples: int = None, max_age: int = None -): - """Default buffer eviction policy""" +) -> list[int]: + """Buffer eviction policy, remove old or over-sampled entries""" indices = [] for i, entry in enumerate(buffer): if max_age and policy_version - entry.data.policy_version > max_age: @@ -41,8 +41,8 @@ def age_evict( return indices -def random_sample(buffer: deque, sample_size: int, policy_version: int): - """Default buffer sampling policy""" +def random_sample(buffer: deque, sample_size: int, policy_version: int) -> list[int]: + """Buffer random sampling policy""" if sample_size > len(buffer): return None return random.sample(range(len(buffer)), k=sample_size) diff --git a/tests/unit_tests/test_replay_buffer.py b/tests/unit_tests/test_replay_buffer.py index c1a5c47e4..e6c6876c3 100644 --- a/tests/unit_tests/test_replay_buffer.py +++ b/tests/unit_tests/test_replay_buffer.py @@ -124,3 +124,16 @@ async def test_sample_dp_size(self) -> None: assert len(dp_samples) == 2 # batch_size replay_buffer.clear.call_one().get() + + @pytest.mark.asyncio + async def test_collect(self) -> None: + """Test _collect method""" + local_rb = ReplayBuffer(batch_size=1) + await local_rb.setup._method(local_rb) + for i in range(1, 6): + local_rb.buffer.append(i) + values = local_rb._collect([2, 0, -1]) + assert values == [3, 1, 5] + values = local_rb._collect([1, 3]) + assert values == [2, 4] + assert local_rb.buffer[0] == 1