Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions trinity/buffer/storage/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ def __init__(self, config: StorageConfig) -> None:
self.logger = get_logger(f"queue_{config.name}", in_ray_actor=True)
self.config = config
self.capacity = config.capacity
self.staleness_limit = config.staleness_limit # Optional[int]
self.max_model_version = 0 # max model version that queue has seen so far
self.queue = QueueBuffer.get_queue(config)
st_config = deepcopy(config)
st_config.wrap_in_ray = False
Expand Down Expand Up @@ -351,6 +353,9 @@ async def put_batch(self, exp_list: List) -> None:
await self.queue.put(exp_list)
if self.writer is not None:
self.writer.write(exp_list)
for exp in exp_list:
if exp.info["model_version"] > self.max_model_version:
self.max_model_version = exp.info["model_version"]
Comment on lines +356 to +358
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The loop to update max_model_version can be made more efficient. Currently, it might perform multiple assignments to self.max_model_version within a single batch. A more efficient approach is to find the maximum version within the batch first, and then update self.max_model_version only once if needed.

Suggested change
for exp in exp_list:
if exp.info["model_version"] > self.max_model_version:
self.max_model_version = exp.info["model_version"]
if exp_list:
max_version_in_batch = max(exp.info["model_version"] for exp in exp_list)
if max_version_in_batch > self.max_model_version:
self.max_model_version = max_version_in_batch


async def get_batch(self, batch_size: int, timeout: float) -> List:
"""Get batch of experience."""
Expand All @@ -361,6 +366,13 @@ async def get_batch(self, batch_size: int, timeout: float) -> List:
raise StopAsyncIteration("Queue is closed and no more items to get.")
try:
exp_list = await asyncio.wait_for(self.queue.get(), timeout=1.0)
if (self.staleness_limit is not None) and (self.staleness_limit > 0):
exp_list = [
exp
for exp in exp_list
if exp.info["model_version"]
>= self.max_model_version - self.staleness_limit
]
Comment on lines +369 to +375
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation for staleness control has a significant inefficiency when used with a priority queue that has sample reuse enabled (reuse_cooldown_time is not None). As noted in the PR description, stale samples are filtered out after they have been retrieved from the queue, but AsyncPriorityQueue will have already re-queued them. This means stale samples are never truly purged from the buffer, and the queue might fill up with them. This can lead to get_batch repeatedly fetching and discarding stale samples, potentially causing timeouts and degrading performance.

A better approach would be to prevent re-queuing of stale samples. This logic could be moved into AsyncPriorityQueue so that it can check for staleness before re-queuing an item. This would likely require QueueStorage to provide the max_model_version to the queue instance whenever it's updated. While this is a WIP, this is a critical design point to address for the feature to be robust.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, need to update AsyncPriorityQueue to account for staleness control.

self.exp_pool.extend(exp_list)
except asyncio.TimeoutError:
if time.time() - start_time > timeout:
Expand Down
1 change: 1 addition & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ class StorageConfig:

# used for StorageType.QUEUE
capacity: int = 10000
staleness_limit: Optional[int] = None
max_read_timeout: float = 1800
replay_buffer: Optional[ReplayBufferConfig] = field(default_factory=ReplayBufferConfig)

Expand Down