diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index a92fd5501..b40329d5f 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -6,7 +6,9 @@ import logging import random +from collections import deque from dataclasses import dataclass +from operator import itemgetter from typing import Any, Callable from monarch.actor import endpoint @@ -19,43 +21,68 @@ logger.setLevel(logging.INFO) +@dataclass +class BufferEntry: + data: "Episode" + sample_count: int = 0 + + +def age_evict( + buffer: deque, policy_version: int, max_samples: int = None, max_age: int = None +) -> list[int]: + """Buffer eviction policy, remove old or over-sampled entries""" + indices = [] + for i, entry in enumerate(buffer): + if max_age and policy_version - entry.data.policy_version > max_age: + continue + if max_samples and entry.sample_count >= max_samples: + continue + indices.append(i) + return indices + + +def random_sample(buffer: deque, sample_size: int, policy_version: int) -> list[int]: + """Buffer random sampling policy""" + if sample_size > len(buffer): + return None + return random.sample(range(len(buffer)), k=sample_size) + + @dataclass class ReplayBuffer(ForgeActor): """Simple in-memory replay buffer implementation.""" batch_size: int - max_policy_age: int dp_size: int = 1 + max_policy_age: int | None = None + max_buffer_size: int | None = None + max_resample_count: int | None = 0 seed: int | None = None collate: Callable = lambda batch: batch - - def __post_init__(self): - super().__init__() + eviction_policy: Callable = age_evict + sample_policy: Callable = random_sample @endpoint async def setup(self) -> None: - self.buffer: list = [] + self.buffer: deque = deque(maxlen=self.max_buffer_size) if self.seed is None: self.seed = random.randint(0, 2**32) random.seed(self.seed) - self.sampler = random.sample @endpoint async def add(self, episode: "Episode") -> None: - self.buffer.append(episode) + self.buffer.append(BufferEntry(episode)) record_metric("buffer/add/count_episodes_added", 1, Reduce.SUM) @endpoint @trace("buffer_perf/sample", track_memory=False) async def sample( - self, curr_policy_version: int, batch_size: int | None = None + self, curr_policy_version: int ) -> tuple[tuple[Any, ...], ...] | None: """Sample from the replay buffer. Args: curr_policy_version (int): The current policy version. - batch_size (int, optional): Number of episodes to sample. If none, defaults to batch size - passed in at initialization. Returns: A list of sampled episodes with shape (dp_size, bsz, ...) or None if there are not enough episodes in the buffer. @@ -63,45 +90,39 @@ async def sample( # Record sample request metric record_metric("buffer/sample/count_sample_requests", 1, Reduce.SUM) - bsz = batch_size if batch_size is not None else self.batch_size - total_samples = self.dp_size * bsz + total_samples = self.dp_size * self.batch_size - # Evict old episodes + # Evict episodes self._evict(curr_policy_version) - if total_samples > len(self.buffer): - return None - - # Calculate buffer utilization - utilization_pct = ( - (total_samples / len(self.buffer)) * 100 if len(self.buffer) > 0 else 0 - ) - - record_metric( - "buffer/sample/avg_buffer_utilization", - len(self.buffer), - Reduce.MEAN, - ) - - record_metric( - "buffer/sample/avg_buffer_utilization_pct", - utilization_pct, - Reduce.MEAN, - ) + # Calculate metrics + if len(self.buffer) > 0: + record_metric( + "buffer/sample/avg_data_utilization", + total_samples / len(self.buffer), + Reduce.MEAN, + ) + if self.max_buffer_size: + record_metric( + "buffer/sample/avg_buffer_utilization", + len(self.buffer) / self.max_buffer_size, + Reduce.MEAN, + ) # TODO: prefetch samples in advance - idx_to_sample = self.sampler(range(len(self.buffer)), k=total_samples) - # Pop episodes in descending order to avoid shifting issues - popped = [self.buffer.pop(i) for i in sorted(idx_to_sample, reverse=True)] - - # Reorder popped episodes to match the original random sample order - sorted_idxs = sorted(idx_to_sample, reverse=True) - idx_to_popped = dict(zip(sorted_idxs, popped)) - sampled_episodes = [idx_to_popped[i] for i in idx_to_sample] + sampled_indices = self.sample_policy( + self.buffer, total_samples, curr_policy_version + ) + if sampled_indices is None: + return None + sampled_episodes = [] + for entry in self._collect(sampled_indices): + entry.sample_count += 1 + sampled_episodes.append(entry.data) # Reshape into (dp_size, bsz, ...) reshaped_episodes = [ - sampled_episodes[dp_idx * bsz : (dp_idx + 1) * bsz] + sampled_episodes[dp_idx * self.batch_size : (dp_idx + 1) * self.batch_size] for dp_idx in range(self.dp_size) ] @@ -117,46 +138,69 @@ async def evict(self, curr_policy_version: int) -> None: """ self._evict(curr_policy_version) - def _evict(self, curr_policy_version: int) -> None: + def _evict(self, curr_policy_version): buffer_len_before_evict = len(self.buffer) - self.buffer = [ - trajectory - for trajectory in self.buffer - if (curr_policy_version - trajectory.policy_version) <= self.max_policy_age - ] - buffer_len_after_evict = len(self.buffer) + indices = self.eviction_policy( + self.buffer, + curr_policy_version, + self.max_resample_count + 1, + self.max_policy_age, + ) + self.buffer = deque(self._collect(indices)) # Record evict metrics - policy_staleness = [ - curr_policy_version - ep.policy_version for ep in self.buffer + policy_age = [ + curr_policy_version - ep.data.policy_version for ep in self.buffer ] - if policy_staleness: + if policy_age: record_metric( - "buffer/evict/avg_policy_staleness", - sum(policy_staleness) / len(policy_staleness), + "buffer/evict/avg_policy_age", + sum(policy_age) / len(policy_age), Reduce.MEAN, ) record_metric( - "buffer/evict/max_policy_staleness", - max(policy_staleness), + "buffer/evict/max_policy_age", + max(policy_age), Reduce.MAX, ) - # Record eviction metrics - evicted_count = buffer_len_before_evict - buffer_len_after_evict - if evicted_count > 0: - record_metric( - "buffer/evict/sum_episodes_evicted", evicted_count, Reduce.SUM - ) + evicted_count = buffer_len_before_evict - len(self.buffer) + record_metric("buffer/evict/sum_episodes_evicted", evicted_count, Reduce.SUM) logger.debug( f"maximum policy age: {self.max_policy_age}, current policy version: {curr_policy_version}, " - f"{evicted_count} episodes expired, {buffer_len_after_evict} episodes left" + f"{evicted_count} episodes expired, {len(self.buffer)} episodes left" ) + def _collect(self, indices: list[int]): + """Efficiently traverse deque and collect elements at each requested index""" + n = len(self.buffer) + if n == 0 or len(indices) == 0: + return [] + + # Normalize indices and store with their original order + indexed = [(pos, idx % n) for pos, idx in enumerate(indices)] + indexed.sort(key=itemgetter(1)) + + result = [None] * len(indices) + rotations = 0 # logical current index + total_rotation = 0 # total net rotation applied + + for orig_pos, idx in indexed: + move = idx - rotations + self.buffer.rotate(-move) + total_rotation += move + rotations = idx + result[orig_pos] = self.buffer[0] + + # Restore original deque orientation + self.buffer.rotate(total_rotation) + + return result + @endpoint async def _getitem(self, idx: int): - return self.buffer[idx] + return self.buffer[idx].data @endpoint async def _numel(self) -> int: diff --git a/tests/unit_tests/test_replay_buffer.py b/tests/unit_tests/test_replay_buffer.py index 05d415398..e6c6876c3 100644 --- a/tests/unit_tests/test_replay_buffer.py +++ b/tests/unit_tests/test_replay_buffer.py @@ -70,24 +70,15 @@ async def test_sample(self, replay_buffer) -> None: await replay_buffer.add.call_one(trajectory_1) assert replay_buffer._numel.call_one().get() == 2 - # Test a simple sampling w/ no evictions + # Test a simple sampling samples = await replay_buffer.sample.call_one(curr_policy_version=1) assert samples is not None assert len(samples[0]) == 2 + assert replay_buffer._numel.call_one().get() == 2 - # Test sampling with overriding batch size - await replay_buffer.add.call_one(trajectory_0) - samples = await replay_buffer.sample.call_one( - curr_policy_version=1, batch_size=1 - ) - assert samples is not None - assert len(samples[0]) == 1 - - # Test sampling w/ overriding batch size (not enough samples in buffer, returns None) + # Test sampling (not enough samples in buffer, returns None) await replay_buffer.add.call_one(trajectory_0) - samples = await replay_buffer.sample.call_one( - curr_policy_version=1, batch_size=3 - ) + samples = await replay_buffer.sample.call_one(curr_policy_version=1) assert samples is None replay_buffer.clear.call_one().get() @@ -95,15 +86,19 @@ async def test_sample(self, replay_buffer) -> None: async def test_sample_with_evictions(self, replay_buffer) -> None: trajectory_0 = Trajectory(policy_version=0) trajectory_1 = Trajectory(policy_version=1) + trajectory_2 = Trajectory(policy_version=2) await replay_buffer.add.call_one(trajectory_0) await replay_buffer.add.call_one(trajectory_1) - assert replay_buffer._numel.call_one().get() == 2 + await replay_buffer.add.call_one(trajectory_2) + assert replay_buffer._numel.call_one().get() == 3 samples = await replay_buffer.sample.call_one( - curr_policy_version=2, batch_size=1 + curr_policy_version=2, ) assert samples is not None - assert len(samples[0]) == 1 - assert samples[0][0] == trajectory_1 + assert len(samples[0]) == 2 + assert samples[0][0].policy_version > 0 + assert samples[0][1].policy_version > 0 + assert replay_buffer._numel.call_one().get() == 2 replay_buffer.clear.call_one().get() @pytest.mark.asyncio @@ -129,3 +124,16 @@ async def test_sample_dp_size(self) -> None: assert len(dp_samples) == 2 # batch_size replay_buffer.clear.call_one().get() + + @pytest.mark.asyncio + async def test_collect(self) -> None: + """Test _collect method""" + local_rb = ReplayBuffer(batch_size=1) + await local_rb.setup._method(local_rb) + for i in range(1, 6): + local_rb.buffer.append(i) + values = local_rb._collect([2, 0, -1]) + assert values == [3, 1, 5] + values = local_rb._collect([1, 3]) + assert values == [2, 4] + assert local_rb.buffer[0] == 1