88from dataclasses import dataclass
99from typing import Any
1010
11- from monarch .actor import endpoint
12-
1311from forge .controller import ForgeActor
1412
13+ from monarch .actor import endpoint
14+
1515
1616@dataclass
1717class ReplayBuffer (ForgeActor ):
1818 """Simple in-memory replay buffer implementation."""
1919
2020 batch_size : int
2121 max_policy_age : int
22+ dp_size : int = 1
2223 seed : int | None = None
2324
2425 @endpoint
@@ -43,23 +44,32 @@ async def sample(self, curr_policy_version: int, batch_size: int | None = None):
4344 passed in at initialization.
4445
4546 Returns:
46- A list of sampled episodes or None if there are not enough episodes in the buffer.
47+ A list of sampled episodes with shape (dp_size, bsz, ...) or None if there are not enough episodes in the buffer.
4748 """
4849 bsz = batch_size if batch_size is not None else self .batch_size
50+ total_samples = self .dp_size * bsz
4951
5052 # Evict old episodes
5153 self ._evict (curr_policy_version )
5254
53- if bsz > len (self .buffer ):
55+ if total_samples > len (self .buffer ):
5456 return None
5557
5658 # TODO: Make this more efficient
57- idx_to_sample = self .sampler (range (len (self .buffer )), k = bsz )
59+ idx_to_sample = self .sampler (range (len (self .buffer )), k = total_samples )
5860 sorted_idxs = sorted (
5961 idx_to_sample , reverse = True
6062 ) # Sort in desc order to avoid shifting idxs
6163 sampled_episodes = [self .buffer .pop (i ) for i in sorted_idxs ]
62- return sampled_episodes
64+
65+ # Reshape to (dp_size, bsz, ...)
66+ reshaped_episodes = []
67+ for dp_idx in range (self .dp_size ):
68+ start_idx = dp_idx * bsz
69+ end_idx = start_idx + bsz
70+ reshaped_episodes .append (sampled_episodes [start_idx :end_idx ])
71+
72+ return reshaped_episodes
6373
6474 @endpoint
6575 async def evict (self , curr_policy_version : int ) -> None :
0 commit comments