diff --git a/trinity/buffer/storage/queue.py b/trinity/buffer/storage/queue.py index 2d924d362e..f348590e48 100644 --- a/trinity/buffer/storage/queue.py +++ b/trinity/buffer/storage/queue.py @@ -302,6 +302,8 @@ def __init__(self, config: StorageConfig) -> None: self.logger = get_logger(f"queue_{config.name}", in_ray_actor=True) self.config = config self.capacity = config.capacity + self.staleness_limit = config.staleness_limit # Optional[int] + self.max_model_version = 0 # max model version that queue has seen so far self.queue = QueueBuffer.get_queue(config) st_config = deepcopy(config) st_config.wrap_in_ray = False @@ -351,6 +353,9 @@ async def put_batch(self, exp_list: List) -> None: await self.queue.put(exp_list) if self.writer is not None: self.writer.write(exp_list) + for exp in exp_list: + if exp.info["model_version"] > self.max_model_version: + self.max_model_version = exp.info["model_version"] async def get_batch(self, batch_size: int, timeout: float) -> List: """Get batch of experience.""" @@ -361,6 +366,13 @@ async def get_batch(self, batch_size: int, timeout: float) -> List: raise StopAsyncIteration("Queue is closed and no more items to get.") try: exp_list = await asyncio.wait_for(self.queue.get(), timeout=1.0) + if (self.staleness_limit is not None) and (self.staleness_limit > 0): + exp_list = [ + exp + for exp in exp_list + if exp.info["model_version"] + >= self.max_model_version - self.staleness_limit + ] self.exp_pool.extend(exp_list) except asyncio.TimeoutError: if time.time() - start_time > timeout: diff --git a/trinity/common/config.py b/trinity/common/config.py index c722959b96..83904336d0 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -159,6 +159,7 @@ class StorageConfig: # used for StorageType.QUEUE capacity: int = 10000 + staleness_limit: Optional[int] = None max_read_timeout: float = 1800 replay_buffer: Optional[ReplayBufferConfig] = field(default_factory=ReplayBufferConfig)