Skip to content

Commit 89dd059

Browse files
authored
Fix repeat times in evaluation (#410)
1 parent 854b296 commit 89dd059

File tree

9 files changed

+77
-39
lines changed

9 files changed

+77
-39
lines changed

examples/grpo_frozen_lake/frozen_lake.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ buffer:
4343
env_max_steps: 8
4444
agent_max_steps: 10
4545
is_slippery: false
46+
repeat_times: 4
4647
rollout_args:
47-
n: 4
4848
top_p: 0.8
4949
top_k: 20
5050
default_workflow_type: 'frozen_lake_workflow'

tests/explorer/explorer_test.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,16 @@ def setUp(self):
4848
class TestExplorerCountdownEval(BaseExplorerCase):
4949
def test_explorer(self):
5050
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
51-
self.config.buffer.explorer_input.eval_tasksets.extend(
51+
eval_tasksets = self.config.buffer.explorer_input.eval_tasksets
52+
eval_tasksets.extend(
5253
[
5354
get_unittest_dataset_config("countdown", "test"),
5455
get_unittest_dataset_config("eval_short"),
5556
get_unittest_dataset_config("eval_long"),
5657
]
5758
)
59+
eval_tasksets[1].repeat_times = 6
60+
eval_tasksets[2].repeat_times = 10
5861
self.config.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
5962
self.config.check_and_update()
6063
explore(self.config)
@@ -65,8 +68,15 @@ def test_explorer(self):
6568
self.assertTrue(len(eval_metrics) > 0)
6669
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)
6770
self.assertEqual(parser.metric_max_step(eval_metrics[0]), 8)
68-
self.assertTrue("eval/eval_short/accuracy/max" in eval_metrics)
69-
self.assertTrue("eval/eval_long/accuracy/max" in eval_metrics)
71+
for eval_taskset, k_list in zip(eval_tasksets, [[1], [2, 4, 6], [2, 4, 8, 10]]):
72+
for eval_stats in ["mean", "best", "worst"]:
73+
for k in k_list:
74+
for stats in ["mean", "std"]:
75+
metric_name = "score" if eval_taskset.name == "countdown" else "accuracy"
76+
self.assertIn(
77+
f"eval/{eval_taskset.name}/{metric_name}/{eval_stats}@{k}/{stats}",
78+
eval_metrics,
79+
)
7080

7181

7282
class TestExplorerGSM8KRULERNoEval(BaseExplorerCase):

tests/trainer/trainer_test.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,11 @@ def test_trainer(self):
8484
self.config.buffer.explorer_input.taskset.task_selector = TaskSelectorConfig(
8585
selector_type="shuffle", seed=42
8686
)
87-
self.config.buffer.explorer_input.eval_tasksets.append(
88-
get_unittest_dataset_config("countdown", "test")
89-
)
90-
self.config.buffer.explorer_input.eval_tasksets.append(
91-
get_unittest_dataset_config("copy_countdown", "test")
92-
)
87+
eval_tasksets = self.config.buffer.explorer_input.eval_tasksets
88+
eval_tasksets.append(get_unittest_dataset_config("countdown", "test"))
89+
eval_tasksets.append(get_unittest_dataset_config("copy_countdown", "test"))
90+
eval_tasksets[0].repeat_times = 4
91+
eval_tasksets[1].repeat_times = 4
9392
self.config.trainer.save_interval = 4
9493
self.config.check_and_update()
9594
_trainer_config = self.config.trainer.trainer_config
@@ -148,14 +147,15 @@ def test_trainer(self):
148147
bench(self.config)
149148
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
150149
for prefix in ["eval", "bench"]:
151-
countdown_metrics = parser.metric_list(f"{prefix}/countdown")
152-
copy_countdown_metrics = parser.metric_list(f"{prefix}/copy_countdown")
153-
self.assertTrue(len(countdown_metrics) > 0)
154-
self.assertTrue(len(copy_countdown_metrics) > 0)
155-
countdown_metric_steps = parser.metric_steps(countdown_metrics[0])
156-
countdown_copy_metric_steps = parser.metric_steps(copy_countdown_metrics[0])
157-
self.assertEqual([0, 4, 8], countdown_metric_steps)
158-
self.assertEqual([0, 4, 8], countdown_copy_metric_steps)
150+
for taskset_name in ["countdown", "copy_countdown"]:
151+
metrics = parser.metric_list(f"{prefix}/{taskset_name}")
152+
self.assertTrue(len(metrics) > 0)
153+
for eval_stats in ["mean", "best", "worst"]:
154+
for k in [2, 4]:
155+
for stats in ["mean", "std"]:
156+
metric_name = f"{prefix}/{taskset_name}/score/{eval_stats}@{k}/{stats}"
157+
metric_steps = parser.metric_steps(metric_name)
158+
self.assertEqual(metric_steps, [0, 4, 8])
159159

160160
def tearDown(self):
161161
# remove dir only when the test passed
@@ -969,6 +969,7 @@ def test_trainer(self):
969969
self.config.buffer.explorer_input.eval_tasksets.append(
970970
get_unittest_dataset_config("gsm8k", "test")
971971
)
972+
self.config.buffer.explorer_input.eval_tasksets[0].repeat_times = 8
972973
self.config.model.model_path = get_model_path()
973974
self.config.algorithm.algorithm_type = "grpo"
974975
self.config.algorithm.advantage_fn = "grpo"
@@ -1019,8 +1020,12 @@ def test_trainer(self):
10191020
for prefix in ["eval", "bench"]:
10201021
gsm8k_metrics = parser.metric_list(f"{prefix}/gsm8k")
10211022
self.assertTrue(len(gsm8k_metrics) > 0)
1022-
gsm8k_metric_steps = parser.metric_steps(gsm8k_metrics[0])
1023-
self.assertEqual([0, 2], gsm8k_metric_steps)
1023+
for eval_stats in ["mean", "best", "worst"]:
1024+
for k in [2, 4, 8]:
1025+
for stats in ["mean", "std"]:
1026+
metric_name = f"{prefix}/gsm8k/accuracy/{eval_stats}@{k}/{stats}"
1027+
metric_steps = parser.metric_steps(metric_name)
1028+
self.assertEqual(metric_steps, [0, 2])
10241029

10251030
def tearDown(self):
10261031
shutil.rmtree(self.config.checkpoint_job_dir)

trinity/buffer/reader/file_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def __init__(self, config: StorageConfig):
176176
total_epochs=self.config.total_epochs if not self.config.is_eval else 1,
177177
offset=self.config.index,
178178
drop_last=not self.config.is_eval,
179-
total_steps=self.config.total_steps,
179+
total_steps=self.config.total_steps if not self.config.is_eval else None,
180180
enable_progress_bar=self.config.enable_progress_bar,
181181
)
182182
self.formatter = FORMATTER.get("task")(config)

