Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ async def continuous_rollouts():
print(
f"Generated {rollout_count} rollouts w/ average reward {avg_reward}"
)
logger.log("reward/rollout", avg_reward, rollout_count)
logger.log("reward_per_rollout", avg_reward, rollout_count)

async def continuous_training():
training_step = 0
Expand All @@ -447,6 +447,7 @@ async def continuous_training():
if batch is None:
await asyncio.sleep(0.1)
else:
batch = batch[0] # Hard coded because we are not doing data parallel
training_result = await trainer.train_step.choose(batch)
training_step += 1
if training_step % 10 == 0:
Expand Down
1 change: 1 addition & 0 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ trainer:
replay_buffer:
batch_size: ${batch_size}
max_policy_age: 0
dp_size: 1
service:
procs_per_replica: 1
num_replicas: 1
Expand Down
2 changes: 1 addition & 1 deletion apps/toy_rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ async def replay_buffer_sampler_task():
) # Update with true policy version when available
)
if trajectory is not None:
trajectories += trajectory
trajectories += trajectory[0]

# Most of the rest of this is just boilerplate for pretty printing.
if not trajectories:
Expand Down
22 changes: 16 additions & 6 deletions src/forge/actors/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
class ReplayBuffer(ForgeActor):
"""Simple in-memory replay buffer implementation."""

batch_size: int = 4
max_policy_age: int = 0
batch_size: int
max_policy_age: int
dp_size: int = 1
seed: int | None = None

@endpoint
Expand All @@ -43,23 +44,32 @@ async def sample(self, curr_policy_version: int, batch_size: int | None = None):
passed in at initialization.

Returns:
A list of sampled episodes or None if there are not enough episodes in the buffer.
A list of sampled episodes with shape (dp_size, bsz, ...) or None if there are not enough episodes in the buffer.
"""
bsz = batch_size if batch_size is not None else self.batch_size
total_samples = self.dp_size * bsz

# Evict old episodes
self._evict(curr_policy_version)

if bsz > len(self.buffer):
if total_samples > len(self.buffer):
return None

# TODO: Make this more efficient
idx_to_sample = self.sampler(range(len(self.buffer)), k=bsz)
idx_to_sample = self.sampler(range(len(self.buffer)), k=total_samples)
sorted_idxs = sorted(
idx_to_sample, reverse=True
) # Sort in desc order to avoid shifting idxs
sampled_episodes = [self.buffer.pop(i) for i in sorted_idxs]
return sampled_episodes

# Reshape to (dp_size, bsz, ...)
reshaped_episodes = []
for dp_idx in range(self.dp_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to return a sorted sample here as that reduces variability in the sample. You need to get the index of the sorted array and then probably do this as a nested for loop to be easier to read.

batch = []
for rank in self.dp_size:
	local_batch = []
	for i in bsz:
		e = sampled_episodes[sort_order[rank*i]]
		local_batch.append(e)
	batch.append(local_batch)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out this issue. I have updated this part. Please review.

start_idx = dp_idx * bsz
end_idx = start_idx + bsz
reshaped_episodes.append(sampled_episodes[start_idx:end_idx])

return reshaped_episodes

@endpoint
async def evict(self, curr_policy_version: int) -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/rl/test_toy_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ async def replay_buffer_sampler_task():
42 # curr_policy_version
)
if sampled_trajectory is not None:
sampled_trajectory = sampled_trajectory[0]
sampled_trajectories.append(sampled_trajectory[0])
samples_collected += 1

Expand Down
13 changes: 9 additions & 4 deletions tests/unit_tests/test_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,15 @@ async def test_sample(self, replay_buffer) -> None:
# Test a simple sampling w/ no evictions
samples = await replay_buffer.sample.call_one(curr_policy_version=1)
assert samples is not None
assert len(samples) == 2
assert len(samples[0]) == 2

# Test sampling with overriding batch size
await replay_buffer.add.call_one(trajectory_0)
samples = await replay_buffer.sample.call_one(
curr_policy_version=1, batch_size=1
)
assert samples is not None
assert len(samples) == 1
assert len(samples[0]) == 1

# Test sampling w/ overriding batch size (not enough samples in buffer, returns None)
await replay_buffer.add.call_one(trajectory_0)
Expand All @@ -105,6 +105,11 @@ async def test_sample_with_evictions(self, replay_buffer) -> None:
curr_policy_version=2, batch_size=1
)
assert samples is not None
assert len(samples) == 1
assert samples[0] == trajectory_1
assert len(samples[0]) == 1
assert samples[0][0] == trajectory_1
replay_buffer.clear.call_one().get()


if __name__ == "__main__":
# Run tests with pytest
pytest.main([__file__, "-v"])
Loading