|
8 | 8 | from dataclasses import dataclass |
9 | 9 | from typing import Any |
10 | 10 |
|
11 | | -from monarch.actor import endpoint |
12 | | - |
13 | 11 | from forge.controller import ForgeActor |
14 | 12 |
|
| 13 | +from monarch.actor import endpoint |
| 14 | + |
15 | 15 |
|
16 | 16 | @dataclass |
17 | 17 | class ReplayBuffer(ForgeActor): |
@@ -57,18 +57,17 @@ async def sample(self, curr_policy_version: int, batch_size: int | None = None): |
57 | 57 |
|
58 | 58 | # TODO: Make this more efficient |
59 | 59 | idx_to_sample = self.sampler(range(len(self.buffer)), k=total_samples) |
60 | | - sorted_idxs = sorted( |
61 | | - idx_to_sample, reverse=True |
62 | | - ) # Sort in desc order to avoid shifting idxs |
63 | | - sampled_episodes = [self.buffer.pop(i) for i in sorted_idxs] |
64 | | - |
65 | | - # Reshape to (dp_size, bsz, ...) |
66 | | - reshaped_episodes = [] |
67 | | - for dp_idx in range(self.dp_size): |
68 | | - start_idx = dp_idx * bsz |
69 | | - end_idx = start_idx + bsz |
70 | | - reshaped_episodes.append(sampled_episodes[start_idx:end_idx]) |
| 60 | + sampled_episodes = [self.buffer[i] for i in idx_to_sample] |
| 61 | + |
| 62 | + # Evict sampled episodes (descending order so pops are safe) |
| 63 | + for i in sorted(idx_to_sample, reverse=True): |
| 64 | + self.buffer.pop(i) |
71 | 65 |
|
| 66 | + # Reshape into (dp_size, bsz, ...) |
| 67 | + reshaped_episodes = [ |
| 68 | + sampled_episodes[dp_idx * bsz : (dp_idx + 1) * bsz] |
| 69 | + for dp_idx in range(self.dp_size) |
| 70 | + ] |
72 | 71 | return reshaped_episodes |
73 | 72 |
|
74 | 73 | @endpoint |
|
0 commit comments