Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ async def continuous_rollouts():
print(
f"Generated {rollout_count} rollouts w/ average reward {avg_reward}"
)
logger.log("reward/rollout", avg_reward, rollout_count)
logger.log("reward_per_rollout", avg_reward, rollout_count)

async def continuous_training():
training_step = 0
Expand All @@ -447,6 +447,7 @@ async def continuous_training():
if batch is None:
await asyncio.sleep(0.1)
else:
batch = batch[0] # Hard coded because we are not doing data parallel
training_result = await trainer.train_step.choose(batch)
training_step += 1
if training_step % 10 == 0:
Expand Down
1 change: 1 addition & 0 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ trainer:
replay_buffer:
batch_size: ${batch_size}
max_policy_age: 0
dp_size: 1
service:
procs_per_replica: 1
num_replicas: 1
Expand Down
2 changes: 1 addition & 1 deletion apps/toy_rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ async def replay_buffer_sampler_task():
) # Update with true policy version when available
)
if trajectory is not None:
trajectories += trajectory
trajectories += trajectory[0]

# Most of the rest of this is just boilerplate for pretty printing.
if not trajectories:
Expand Down
29 changes: 19 additions & 10 deletions src/forge/actors/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
class ReplayBuffer(ForgeActor):
"""Simple in-memory replay buffer implementation."""

batch_size: int = 4
max_policy_age: int = 0
batch_size: int
max_policy_age: int
dp_size: int = 1
seed: int | None = None

@endpoint
Expand All @@ -43,23 +44,31 @@ async def sample(self, curr_policy_version: int, batch_size: int | None = None):
passed in at initialization.

Returns:
A list of sampled episodes or None if there are not enough episodes in the buffer.
A list of sampled episodes with shape (dp_size, bsz, ...) or None if there are not enough episodes in the buffer.
"""
bsz = batch_size if batch_size is not None else self.batch_size
total_samples = self.dp_size * bsz

# Evict old episodes
self._evict(curr_policy_version)

if bsz > len(self.buffer):
if total_samples > len(self.buffer):
return None

# TODO: Make this more efficient
idx_to_sample = self.sampler(range(len(self.buffer)), k=bsz)
sorted_idxs = sorted(
idx_to_sample, reverse=True
) # Sort in desc order to avoid shifting idxs
sampled_episodes = [self.buffer.pop(i) for i in sorted_idxs]
return sampled_episodes
idx_to_sample = self.sampler(range(len(self.buffer)), k=total_samples)
sampled_episodes = [self.buffer[i] for i in idx_to_sample]

# Evict sampled episodes (descending order so pops are safe)
for i in sorted(idx_to_sample, reverse=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cleaner but is it moving the data twice? It's probably fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the code to make it more efficient.

self.buffer.pop(i)

# Reshape into (dp_size, bsz, ...)
reshaped_episodes = [
sampled_episodes[dp_idx * bsz : (dp_idx + 1) * bsz]
for dp_idx in range(self.dp_size)
]
return reshaped_episodes

@endpoint
async def evict(self, curr_policy_version: int) -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/rl/test_toy_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ async def replay_buffer_sampler_task():
42 # curr_policy_version
)
if sampled_trajectory is not None:
sampled_trajectory = sampled_trajectory[0]
sampled_trajectories.append(sampled_trajectory[0])
samples_collected += 1

Expand Down
13 changes: 9 additions & 4 deletions tests/unit_tests/test_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,15 @@ async def test_sample(self, replay_buffer) -> None:
# Test a simple sampling w/ no evictions
samples = await replay_buffer.sample.call_one(curr_policy_version=1)
assert samples is not None
assert len(samples) == 2
assert len(samples[0]) == 2

# Test sampling with overriding batch size
await replay_buffer.add.call_one(trajectory_0)
samples = await replay_buffer.sample.call_one(
curr_policy_version=1, batch_size=1
)
assert samples is not None
assert len(samples) == 1
assert len(samples[0]) == 1

# Test sampling w/ overriding batch size (not enough samples in buffer, returns None)
await replay_buffer.add.call_one(trajectory_0)
Expand All @@ -105,6 +105,11 @@ async def test_sample_with_evictions(self, replay_buffer) -> None:
curr_policy_version=2, batch_size=1
)
assert samples is not None
assert len(samples) == 1
assert samples[0] == trajectory_1
assert len(samples[0]) == 1
assert samples[0][0] == trajectory_1
replay_buffer.clear.call_one().get()


if __name__ == "__main__":
# Run tests with pytest
pytest.main([__file__, "-v"])
Loading