Skip to content

Commit 1371400

Browse files
committed
updated sampling logic to not return sorted samples
1 parent 3cb5d23 commit 1371400

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

src/forge/actors/replay_buffer.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from dataclasses import dataclass
99
from typing import Any
1010

11-
from monarch.actor import endpoint
12-
1311
from forge.controller import ForgeActor
1412

13+
from monarch.actor import endpoint
14+
1515

1616
@dataclass
1717
class ReplayBuffer(ForgeActor):
@@ -57,18 +57,17 @@ async def sample(self, curr_policy_version: int, batch_size: int | None = None):
5757

5858
# TODO: Make this more efficient
5959
idx_to_sample = self.sampler(range(len(self.buffer)), k=total_samples)
60-
sorted_idxs = sorted(
61-
idx_to_sample, reverse=True
62-
) # Sort in desc order to avoid shifting idxs
63-
sampled_episodes = [self.buffer.pop(i) for i in sorted_idxs]
64-
65-
# Reshape to (dp_size, bsz, ...)
66-
reshaped_episodes = []
67-
for dp_idx in range(self.dp_size):
68-
start_idx = dp_idx * bsz
69-
end_idx = start_idx + bsz
70-
reshaped_episodes.append(sampled_episodes[start_idx:end_idx])
60+
sampled_episodes = [self.buffer[i] for i in idx_to_sample]
61+
62+
# Evict sampled episodes (descending order so pops are safe)
63+
for i in sorted(idx_to_sample, reverse=True):
64+
self.buffer.pop(i)
7165

66+
# Reshape into (dp_size, bsz, ...)
67+
reshaped_episodes = [
68+
sampled_episodes[dp_idx * bsz : (dp_idx + 1) * bsz]
69+
for dp_idx in range(self.dp_size)
70+
]
7271
return reshaped_episodes
7372

7473
@endpoint

0 commit comments

Comments
 (0)