diff --git a/tests/buffer/task_scheduler_test.py b/tests/buffer/task_scheduler_test.py index ee25f5856e..bff905a3e1 100644 --- a/tests/buffer/task_scheduler_test.py +++ b/tests/buffer/task_scheduler_test.py @@ -41,6 +41,54 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in @parameterized.expand( [ ( + {"batch_size": 5, "total_steps": 3}, + {"selector_type": "sequential"}, + [ + {"index": 0, "taskset_id": 1}, + {"index": 1, "taskset_id": 1}, + {"index": 2, "taskset_id": 1}, + {"index": 0, "taskset_id": 0}, + {"index": 1, "taskset_id": 0}, + {"index": 3, "taskset_id": 1}, + {"index": 4, "taskset_id": 1}, + {"index": 5, "taskset_id": 1}, + {"index": 2, "taskset_id": 0}, + {"index": 3, "taskset_id": 0}, + {"index": 6, "taskset_id": 1}, + {"index": 0, "taskset_id": 1}, + {"index": 1, "taskset_id": 1}, + {"index": 4, "taskset_id": 0}, + {"index": 0, "taskset_id": 0}, + ], + ), + ( + {"batch_size": 5, "total_epochs": 2}, + {"selector_type": "sequential"}, + [ + {"index": 0, "taskset_id": 1}, + {"index": 1, "taskset_id": 1}, + {"index": 2, "taskset_id": 1}, + {"index": 0, "taskset_id": 0}, + {"index": 1, "taskset_id": 0}, + {"index": 3, "taskset_id": 1}, + {"index": 4, "taskset_id": 1}, + {"index": 5, "taskset_id": 1}, + {"index": 2, "taskset_id": 0}, + {"index": 3, "taskset_id": 0}, + {"index": 6, "taskset_id": 1}, + {"index": 0, "taskset_id": 1}, + {"index": 1, "taskset_id": 1}, + {"index": 4, "taskset_id": 0}, + {"index": 0, "taskset_id": 0}, + {"index": 2, "taskset_id": 1}, + {"index": 3, "taskset_id": 1}, + {"index": 4, "taskset_id": 1}, + {"index": 1, "taskset_id": 0}, + {"index": 2, "taskset_id": 0}, + ], + ), + ( + {"batch_size": 2, "total_epochs": 2}, {"selector_type": "sequential"}, [ {"index": 0, "taskset_id": 1}, @@ -70,6 +118,7 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in ], ), ( + {"batch_size": 2, "total_epochs": 2}, {"selector_type": "shuffle", "seed": 42}, [ {"index": 3, "taskset_id": 1}, @@ -99,6 +148,7 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in ], ), ( + {"batch_size": 2, "total_epochs": 2}, {"selector_type": "random", "seed": 42}, [ {"index": 0, "taskset_id": 1}, @@ -128,6 +178,7 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in ], ), ( + {"batch_size": 2, "total_epochs": 2}, {"selector_type": "offline_easy2hard", "feature_keys": ["feature_offline"]}, [ {"index": 3, "taskset_id": 1}, @@ -157,6 +208,7 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in ], ), ( + {"batch_size": 2, "total_epochs": 2}, {"selector_type": "difficulty_based", "feature_keys": ["feat_1", "feat_2"]}, [ {"index": 3, "taskset_id": 1}, @@ -187,10 +239,13 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in ), ] ) - async def test_task_scheduler(self, task_selector_kwargs, batch_tasks_orders) -> None: + async def test_task_scheduler( + self, buffer_config_kwargs, task_selector_kwargs, batch_tasks_orders + ) -> None: config = get_template_config() - config.buffer.batch_size = 2 - config.buffer.total_epochs = 2 + config.mode = "explore" + for key, value in buffer_config_kwargs.items(): + setattr(config.buffer, key, value) config.buffer.explorer_input.taskset = None config.buffer.explorer_input.tasksets = [ TasksetConfig( diff --git a/trinity/buffer/task_scheduler.py b/trinity/buffer/task_scheduler.py index 7619a3702f..01b6fa1a47 100644 --- a/trinity/buffer/task_scheduler.py +++ b/trinity/buffer/task_scheduler.py @@ -90,6 +90,13 @@ def __init__(self, explorer_state: Dict, config: Config): self.epoch = self.step * self.read_batch_size // len(self.base_taskset_ids) self.orders = self.build_orders(self.epoch) + if self.config.buffer.total_steps: + self.max_steps = self.config.buffer.total_steps + else: + self.max_steps = ( + self.config.buffer.total_epochs * len(self.base_taskset_ids) // self.read_batch_size + ) + def build_orders(self, epoch: int): """ Creates a shuffled sequence of taskset IDs to control sampling priority per step. @@ -108,6 +115,9 @@ def build_orders(self, epoch: int): rng.shuffle(taskset_ids) return taskset_ids + def _should_stop(self) -> bool: + return self.step >= self.max_steps + async def read_async(self) -> List: """ Asynchronously reads a batch of tasks according to the current schedule. @@ -125,12 +135,8 @@ async def read_async(self) -> List: Returns: List[Task]: A batch of tasks from potentially multiple tasksets """ - if self.config.buffer.total_steps: - if self.step >= self.config.buffer.total_steps: - raise StopAsyncIteration - else: - if self.epoch >= self.config.buffer.total_epochs: - raise StopAsyncIteration + if self._should_stop(): + raise StopAsyncIteration batch_size = self.read_batch_size start = self.step * batch_size % len(self.base_taskset_ids) @@ -143,8 +149,6 @@ async def read_async(self) -> List: else: taskset_ids = self.orders[start:] self.epoch += 1 - if self.epoch >= self.config.buffer.total_epochs: - raise StopAsyncIteration self.orders = self.build_orders(self.epoch) taskset_ids += self.orders[: (end - len(self.base_taskset_ids))]