Skip to content

Commit 494654b

Browse files
committed
add dp size in replay buffer
1 parent ddd0794 commit 494654b

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

apps/grpo/main.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -377,10 +377,7 @@ async def main():
377377
)
378378

379379
replay_buffer = await spawn_service(
380-
default_service_cfg,
381-
ReplayBuffer,
382-
batch_size=4,
383-
max_policy_age=1,
380+
default_service_cfg, ReplayBuffer, batch_size=4, max_policy_age=1, dp_size=1
384381
)
385382

386383
dataloader = await spawn_service(
@@ -469,6 +466,7 @@ async def continuous_training():
469466
if batch is None:
470467
await asyncio.sleep(0.1)
471468
else:
469+
batch = batch[0] # Hard coded because we are not doing data parallel
472470
training_result = await trainer.train_step.choose(batch)
473471
training_step += 1
474472
if training_step % 10 == 0:

src/forge/actors/replay_buffer.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,18 @@
88
from dataclasses import dataclass
99
from typing import Any
1010

11-
from monarch.actor import endpoint
12-
1311
from forge.controller import ForgeActor
1412

13+
from monarch.actor import endpoint
14+
1515

1616
@dataclass
1717
class ReplayBuffer(ForgeActor):
1818
"""Simple in-memory replay buffer implementation."""
1919

2020
batch_size: int
2121
max_policy_age: int
22+
dp_size: int = 1
2223
seed: int | None = None
2324

2425
@endpoint
@@ -43,23 +44,32 @@ async def sample(self, curr_policy_version: int, batch_size: int | None = None):
4344
passed in at initialization.
4445
4546
Returns:
46-
A list of sampled episodes or None if there are not enough episodes in the buffer.
47+
A list of sampled episodes with shape (dp_size, bsz, ...) or None if there are not enough episodes in the buffer.
4748
"""
4849
bsz = batch_size if batch_size is not None else self.batch_size
50+
total_samples = self.dp_size * bsz
4951

5052
# Evict old episodes
5153
self._evict(curr_policy_version)
5254

53-
if bsz > len(self.buffer):
55+
if total_samples > len(self.buffer):
5456
return None
5557

5658
# TODO: Make this more efficient
57-
idx_to_sample = self.sampler(range(len(self.buffer)), k=bsz)
59+
idx_to_sample = self.sampler(range(len(self.buffer)), k=total_samples)
5860
sorted_idxs = sorted(
5961
idx_to_sample, reverse=True
6062
) # Sort in desc order to avoid shifting idxs
6163
sampled_episodes = [self.buffer.pop(i) for i in sorted_idxs]
62-
return sampled_episodes
64+
65+
# Reshape to (dp_size, bsz, ...)
66+
reshaped_episodes = []
67+
for dp_idx in range(self.dp_size):
68+
start_idx = dp_idx * bsz
69+
end_idx = start_idx + bsz
70+
reshaped_episodes.append(sampled_episodes[start_idx:end_idx])
71+
72+
return reshaped_episodes
6373

6474
@endpoint
6575
async def evict(self, curr_policy_version: int) -> None:

0 commit comments

Comments
 (0)