-
Notifications
You must be signed in to change notification settings - Fork 16
Support dp_size in replay buffer #93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 10 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
494654b
add dp size in replay buffer
DNXie 88f0672
update metric name
DNXie 4647052
fix test
DNXie f55b4a4
fix replay buffer tests
DNXie 98649c1
fix inconsistencies
DNXie 0b008ac
Merge branch 'main' into replay_buffer_dp_size
DNXie 4a64f9a
update config.
DNXie 3cb5d23
fix lint
DNXie 1371400
updated sampling logic to not return sorted samples
DNXie 9357cab
fix lint
DNXie 16a1b97
make sampling more efficient
DNXie 7605132
add test case
DNXie 2a4a7ac
add dprank to trainer
DNXie File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,8 +17,9 @@ | |
class ReplayBuffer(ForgeActor): | ||
"""Simple in-memory replay buffer implementation.""" | ||
|
||
batch_size: int = 4 | ||
max_policy_age: int = 0 | ||
batch_size: int | ||
max_policy_age: int | ||
dp_size: int = 1 | ||
seed: int | None = None | ||
|
||
@endpoint | ||
|
@@ -43,23 +44,31 @@ async def sample(self, curr_policy_version: int, batch_size: int | None = None): | |
passed in at initialization. | ||
|
||
Returns: | ||
A list of sampled episodes or None if there are not enough episodes in the buffer. | ||
A list of sampled episodes with shape (dp_size, bsz, ...) or None if there are not enough episodes in the buffer. | ||
""" | ||
bsz = batch_size if batch_size is not None else self.batch_size | ||
total_samples = self.dp_size * bsz | ||
|
||
# Evict old episodes | ||
self._evict(curr_policy_version) | ||
|
||
if bsz > len(self.buffer): | ||
if total_samples > len(self.buffer): | ||
return None | ||
|
||
# TODO: Make this more efficient | ||
idx_to_sample = self.sampler(range(len(self.buffer)), k=bsz) | ||
sorted_idxs = sorted( | ||
idx_to_sample, reverse=True | ||
) # Sort in desc order to avoid shifting idxs | ||
sampled_episodes = [self.buffer.pop(i) for i in sorted_idxs] | ||
return sampled_episodes | ||
idx_to_sample = self.sampler(range(len(self.buffer)), k=total_samples) | ||
sampled_episodes = [self.buffer[i] for i in idx_to_sample] | ||
|
||
# Evict sampled episodes (descending order so pops are safe) | ||
for i in sorted(idx_to_sample, reverse=True): | ||
|
||
self.buffer.pop(i) | ||
|
||
# Reshape into (dp_size, bsz, ...) | ||
reshaped_episodes = [ | ||
sampled_episodes[dp_idx * bsz : (dp_idx + 1) * bsz] | ||
for dp_idx in range(self.dp_size) | ||
] | ||
return reshaped_episodes | ||
|
||
@endpoint | ||
async def evict(self, curr_policy_version: int) -> None: | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.