Skip to content

Commit 7f35b26

Browse files
pbontragerallenwang28
authored andcommitted
Configurable ReplayBuffer (#410)
1 parent be486ef commit 7f35b26

File tree

2 files changed

+132
-80
lines changed

2 files changed

+132
-80
lines changed

src/forge/actors/replay_buffer.py

Lines changed: 107 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
import logging
88
import random
9+
from collections import deque
910
from dataclasses import dataclass
11+
from operator import itemgetter
1012
from typing import Any, Callable
1113

1214
from monarch.actor import endpoint
@@ -19,89 +21,108 @@
1921
logger.setLevel(logging.INFO)
2022

2123

24+
@dataclass
25+
class BufferEntry:
26+
data: "Episode"
27+
sample_count: int = 0
28+
29+
30+
def age_evict(
31+
buffer: deque, policy_version: int, max_samples: int = None, max_age: int = None
32+
) -> list[int]:
33+
"""Buffer eviction policy, remove old or over-sampled entries"""
34+
indices = []
35+
for i, entry in enumerate(buffer):
36+
if max_age and policy_version - entry.data.policy_version > max_age:
37+
continue
38+
if max_samples and entry.sample_count >= max_samples:
39+
continue
40+
indices.append(i)
41+
return indices
42+
43+
44+
def random_sample(buffer: deque, sample_size: int, policy_version: int) -> list[int]:
45+
"""Buffer random sampling policy"""
46+
if sample_size > len(buffer):
47+
return None
48+
return random.sample(range(len(buffer)), k=sample_size)
49+
50+
2251
@dataclass
2352
class ReplayBuffer(ForgeActor):
2453
"""Simple in-memory replay buffer implementation."""
2554

2655
batch_size: int
27-
max_policy_age: int
2856
dp_size: int = 1
57+
max_policy_age: int | None = None
58+
max_buffer_size: int | None = None
59+
max_resample_count: int | None = 0
2960
seed: int | None = None
3061
collate: Callable = lambda batch: batch
31-
32-
def __post_init__(self):
33-
super().__init__()
62+
eviction_policy: Callable = age_evict
63+
sample_policy: Callable = random_sample
3464

3565
@endpoint
3666
async def setup(self) -> None:
37-
self.buffer: list = []
67+
self.buffer: deque = deque(maxlen=self.max_buffer_size)
3868
if self.seed is None:
3969
self.seed = random.randint(0, 2**32)
4070
random.seed(self.seed)
41-
self.sampler = random.sample
4271

4372
@endpoint
4473
async def add(self, episode: "Episode") -> None:
45-
self.buffer.append(episode)
74+
self.buffer.append(BufferEntry(episode))
4675
record_metric("buffer/add/count_episodes_added", 1, Reduce.SUM)
4776

4877
@endpoint
4978
@trace("buffer_perf/sample", track_memory=False)
5079
async def sample(
51-
self, curr_policy_version: int, batch_size: int | None = None
80+
self, curr_policy_version: int
5281
) -> tuple[tuple[Any, ...], ...] | None:
5382
"""Sample from the replay buffer.
5483
5584
Args:
5685
curr_policy_version (int): The current policy version.
57-
batch_size (int, optional): Number of episodes to sample. If none, defaults to batch size
58-
passed in at initialization.
5986
6087
Returns:
6188
A list of sampled episodes with shape (dp_size, bsz, ...) or None if there are not enough episodes in the buffer.
6289
"""
6390
# Record sample request metric
6491
record_metric("buffer/sample/count_sample_requests", 1, Reduce.SUM)
6592

66-
bsz = batch_size if batch_size is not None else self.batch_size
67-
total_samples = self.dp_size * bsz
93+
total_samples = self.dp_size * self.batch_size
6894

69-
# Evict old episodes
95+
# Evict episodes
7096
self._evict(curr_policy_version)
7197

72-
if total_samples > len(self.buffer):
73-
return None
74-
75-
# Calculate buffer utilization
76-
utilization_pct = (
77-
(total_samples / len(self.buffer)) * 100 if len(self.buffer) > 0 else 0
78-
)
79-
80-
record_metric(
81-
"buffer/sample/avg_buffer_utilization",
82-
len(self.buffer),
83-
Reduce.MEAN,
84-
)
85-
86-
record_metric(
87-
"buffer/sample/avg_buffer_utilization_pct",
88-
utilization_pct,
89-
Reduce.MEAN,
90-
)
98+
# Calculate metrics
99+
if len(self.buffer) > 0:
100+
record_metric(
101+
"buffer/sample/avg_data_utilization",
102+
total_samples / len(self.buffer),
103+
Reduce.MEAN,
104+
)
105+
if self.max_buffer_size:
106+
record_metric(
107+
"buffer/sample/avg_buffer_utilization",
108+
len(self.buffer) / self.max_buffer_size,
109+
Reduce.MEAN,
110+
)
91111

92112
# TODO: prefetch samples in advance
93-
idx_to_sample = self.sampler(range(len(self.buffer)), k=total_samples)
94-
# Pop episodes in descending order to avoid shifting issues
95-
popped = [self.buffer.pop(i) for i in sorted(idx_to_sample, reverse=True)]
96-
97-
# Reorder popped episodes to match the original random sample order
98-
sorted_idxs = sorted(idx_to_sample, reverse=True)
99-
idx_to_popped = dict(zip(sorted_idxs, popped))
100-
sampled_episodes = [idx_to_popped[i] for i in idx_to_sample]
113+
sampled_indices = self.sample_policy(
114+
self.buffer, total_samples, curr_policy_version
115+
)
116+
if sampled_indices is None:
117+
return None
118+
sampled_episodes = []
119+
for entry in self._collect(sampled_indices):
120+
entry.sample_count += 1
121+
sampled_episodes.append(entry.data)
101122

102123
# Reshape into (dp_size, bsz, ...)
103124
reshaped_episodes = [
104-
sampled_episodes[dp_idx * bsz : (dp_idx + 1) * bsz]
125+
sampled_episodes[dp_idx * self.batch_size : (dp_idx + 1) * self.batch_size]
105126
for dp_idx in range(self.dp_size)
106127
]
107128

@@ -118,46 +139,69 @@ async def evict(self, curr_policy_version: int) -> None:
118139
"""
119140
self._evict(curr_policy_version)
120141

121-
def _evict(self, curr_policy_version: int) -> None:
142+
def _evict(self, curr_policy_version):
122143
buffer_len_before_evict = len(self.buffer)
123-
self.buffer = [
124-
trajectory
125-
for trajectory in self.buffer
126-
if (curr_policy_version - trajectory.policy_version) <= self.max_policy_age
127-
]
128-
buffer_len_after_evict = len(self.buffer)
144+
indices = self.eviction_policy(
145+
self.buffer,
146+
curr_policy_version,
147+
self.max_resample_count + 1,
148+
self.max_policy_age,
149+
)
150+
self.buffer = deque(self._collect(indices))
129151

130152
# Record evict metrics
131-
policy_staleness = [
132-
curr_policy_version - ep.policy_version for ep in self.buffer
153+
policy_age = [
154+
curr_policy_version - ep.data.policy_version for ep in self.buffer
133155
]
134-
if policy_staleness:
156+
if policy_age:
135157
record_metric(
136-
"buffer/evict/avg_policy_staleness",
137-
sum(policy_staleness) / len(policy_staleness),
158+
"buffer/evict/avg_policy_age",
159+
sum(policy_age) / len(policy_age),
138160
Reduce.MEAN,
139161
)
140162
record_metric(
141-
"buffer/evict/max_policy_staleness",
142-
max(policy_staleness),
163+
"buffer/evict/max_policy_age",
164+
max(policy_age),
143165
Reduce.MAX,
144166
)
145167

146-
# Record eviction metrics
147-
evicted_count = buffer_len_before_evict - buffer_len_after_evict
148-
if evicted_count > 0:
149-
record_metric(
150-
"buffer/evict/sum_episodes_evicted", evicted_count, Reduce.SUM
151-
)
168+
evicted_count = buffer_len_before_evict - len(self.buffer)
169+
record_metric("buffer/evict/sum_episodes_evicted", evicted_count, Reduce.SUM)
152170

153171
logger.debug(
154172
f"maximum policy age: {self.max_policy_age}, current policy version: {curr_policy_version}, "
155-
f"{evicted_count} episodes expired, {buffer_len_after_evict} episodes left"
173+
f"{evicted_count} episodes expired, {len(self.buffer)} episodes left"
156174
)
157175

176+
def _collect(self, indices: list[int]):
177+
"""Efficiently traverse deque and collect elements at each requested index"""
178+
n = len(self.buffer)
179+
if n == 0 or len(indices) == 0:
180+
return []
181+
182+
# Normalize indices and store with their original order
183+
indexed = [(pos, idx % n) for pos, idx in enumerate(indices)]
184+
indexed.sort(key=itemgetter(1))
185+
186+
result = [None] * len(indices)
187+
rotations = 0 # logical current index
188+
total_rotation = 0 # total net rotation applied
189+
190+
for orig_pos, idx in indexed:
191+
move = idx - rotations
192+
self.buffer.rotate(-move)
193+
total_rotation += move
194+
rotations = idx
195+
result[orig_pos] = self.buffer[0]
196+
197+
# Restore original deque orientation
198+
self.buffer.rotate(total_rotation)
199+
200+
return result
201+
158202
@endpoint
159203
async def _getitem(self, idx: int):
160-
return self.buffer[idx]
204+
return self.buffer[idx].data
161205

162206
@endpoint
163207
async def _numel(self) -> int:

tests/unit_tests/test_replay_buffer.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -70,40 +70,35 @@ async def test_sample(self, replay_buffer) -> None:
7070
await replay_buffer.add.call_one(trajectory_1)
7171
assert replay_buffer._numel.call_one().get() == 2
7272

73-
# Test a simple sampling w/ no evictions
73+
# Test a simple sampling
7474
samples = await replay_buffer.sample.call_one(curr_policy_version=1)
7575
assert samples is not None
7676
assert len(samples[0]) == 2
77+
assert replay_buffer._numel.call_one().get() == 2
7778

78-
# Test sampling with overriding batch size
79-
await replay_buffer.add.call_one(trajectory_0)
80-
samples = await replay_buffer.sample.call_one(
81-
curr_policy_version=1, batch_size=1
82-
)
83-
assert samples is not None
84-
assert len(samples[0]) == 1
85-
86-
# Test sampling w/ overriding batch size (not enough samples in buffer, returns None)
79+
# Test sampling (not enough samples in buffer, returns None)
8780
await replay_buffer.add.call_one(trajectory_0)
88-
samples = await replay_buffer.sample.call_one(
89-
curr_policy_version=1, batch_size=3
90-
)
81+
samples = await replay_buffer.sample.call_one(curr_policy_version=1)
9182
assert samples is None
9283
replay_buffer.clear.call_one().get()
9384

9485
@pytest.mark.asyncio
9586
async def test_sample_with_evictions(self, replay_buffer) -> None:
9687
trajectory_0 = Trajectory(policy_version=0)
9788
trajectory_1 = Trajectory(policy_version=1)
89+
trajectory_2 = Trajectory(policy_version=2)
9890
await replay_buffer.add.call_one(trajectory_0)
9991
await replay_buffer.add.call_one(trajectory_1)
100-
assert replay_buffer._numel.call_one().get() == 2
92+
await replay_buffer.add.call_one(trajectory_2)
93+
assert replay_buffer._numel.call_one().get() == 3
10194
samples = await replay_buffer.sample.call_one(
102-
curr_policy_version=2, batch_size=1
95+
curr_policy_version=2,
10396
)
10497
assert samples is not None
105-
assert len(samples[0]) == 1
106-
assert samples[0][0] == trajectory_1
98+
assert len(samples[0]) == 2
99+
assert samples[0][0].policy_version > 0
100+
assert samples[0][1].policy_version > 0
101+
assert replay_buffer._numel.call_one().get() == 2
107102
replay_buffer.clear.call_one().get()
108103

109104
@pytest.mark.asyncio
@@ -129,3 +124,16 @@ async def test_sample_dp_size(self) -> None:
129124
assert len(dp_samples) == 2 # batch_size
130125

131126
replay_buffer.clear.call_one().get()
127+
128+
@pytest.mark.asyncio
129+
async def test_collect(self) -> None:
130+
"""Test _collect method"""
131+
local_rb = ReplayBuffer(batch_size=1)
132+
await local_rb.setup._method(local_rb)
133+
for i in range(1, 6):
134+
local_rb.buffer.append(i)
135+
values = local_rb._collect([2, 0, -1])
136+
assert values == [3, 1, 5]
137+
values = local_rb._collect([1, 3])
138+
assert values == [2, 4]
139+
assert local_rb.buffer[0] == 1

0 commit comments

Comments
 (0)