trinity/common/config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class GenerationConfig:
8484
logprobs: Optional[int] = None # 0 # vLLM return `logprobs + 1` elements
8585
max_tokens: Optional[int] = None # if None, use model.max_response_tokens
8686
# repeat each task for `n` times
87-
# ! DO NOT SET in `buffer.explorer_input.taskset.rollout_args`
87+
# ! DO NOT SET, it will be set by `algorithm.repeat_times` or `buffer.explorer_input.eval_tasksets[i].repeat_times`
8888
n: int = 1
8989

9090

@@ -249,10 +249,10 @@ class TasksetConfig:
249249

250250
enable_progress_bar: bool = False
251251

252+
# ! This setting is only valid for `eval_taskset`; for other taskset, it will be overridden by `algorithm.repeat_times`.
253+
repeat_times: int = 1
252254
# ! DO NOT SET, automatically load from checkpoint
253255
index: int = 0
254-
# ! DO NOT SET, automatically set from algorithm.repeat_times
255-
repeat_times: int = 1
256256
# ! DO NOT SET, automatically set based on train/eval
257257
is_eval: bool = False
258258
# ! DO NOT SET, automatically set from buffer.batch_size
@@ -927,7 +927,7 @@ def _check_explorer_input(self) -> None:
927927
dataset.batch_size = self.buffer.batch_size
928928
if not dataset.name:
929929
dataset.name = f"eval_taskset_{idx}"
930-
set_if_none(dataset, "repeat_times", 1)
930+
931931
# eval_workflow has higher priority than workflow in eval tasksets, so we set it first
932932
set_if_none(dataset, "default_workflow_type", explorer_input.default_eval_workflow_type)
933933
set_if_none(dataset, "default_workflow_type", explorer_input.default_workflow_type)

trinity/explorer/explorer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,9 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
397397
metric[f"{prefix}/{eval_task_name}/finished_task_count"] = len(statuses)
398398
metric.update(
399399
gather_metrics(
400-
[status.metrics[0] for status in statuses], f"{prefix}/{eval_task_name}"
400+
[status.metrics[0] for status in statuses],
401+
f"{prefix}/{eval_task_name}",
402+
output_stats=["mean", "std"],
401403
)
402404
)
403405
if self.eval_start_time is not None:

trinity/explorer/scheduler.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,15 @@ class TaskWrapper:
3131
results: List[Tuple[Status, List[Experience]]] = field(default_factory=list)
3232

3333

