Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class Trainer(ForgeActor):
beta: float = 0.1
epsilon: float = 0.1
device: torch.device | None = None
dp_rank: int = 0 # TODO: support data parallelism, hard code it for now

@endpoint
def setup(self):
Expand Down Expand Up @@ -178,6 +179,7 @@ def setup(self):

@endpoint
async def train_step(self, batch: list[Episode]):
batch = batch[self.dp_rank]
pad_id = batch[0].pad_id

# prepare batch
Expand Down Expand Up @@ -438,7 +440,7 @@ async def continuous_rollouts():
print(
f"Generated {rollout_count} rollouts w/ average reward {avg_reward}"
)
logger.log("reward/rollout", avg_reward, rollout_count)
logger.log("reward_per_rollout", avg_reward, rollout_count)

async def continuous_training():
training_step = 0
Expand Down
1 change: 1 addition & 0 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ trainer:
replay_buffer:
batch_size: ${batch_size}
max_policy_age: 0
dp_size: 1
service:
procs_per_replica: 1
num_replicas: 1
Expand Down
2 changes: 1 addition & 1 deletion apps/toy_rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ async def replay_buffer_sampler_task():
) # Update with true policy version when available
)
if trajectory is not None:
trajectories += trajectory
trajectories += trajectory[0]

# Most of the rest of this is just boilerplate for pretty printing.
if not trajectories:
Expand Down
32 changes: 22 additions & 10 deletions src/forge/actors/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
class ReplayBuffer(ForgeActor):
"""Simple in-memory replay buffer implementation."""

batch_size: int = 4
max_policy_age: int = 0
batch_size: int
max_policy_age: int
dp_size: int = 1
seed: int | None = None

@endpoint
Expand All @@ -43,23 +44,34 @@ async def sample(self, curr_policy_version: int, batch_size: int | None = None):
passed in at initialization.

Returns:
A list of sampled episodes or None if there are not enough episodes in the buffer.
A list of sampled episodes with shape (dp_size, bsz, ...) or None if there are not enough episodes in the buffer.
"""
bsz = batch_size if batch_size is not None else self.batch_size
total_samples = self.dp_size * bsz

# Evict old episodes
self._evict(curr_policy_version)

if bsz > len(self.buffer):
if total_samples > len(self.buffer):
return None

# TODO: Make this more efficient
idx_to_sample = self.sampler(range(len(self.buffer)), k=bsz)
sorted_idxs = sorted(
idx_to_sample, reverse=True
) # Sort in desc order to avoid shifting idxs
sampled_episodes = [self.buffer.pop(i) for i in sorted_idxs]
return sampled_episodes
idx_to_sample = self.sampler(range(len(self.buffer)), k=total_samples)
# Pop episodes in descending order to avoid shifting issues
popped = [self.buffer.pop(i) for i in sorted(idx_to_sample, reverse=True)]

# Reorder popped episodes to match the original random sample order
sorted_idxs = sorted(idx_to_sample, reverse=True)
idx_to_popped = dict(zip(sorted_idxs, popped))
sampled_episodes = [idx_to_popped[i] for i in idx_to_sample]

# Reshape into (dp_size, bsz, ...)
reshaped_episodes = [
sampled_episodes[dp_idx * bsz : (dp_idx + 1) * bsz]
for dp_idx in range(self.dp_size)
]

return reshaped_episodes

@endpoint
async def evict(self, curr_policy_version: int) -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/rl/test_toy_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ async def replay_buffer_sampler_task():
42 # curr_policy_version
)
if sampled_trajectory is not None:
sampled_trajectory = sampled_trajectory[0]
sampled_trajectories.append(sampled_trajectory[0])
samples_collected += 1

Expand Down
33 changes: 29 additions & 4 deletions tests/unit_tests/test_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,15 @@ async def test_sample(self, replay_buffer) -> None:
# Test a simple sampling w/ no evictions
samples = await replay_buffer.sample.call_one(curr_policy_version=1)
assert samples is not None
assert len(samples) == 2
assert len(samples[0]) == 2

# Test sampling with overriding batch size
await replay_buffer.add.call_one(trajectory_0)
samples = await replay_buffer.sample.call_one(
curr_policy_version=1, batch_size=1
)
assert samples is not None
assert len(samples) == 1
assert len(samples[0]) == 1

# Test sampling w/ overriding batch size (not enough samples in buffer, returns None)
await replay_buffer.add.call_one(trajectory_0)
Expand All @@ -105,6 +105,31 @@ async def test_sample_with_evictions(self, replay_buffer) -> None:
curr_policy_version=2, batch_size=1
)
assert samples is not None
assert len(samples) == 1
assert samples[0] == trajectory_1
assert len(samples[0]) == 1
assert samples[0][0] == trajectory_1
replay_buffer.clear.call_one().get()

@pytest.mark.asyncio
async def test_sample_dp_size(self) -> None:
"""Test that len(samples) == dp_size when sampling."""
mesh = await proc_mesh(gpus=1)
# Create replay buffer with dp_size=3
replay_buffer = await mesh.spawn(
"replay_buffer", ReplayBuffer, batch_size=2, max_policy_age=1, dp_size=3
)
await replay_buffer.setup.call()

# Add enough trajectories to sample
for i in range(10):
trajectory = Trajectory(policy_version=0)
await replay_buffer.add.call_one(trajectory)

# Sample and verify len(samples) == dp_size
samples = await replay_buffer.sample.call_one(curr_policy_version=0)
assert samples is not None
assert len(samples) == 3 # dp_size
# Each sub-list should have batch_size samples
for dp_samples in samples:
assert len(dp_samples) == 2 # batch_size

replay_buffer.clear.call_one().get()
Loading