diff --git a/examples/grpo_frozen_lake/frozen_lake.yaml b/examples/grpo_frozen_lake/frozen_lake.yaml index 648bab71bc..84c5022374 100644 --- a/examples/grpo_frozen_lake/frozen_lake.yaml +++ b/examples/grpo_frozen_lake/frozen_lake.yaml @@ -43,8 +43,8 @@ buffer: env_max_steps: 8 agent_max_steps: 10 is_slippery: false + repeat_times: 4 rollout_args: - n: 4 top_p: 0.8 top_k: 20 default_workflow_type: 'frozen_lake_workflow' diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index 3470ed3456..c061099437 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -48,13 +48,16 @@ def setUp(self): class TestExplorerCountdownEval(BaseExplorerCase): def test_explorer(self): self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") - self.config.buffer.explorer_input.eval_tasksets.extend( + eval_tasksets = self.config.buffer.explorer_input.eval_tasksets + eval_tasksets.extend( [ get_unittest_dataset_config("countdown", "test"), get_unittest_dataset_config("eval_short"), get_unittest_dataset_config("eval_long"), ] ) + eval_tasksets[1].repeat_times = 6 + eval_tasksets[2].repeat_times = 10 self.config.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}" self.config.check_and_update() explore(self.config) @@ -65,8 +68,15 @@ def test_explorer(self): self.assertTrue(len(eval_metrics) > 0) self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) self.assertEqual(parser.metric_max_step(eval_metrics[0]), 8) - self.assertTrue("eval/eval_short/accuracy/max" in eval_metrics) - self.assertTrue("eval/eval_long/accuracy/max" in eval_metrics) + for eval_taskset, k_list in zip(eval_tasksets, [[1], [2, 4, 6], [2, 4, 8, 10]]): + for eval_stats in ["mean", "best", "worst"]: + for k in k_list: + for stats in ["mean", "std"]: + metric_name = "score" if eval_taskset.name == "countdown" else "accuracy" + self.assertIn( + f"eval/{eval_taskset.name}/{metric_name}/{eval_stats}@{k}/{stats}", + eval_metrics, + ) class TestExplorerGSM8KRULERNoEval(BaseExplorerCase): diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 9d8cebbd48..a42b85a3e0 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -84,12 +84,11 @@ def test_trainer(self): self.config.buffer.explorer_input.taskset.task_selector = TaskSelectorConfig( selector_type="shuffle", seed=42 ) - self.config.buffer.explorer_input.eval_tasksets.append( - get_unittest_dataset_config("countdown", "test") - ) - self.config.buffer.explorer_input.eval_tasksets.append( - get_unittest_dataset_config("copy_countdown", "test") - ) + eval_tasksets = self.config.buffer.explorer_input.eval_tasksets + eval_tasksets.append(get_unittest_dataset_config("countdown", "test")) + eval_tasksets.append(get_unittest_dataset_config("copy_countdown", "test")) + eval_tasksets[0].repeat_times = 4 + eval_tasksets[1].repeat_times = 4 self.config.trainer.save_interval = 4 self.config.check_and_update() _trainer_config = self.config.trainer.trainer_config @@ -148,14 +147,15 @@ def test_trainer(self): bench(self.config) parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) for prefix in ["eval", "bench"]: - countdown_metrics = parser.metric_list(f"{prefix}/countdown") - copy_countdown_metrics = parser.metric_list(f"{prefix}/copy_countdown") - self.assertTrue(len(countdown_metrics) > 0) - self.assertTrue(len(copy_countdown_metrics) > 0) - countdown_metric_steps = parser.metric_steps(countdown_metrics[0]) - countdown_copy_metric_steps = parser.metric_steps(copy_countdown_metrics[0]) - self.assertEqual([0, 4, 8], countdown_metric_steps) - self.assertEqual([0, 4, 8], countdown_copy_metric_steps) + for taskset_name in ["countdown", "copy_countdown"]: + metrics = parser.metric_list(f"{prefix}/{taskset_name}") + self.assertTrue(len(metrics) > 0) + for eval_stats in ["mean", "best", "worst"]: + for k in [2, 4]: + for stats in ["mean", "std"]: + metric_name = f"{prefix}/{taskset_name}/score/{eval_stats}@{k}/{stats}" + metric_steps = parser.metric_steps(metric_name) + self.assertEqual(metric_steps, [0, 4, 8]) def tearDown(self): # remove dir only when the test passed @@ -969,6 +969,7 @@ def test_trainer(self): self.config.buffer.explorer_input.eval_tasksets.append( get_unittest_dataset_config("gsm8k", "test") ) + self.config.buffer.explorer_input.eval_tasksets[0].repeat_times = 8 self.config.model.model_path = get_model_path() self.config.algorithm.algorithm_type = "grpo" self.config.algorithm.advantage_fn = "grpo" @@ -1019,8 +1020,12 @@ def test_trainer(self): for prefix in ["eval", "bench"]: gsm8k_metrics = parser.metric_list(f"{prefix}/gsm8k") self.assertTrue(len(gsm8k_metrics) > 0) - gsm8k_metric_steps = parser.metric_steps(gsm8k_metrics[0]) - self.assertEqual([0, 2], gsm8k_metric_steps) + for eval_stats in ["mean", "best", "worst"]: + for k in [2, 4, 8]: + for stats in ["mean", "std"]: + metric_name = f"{prefix}/gsm8k/accuracy/{eval_stats}@{k}/{stats}" + metric_steps = parser.metric_steps(metric_name) + self.assertEqual(metric_steps, [0, 2]) def tearDown(self): shutil.rmtree(self.config.checkpoint_job_dir) diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index c61e25e831..8fa3d4c03d 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -176,7 +176,7 @@ def __init__(self, config: StorageConfig): total_epochs=self.config.total_epochs if not self.config.is_eval else 1, offset=self.config.index, drop_last=not self.config.is_eval, - total_steps=self.config.total_steps, + total_steps=self.config.total_steps if not self.config.is_eval else None, enable_progress_bar=self.config.enable_progress_bar, ) self.formatter = FORMATTER.get("task")(config) diff --git a/trinity/common/config.py b/trinity/common/config.py index 8ea02095f7..158b82a170 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -84,7 +84,7 @@ class GenerationConfig: logprobs: Optional[int] = None # 0 # vLLM return `logprobs + 1` elements max_tokens: Optional[int] = None # if None, use model.max_response_tokens # repeat each task for `n` times - # ! DO NOT SET in `buffer.explorer_input.taskset.rollout_args` + # ! DO NOT SET, it will be set by `algorithm.repeat_times` or `buffer.explorer_input.eval_tasksets[i].repeat_times` n: int = 1 @@ -249,10 +249,10 @@ class TasksetConfig: enable_progress_bar: bool = False + # ! This setting is only valid for `eval_taskset`; for other taskset, it will be overridden by `algorithm.repeat_times`. + repeat_times: int = 1 # ! DO NOT SET, automatically load from checkpoint index: int = 0 - # ! DO NOT SET, automatically set from algorithm.repeat_times - repeat_times: int = 1 # ! DO NOT SET, automatically set based on train/eval is_eval: bool = False # ! DO NOT SET, automatically set from buffer.batch_size @@ -927,7 +927,7 @@ def _check_explorer_input(self) -> None: dataset.batch_size = self.buffer.batch_size if not dataset.name: dataset.name = f"eval_taskset_{idx}" - set_if_none(dataset, "repeat_times", 1) + # eval_workflow has higher priority than workflow in eval tasksets, so we set it first set_if_none(dataset, "default_workflow_type", explorer_input.default_eval_workflow_type) set_if_none(dataset, "default_workflow_type", explorer_input.default_workflow_type) diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 30ece621d3..cb421c2fce 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -397,7 +397,9 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva metric[f"{prefix}/{eval_task_name}/finished_task_count"] = len(statuses) metric.update( gather_metrics( - [status.metrics[0] for status in statuses], f"{prefix}/{eval_task_name}" + [status.metrics[0] for status in statuses], + f"{prefix}/{eval_task_name}", + output_stats=["mean", "std"], ) ) if self.eval_start_time is not None: diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index aa3b6539b9..e06b783e27 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -31,16 +31,15 @@ class TaskWrapper: results: List[Tuple[Status, List[Experience]]] = field(default_factory=list) -def calculate_task_level_metrics(metrics: List[Dict]) -> Dict[str, float]: +def calculate_task_level_metrics(metrics: List[Dict], is_eval: bool) -> Dict[str, float]: """Calculate task level metrics (mean) from multiple runs of the same task. Args: metrics (`List[Dict]`): A list of metric dictionaries from multiple runs of the same task. + is_eval (`bool`): Whether this is an evaluation task. Returns: `Dict[str, float]`: A dictionary of aggregated metrics, where each metric is averaged over all runs. - - TODO: support more aggregation methods like max, min. """ if not metrics: return {} @@ -49,7 +48,27 @@ def calculate_task_level_metrics(metrics: List[Dict]) -> Dict[str, float]: for key, value in m.items(): if isinstance(value, (int, float)): aggregated_metrics[key].append(value) - return {key: sum(values) / len(values) for key, values in aggregated_metrics.items() if values} + if is_eval: + result = {} + for key, values in aggregated_metrics.items(): + if "time/task_execution" in key or "time/run_execution" in key: + result[key] = sum(values) / len(values) + continue + k_list = [] + k = 2 + while k < len(values): + k_list.append(k) + k *= 2 + k_list.append(len(values)) + for k in k_list: + result[f"{key}/mean@{k}"] = sum(values[:k]) / k + result[f"{key}/best@{k}"] = max(values[:k]) + result[f"{key}/worst@{k}"] = min(values[:k]) + return result + else: + return { + key: sum(values) / len(values) for key, values in aggregated_metrics.items() if values + } class RunnerWrapper: @@ -340,7 +359,8 @@ def task_done_callback(self, async_task: asyncio.Task): all_success = False # calculate task level metrics task_status = Status( - ok=all_success, metrics=[calculate_task_level_metrics(task_metrics)] + ok=all_success, + metrics=[calculate_task_level_metrics(task_metrics, task.task.is_eval)], ) self.completed_tasks[task.batch_id].appendleft((task_status, task_experiences)) self.logger.debug(f"Task completed (batch_id {task.batch_id}).") diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 314f8b94e0..8917274dd5 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -141,7 +141,7 @@ async def _run_task( # repeatable workflow cannot calculate run level metrics, we use experience level metrics directly run_metrics = [exp.metrics for exp in exps if exp.metrics] for metric in run_metrics: - metric["time/task_execution"] = et - st + metric["time/run_execution"] = et - st else: exps = [] run_metrics = [] @@ -155,7 +155,7 @@ async def _run_task( et = time.time() self.runner_state["terminate_time"] = et run_metric = calculate_run_level_metrics(new_exps) - run_metric["time/task_execution"] = et - st + run_metric["time/run_execution"] = et - st run_metrics.append(run_metric) for exp in new_exps: exp.eid.run = run_id_base + i @@ -209,7 +209,7 @@ async def run_task( error_trace_back = traceback.format_exc() self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}") return ( - Status(False, metrics=[{"time/task_execution": time.time() - st}], message=str(e)), + Status(False, metrics=[{"time/run_execution": time.time() - st}], message=str(e)), [], ) diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 0eee105608..8ea9446061 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -25,18 +25,19 @@ MONITOR = Registry("monitor") -def gather_metrics(metric_list: List[Dict], prefix: str) -> Dict: +def gather_metrics( + metric_list: List[Dict], prefix: str, output_stats: List[str] = ["mean", "max", "min"] +) -> Dict: if not metric_list: return {} try: df = pd.DataFrame(metric_list) numeric_df = df.select_dtypes(include=[np.number]) - stats_df = numeric_df.agg(["mean", "max", "min"]) + stats_df = numeric_df.agg(output_stats) metric = {} for col in stats_df.columns: - metric[f"{prefix}/{col}/mean"] = stats_df.loc["mean", col].item() - metric[f"{prefix}/{col}/max"] = stats_df.loc["max", col].item() - metric[f"{prefix}/{col}/min"] = stats_df.loc["min", col].item() + for stats in output_stats: + metric[f"{prefix}/{col}/{stats}"] = stats_df.loc[stats, col].item() return metric except Exception as e: raise ValueError(f"Failed to gather metrics: {e}") from e