34-
def calculate_task_level_metrics(metrics: List[Dict]) -> Dict[str, float]:
34+
def calculate_task_level_metrics(metrics: List[Dict], is_eval: bool) -> Dict[str, float]:
3535
"""Calculate task level metrics (mean) from multiple runs of the same task.
3636
3737
Args:
3838
metrics (`List[Dict]`): A list of metric dictionaries from multiple runs of the same task.
39+
is_eval (`bool`): Whether this is an evaluation task.
3940
4041
Returns:
4142
`Dict[str, float]`: A dictionary of aggregated metrics, where each metric is averaged over all runs.
42-
43-
TODO: support more aggregation methods like max, min.
4443
"""
4544
if not metrics:
4645
return {}
@@ -49,7 +48,27 @@ def calculate_task_level_metrics(metrics: List[Dict]) -> Dict[str, float]:
4948
for key, value in m.items():
5049
if isinstance(value, (int, float)):
5150
aggregated_metrics[key].append(value)
52-
return {key: sum(values) / len(values) for key, values in aggregated_metrics.items() if values}
51+
if is_eval:
52+
result = {}
53+
for key, values in aggregated_metrics.items():
54+
if "time/task_execution" in key or "time/run_execution" in key:
55+
result[key] = sum(values) / len(values)
56+
continue
57+
k_list = []
58+
k = 2
59+
while k < len(values):
60+
k_list.append(k)
61+
k *= 2
62+
k_list.append(len(values))
63+
for k in k_list:
64+
result[f"{key}/mean@{k}"] = sum(values[:k]) / k
65+
result[f"{key}/best@{k}"] = max(values[:k])
66+
result[f"{key}/worst@{k}"] = min(values[:k])
67+
return result
68+
else:
69+
return {
70+
key: sum(values) / len(values) for key, values in aggregated_metrics.items() if values
71+
}
5372

5473

5574
class RunnerWrapper:
@@ -340,7 +359,8 @@ def task_done_callback(self, async_task: asyncio.Task):
340359
all_success = False
341360
# calculate task level metrics
342361
task_status = Status(
343-
ok=all_success, metrics=[calculate_task_level_metrics(task_metrics)]
362+
ok=all_success,
363+
metrics=[calculate_task_level_metrics(task_metrics, task.task.is_eval)],
344364
)
345365
self.completed_tasks[task.batch_id].appendleft((task_status, task_experiences))
346366
self.logger.debug(f"Task completed (batch_id {task.batch_id}).")

trinity/explorer/workflow_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ async def _run_task(
141141
# repeatable workflow cannot calculate run level metrics, we use experience level metrics directly
142142
run_metrics = [exp.metrics for exp in exps if exp.metrics]
143143
for metric in run_metrics:
144-
metric["time/task_execution"] = et - st
144+
metric["time/run_execution"] = et - st
145145
else:
146146
exps = []
147147
run_metrics = []
@@ -155,7 +155,7 @@ async def _run_task(
155155
et = time.time()
156156
self.runner_state["terminate_time"] = et
157157
run_metric = calculate_run_level_metrics(new_exps)
158-
run_metric["time/task_execution"] = et - st
158+
run_metric["time/run_execution"] = et - st
159159
run_metrics.append(run_metric)
160160
for exp in new_exps:
161161
exp.eid.run = run_id_base + i
@@ -209,7 +209,7 @@ async def run_task(
209209
error_trace_back = traceback.format_exc()
210210
self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}")
211211
return (
212-
Status(False, metrics=[{"time/task_execution": time.time() - st}], message=str(e)),
212+
Status(False, metrics=[{"time/run_execution": time.time() - st}], message=str(e)),
213213
[],
214214
)
215215

trinity/utils/monitor.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,19 @@
2525
MONITOR = Registry("monitor")
2626

2727

28-
def gather_metrics(metric_list: List[Dict], prefix: str) -> Dict:
28+
def gather_metrics(
29+
metric_list: List[Dict], prefix: str, output_stats: List[str] = ["mean", "max", "min"]
30+
) -> Dict:
2931
if not metric_list:
3032
return {}
3133
try:
3234
df = pd.DataFrame(metric_list)
3335
numeric_df = df.select_dtypes(include=[np.number])
34-
stats_df = numeric_df.agg(["mean", "max", "min"])
36+
stats_df = numeric_df.agg(output_stats)
3537
metric = {}
3638
for col in stats_df.columns:
37-
metric[f"{prefix}/{col}/mean"] = stats_df.loc["mean", col].item()
38-
metric[f"{prefix}/{col}/max"] = stats_df.loc["max", col].item()
39-
metric[f"{prefix}/{col}/min"] = stats_df.loc["min", col].item()
39+
for stats in output_stats:
40+
metric[f"{prefix}/{col}/{stats}"] = stats_df.loc[stats, col].item()
4041
return metric
4142
except Exception as e:
4243
raise ValueError(f"Failed to gather metrics: {e}") from e

0 commit comments

Comments
 (0)