diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index ae4222b538..60c9cf9971 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -785,6 +785,18 @@ async def test_dynamic_timeout(self): scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) await scheduler.start() tasks = [] + tasks.extend(generate_tasks(0, timeout_num=4, repeat_times=1, timeout_seconds=1)) + for task in tasks: + task.is_eval = True + scheduler.schedule( + tasks, batch_id="0/eval" + ) # eval tasks will not count into dynamic timeout + statuses, exps = await scheduler.get_results(batch_id="0/eval") + self.assertEqual(len(statuses), 4) + self.assertEqual(len(exps), 0) + self.assertEqual(scheduler.total_running_time, 0) + self.assertEqual(scheduler.total_completed_tasks, 0) + tasks = [] # generate 4 tasks that will run 1 second tasks.extend(generate_tasks(0, timeout_num=4, repeat_times=1, timeout_seconds=1)) scheduler.schedule(tasks, batch_id=0) # first step will not use dynamic timeout diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index a873401599..4f4d54ea22 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -1005,10 +1005,18 @@ def test_trainer(self): self.config.algorithm.repeat_times = 4 self.config.buffer.batch_size = 4 self.config.buffer.total_steps = 2 - self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config( + "countdown", "train" + ) + self.config.buffer.explorer_input.eval_tasksets = [ + get_unittest_dataset_config("countdown", "test") + ] + self.config.buffer.eval_interval = 4 # only eval on start self.config.name = f"explore-over-rollout-{datetime.now().strftime('%Y%m%d%H%M%S')}" self.config.explorer.over_rollout.ratio = 0.5 # set over rollout rate to 50%, which means only wait for 2 (4 * 50%) tasks in each steps self.config.explorer.over_rollout.wait_after_min = 0 + self.config.explorer.dynamic_timeout.enable = True + self.config.explorer.dynamic_timeout.ratio = 2 self.config.algorithm.algorithm_type = "grpo" self.config.algorithm.advantage_fn = "grpo" self.config.algorithm.advantage_fn_args = { @@ -1022,7 +1030,7 @@ def test_trainer(self): rollout_metrics = parser.metric_list("rollout") self.assertTrue(len(rollout_metrics) > 0) eval_metrics = parser.metric_list("eval") - self.assertTrue(len(eval_metrics) == 0) + self.assertTrue(len(eval_metrics) > 0) self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2) self.assertTrue(parser.metric_exist("experience_pipeline/experience_count")) experience_counts = parser.metric_values("experience_pipeline/experience_count") diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index d86859dbc1..aa3b6539b9 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -26,12 +26,22 @@ class TaskWrapper: task: Task batch_id: Union[int, str] - sub_task_num: int = 1 + sub_task_num: int = 1 # number of sub tasks splitted from this task + # if max_repeat_times_per_runner is set, one task may be splitted into multiple sub tasks results: List[Tuple[Status, List[Experience]]] = field(default_factory=list) def calculate_task_level_metrics(metrics: List[Dict]) -> Dict[str, float]: - """Calculate task level metrics from experiences.""" + """Calculate task level metrics (mean) from multiple runs of the same task. + + Args: + metrics (`List[Dict]`): A list of metric dictionaries from multiple runs of the same task. + + Returns: + `Dict[str, float]`: A dictionary of aggregated metrics, where each metric is averaged over all runs. + + TODO: support more aggregation methods like max, min. + """ if not metrics: return {} aggregated_metrics: Dict[str, List[float]] = defaultdict(list) @@ -312,11 +322,13 @@ def task_done_callback(self, async_task: asyncio.Task): return else: status, exps, runner_id, run_time = async_task.result() - self.total_running_time += run_time - self.total_completed_tasks += 1 + if not task.task.is_eval: # only count running time for non-eval tasks + self.total_running_time += run_time + self.total_completed_tasks += 1 task.results.append((status, exps)) self.busy_runners.pop(runner_id) self.idle_runners.add(runner_id) + # If all sub runs in a task are completed if len(task.results) == task.sub_task_num: task_experiences = [] task_metrics = [] @@ -326,6 +338,7 @@ def task_done_callback(self, async_task: asyncio.Task): task_experiences.extend(exp) if not s.ok: all_success = False + # calculate task level metrics task_status = Status( ok=all_success, metrics=[calculate_task_level_metrics(task_metrics)] )