-
Notifications
You must be signed in to change notification settings - Fork 48
[WIP] A minimal implementation of staleness control #382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -302,6 +302,8 @@ def __init__(self, config: StorageConfig) -> None: | |
| self.logger = get_logger(f"queue_{config.name}", in_ray_actor=True) | ||
| self.config = config | ||
| self.capacity = config.capacity | ||
| self.staleness_limit = config.staleness_limit # Optional[int] | ||
| self.max_model_version = 0 # max model version that queue has seen so far | ||
| self.queue = QueueBuffer.get_queue(config) | ||
| st_config = deepcopy(config) | ||
| st_config.wrap_in_ray = False | ||
|
|
@@ -351,6 +353,9 @@ async def put_batch(self, exp_list: List) -> None: | |
| await self.queue.put(exp_list) | ||
| if self.writer is not None: | ||
| self.writer.write(exp_list) | ||
| for exp in exp_list: | ||
| if exp.info["model_version"] > self.max_model_version: | ||
| self.max_model_version = exp.info["model_version"] | ||
|
|
||
| async def get_batch(self, batch_size: int, timeout: float) -> List: | ||
| """Get batch of experience.""" | ||
|
|
@@ -361,6 +366,13 @@ async def get_batch(self, batch_size: int, timeout: float) -> List: | |
| raise StopAsyncIteration("Queue is closed and no more items to get.") | ||
| try: | ||
| exp_list = await asyncio.wait_for(self.queue.get(), timeout=1.0) | ||
| if (self.staleness_limit is not None) and (self.staleness_limit > 0): | ||
| exp_list = [ | ||
| exp | ||
| for exp in exp_list | ||
| if exp.info["model_version"] | ||
| >= self.max_model_version - self.staleness_limit | ||
| ] | ||
|
Comment on lines
+369
to
+375
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation for staleness control has a significant inefficiency when used with a priority queue that has sample reuse enabled ( A better approach would be to prevent re-queuing of stale samples. This logic could be moved into
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sense, need to update |
||
| self.exp_pool.extend(exp_list) | ||
| except asyncio.TimeoutError: | ||
| if time.time() - start_time > timeout: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loop to update
max_model_versioncan be made more efficient. Currently, it might perform multiple assignments toself.max_model_versionwithin a single batch. A more efficient approach is to find the maximum version within the batch first, and then updateself.max_model_versiononly once if needed.