From c562a7e3d694b226ff0053d5fb14c89a3343ffb2 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Fri, 14 Nov 2025 16:41:17 +0800 Subject: [PATCH 1/4] Bug fix when set `total_steps` --- tests/buffer/task_scheduler_test.py | 35 ++++++++++++++++++++++++++--- trinity/buffer/task_scheduler.py | 2 +- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/tests/buffer/task_scheduler_test.py b/tests/buffer/task_scheduler_test.py index ee25f5856e..0cb25678f4 100644 --- a/tests/buffer/task_scheduler_test.py +++ b/tests/buffer/task_scheduler_test.py @@ -41,6 +41,28 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in @parameterized.expand( [ ( + {"batch_size": 5, "total_steps": 3}, + {"selector_type": "sequential"}, + [ + {"index": 0, "taskset_id": 1}, + {"index": 1, "taskset_id": 1}, + {"index": 2, "taskset_id": 1}, + {"index": 0, "taskset_id": 0}, + {"index": 1, "taskset_id": 0}, + {"index": 3, "taskset_id": 1}, + {"index": 4, "taskset_id": 1}, + {"index": 5, "taskset_id": 1}, + {"index": 2, "taskset_id": 0}, + {"index": 3, "taskset_id": 0}, + {"index": 6, "taskset_id": 1}, + {"index": 0, "taskset_id": 1}, + {"index": 1, "taskset_id": 1}, + {"index": 4, "taskset_id": 0}, + {"index": 0, "taskset_id": 0}, + ], + ), + ( + {"batch_size": 2, "total_epochs": 2}, {"selector_type": "sequential"}, [ {"index": 0, "taskset_id": 1}, @@ -70,6 +92,7 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in ], ), ( + {"batch_size": 2, "total_epochs": 2}, {"selector_type": "shuffle", "seed": 42}, [ {"index": 3, "taskset_id": 1}, @@ -99,6 +122,7 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in ], ), ( + {"batch_size": 2, "total_epochs": 2}, {"selector_type": "random", "seed": 42}, [ {"index": 0, "taskset_id": 1}, @@ -128,6 +152,7 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in ], ), ( + {"batch_size": 2, "total_epochs": 2}, {"selector_type": "offline_easy2hard", "feature_keys": ["feature_offline"]}, [ {"index": 3, "taskset_id": 1}, @@ -157,6 +182,7 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in ], ), ( + {"batch_size": 2, "total_epochs": 2}, {"selector_type": "difficulty_based", "feature_keys": ["feat_1", "feat_2"]}, [ {"index": 3, "taskset_id": 1}, @@ -187,10 +213,13 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in ), ] ) - async def test_task_scheduler(self, task_selector_kwargs, batch_tasks_orders) -> None: + async def test_task_scheduler( + self, buffer_config_kwargs, task_selector_kwargs, batch_tasks_orders + ) -> None: config = get_template_config() - config.buffer.batch_size = 2 - config.buffer.total_epochs = 2 + config.mode = "explore" + for key, value in buffer_config_kwargs.items(): + setattr(config.buffer, key, value) config.buffer.explorer_input.taskset = None config.buffer.explorer_input.tasksets = [ TasksetConfig( diff --git a/trinity/buffer/task_scheduler.py b/trinity/buffer/task_scheduler.py index 7619a3702f..18d2f5ad7e 100644 --- a/trinity/buffer/task_scheduler.py +++ b/trinity/buffer/task_scheduler.py @@ -143,7 +143,7 @@ async def read_async(self) -> List: else: taskset_ids = self.orders[start:] self.epoch += 1 - if self.epoch >= self.config.buffer.total_epochs: + if not self.config.buffer.total_steps and self.epoch >= self.config.buffer.total_epochs: raise StopAsyncIteration self.orders = self.build_orders(self.epoch) taskset_ids += self.orders[: (end - len(self.base_taskset_ids))] From fedc49a41355725d0161f599beb3bb6003c65926 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Fri, 14 Nov 2025 17:03:04 +0800 Subject: [PATCH 2/4] apply suggestions --- trinity/buffer/task_scheduler.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/trinity/buffer/task_scheduler.py b/trinity/buffer/task_scheduler.py index 18d2f5ad7e..294aab68b6 100644 --- a/trinity/buffer/task_scheduler.py +++ b/trinity/buffer/task_scheduler.py @@ -108,6 +108,15 @@ def build_orders(self, epoch: int): rng.shuffle(taskset_ids) return taskset_ids + def _should_stop(self) -> bool: + if self.config.buffer.total_steps: + if self.step >= self.config.buffer.total_steps: + return True + else: + if self.epoch >= self.config.buffer.total_epochs: + return True + return False + async def read_async(self) -> List: """ Asynchronously reads a batch of tasks according to the current schedule. @@ -125,12 +134,8 @@ async def read_async(self) -> List: Returns: List[Task]: A batch of tasks from potentially multiple tasksets """ - if self.config.buffer.total_steps: - if self.step >= self.config.buffer.total_steps: - raise StopAsyncIteration - else: - if self.epoch >= self.config.buffer.total_epochs: - raise StopAsyncIteration + if self._should_stop(): + raise StopAsyncIteration batch_size = self.read_batch_size start = self.step * batch_size % len(self.base_taskset_ids) @@ -143,7 +148,7 @@ async def read_async(self) -> List: else: taskset_ids = self.orders[start:] self.epoch += 1 - if not self.config.buffer.total_steps and self.epoch >= self.config.buffer.total_epochs: + if self._should_stop(): raise StopAsyncIteration self.orders = self.build_orders(self.epoch) taskset_ids += self.orders[: (end - len(self.base_taskset_ids))] From 4462fedb719ae69a1319a10924a68e114b1b12ad Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Fri, 14 Nov 2025 17:14:16 +0800 Subject: [PATCH 3/4] refactor _should_stop --- trinity/buffer/task_scheduler.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/trinity/buffer/task_scheduler.py b/trinity/buffer/task_scheduler.py index 294aab68b6..9b525b1f89 100644 --- a/trinity/buffer/task_scheduler.py +++ b/trinity/buffer/task_scheduler.py @@ -90,6 +90,13 @@ def __init__(self, explorer_state: Dict, config: Config): self.epoch = self.step * self.read_batch_size // len(self.base_taskset_ids) self.orders = self.build_orders(self.epoch) + if self.config.buffer.total_steps: + self.max_steps = self.config.buffer.total_steps + else: + self.max_steps = ( + self.config.buffer.total_epochs * len(self.base_taskset_ids) // self.read_batch_size + ) + def build_orders(self, epoch: int): """ Creates a shuffled sequence of taskset IDs to control sampling priority per step. @@ -109,13 +116,7 @@ def build_orders(self, epoch: int): return taskset_ids def _should_stop(self) -> bool: - if self.config.buffer.total_steps: - if self.step >= self.config.buffer.total_steps: - return True - else: - if self.epoch >= self.config.buffer.total_epochs: - return True - return False + return self.step >= self.max_steps async def read_async(self) -> List: """ From e7174621766657327c63eff0daa6130795692285 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Fri, 14 Nov 2025 17:22:56 +0800 Subject: [PATCH 4/4] add unittest point --- tests/buffer/task_scheduler_test.py | 26 ++++++++++++++++++++++++++ trinity/buffer/task_scheduler.py | 2 -- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/tests/buffer/task_scheduler_test.py b/tests/buffer/task_scheduler_test.py index 0cb25678f4..bff905a3e1 100644 --- a/tests/buffer/task_scheduler_test.py +++ b/tests/buffer/task_scheduler_test.py @@ -61,6 +61,32 @@ def _check_batch_tasks(self, batch_tasks: List[Task], indices: List[Dict[str, in {"index": 0, "taskset_id": 0}, ], ), + ( + {"batch_size": 5, "total_epochs": 2}, + {"selector_type": "sequential"}, + [ + {"index": 0, "taskset_id": 1}, + {"index": 1, "taskset_id": 1}, + {"index": 2, "taskset_id": 1}, + {"index": 0, "taskset_id": 0}, + {"index": 1, "taskset_id": 0}, + {"index": 3, "taskset_id": 1}, + {"index": 4, "taskset_id": 1}, + {"index": 5, "taskset_id": 1}, + {"index": 2, "taskset_id": 0}, + {"index": 3, "taskset_id": 0}, + {"index": 6, "taskset_id": 1}, + {"index": 0, "taskset_id": 1}, + {"index": 1, "taskset_id": 1}, + {"index": 4, "taskset_id": 0}, + {"index": 0, "taskset_id": 0}, + {"index": 2, "taskset_id": 1}, + {"index": 3, "taskset_id": 1}, + {"index": 4, "taskset_id": 1}, + {"index": 1, "taskset_id": 0}, + {"index": 2, "taskset_id": 0}, + ], + ), ( {"batch_size": 2, "total_epochs": 2}, {"selector_type": "sequential"}, diff --git a/trinity/buffer/task_scheduler.py b/trinity/buffer/task_scheduler.py index 9b525b1f89..01b6fa1a47 100644 --- a/trinity/buffer/task_scheduler.py +++ b/trinity/buffer/task_scheduler.py @@ -149,8 +149,6 @@ async def read_async(self) -> List: else: taskset_ids = self.orders[start:] self.epoch += 1 - if self._should_stop(): - raise StopAsyncIteration self.orders = self.build_orders(self.epoch) taskset_ids += self.orders[: (end - len(self.base_taskset_ids))]