Skip to content

Commit efd7bd2

Browse files
committed
Fix repeat times in evaluation
1 parent 404bc13 commit efd7bd2

File tree

9 files changed

+67
-27
lines changed

9 files changed

+67
-27
lines changed

tests/trainer/trainer_test.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def test_trainer(self):
9090
self.config.buffer.explorer_input.eval_tasksets.append(
9191
get_unittest_dataset_config("copy_countdown", "test")
9292
)
93+
self.config.buffer.explorer_input.eval_tasksets[0].eval_at_k = [1, 2]
94+
self.config.buffer.explorer_input.eval_tasksets[1].eval_at_k = [3, 4]
9395
self.config.trainer.save_interval = 4
9496
self.config.check_and_update()
9597
_trainer_config = self.config.trainer.trainer_config
@@ -147,15 +149,17 @@ def test_trainer(self):
147149
self.config.check_and_update()
148150
bench(self.config)
149151
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
152+
eval_tasksets = self.config.buffer.explorer_input.eval_tasksets
150153
for prefix in ["eval", "bench"]:
151-
countdown_metrics = parser.metric_list(f"{prefix}/countdown")
152-
copy_countdown_metrics = parser.metric_list(f"{prefix}/copy_countdown")
153-
self.assertTrue(len(countdown_metrics) > 0)
154-
self.assertTrue(len(copy_countdown_metrics) > 0)
155-
countdown_metric_steps = parser.metric_steps(countdown_metrics[0])
156-
countdown_copy_metric_steps = parser.metric_steps(copy_countdown_metrics[0])
157-
self.assertEqual([0, 4, 8], countdown_metric_steps)
158-
self.assertEqual([0, 4, 8], countdown_copy_metric_steps)
154+
for eval_taskset, taskset_name in zip(eval_tasksets, ["countdown", "copy_countdown"]):
155+
metrics = parser.metric_list(f"{prefix}/{taskset_name}")
156+
self.assertTrue(len(metrics) > 0)
157+
for eval_stats in ["mean", "best", "worst"]:
158+
for k in eval_taskset.eval_at_k:
159+
for stats in ["mean", "std"]:
160+
metric_name = f"{prefix}/{taskset_name}/score/{eval_stats}@{k}/{stats}"
161+
metric_steps = parser.metric_steps(metric_name)
162+
self.assertEqual(metric_steps, [0, 4, 8])
159163

160164
def tearDown(self):
161165
# remove dir only when the test passed
@@ -969,6 +973,7 @@ def test_trainer(self):
969973
self.config.buffer.explorer_input.eval_tasksets.append(
970974
get_unittest_dataset_config("gsm8k", "test")
971975
)
976+
self.config.buffer.explorer_input.eval_tasksets[0].eval_at_k = [1, 2, 4]
972977
self.config.model.model_path = get_model_path()
973978
self.config.algorithm.algorithm_type = "grpo"
974979
self.config.algorithm.advantage_fn = "grpo"
@@ -1016,11 +1021,16 @@ def test_trainer(self):
10161021
self.config.check_and_update()
10171022
bench(self.config)
10181023
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
1024+
eval_taskset = self.config.buffer.explorer_input.eval_tasksets[0]
10191025
for prefix in ["eval", "bench"]:
10201026
gsm8k_metrics = parser.metric_list(f"{prefix}/gsm8k")
10211027
self.assertTrue(len(gsm8k_metrics) > 0)
1022-
gsm8k_metric_steps = parser.metric_steps(gsm8k_metrics[0])
1023-
self.assertEqual([0, 2], gsm8k_metric_steps)
1028+
for eval_stats in ["mean", "best", "worst"]:
1029+
for k in eval_taskset.eval_at_k:
1030+
for stats in ["mean", "std"]:
1031+
metric_name = f"{prefix}/gsm8k/accuracy/{eval_stats}@{k}/{stats}"
1032+
metric_steps = parser.metric_steps(metric_name)
1033+
self.assertEqual(metric_steps, [0, 2])
10241034

10251035
def tearDown(self):
10261036
shutil.rmtree(self.config.checkpoint_job_dir)

trinity/algorithm/sample_strategy/mix_sample_strategy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,12 @@ def default_args(cls) -> Dict:
114114

115115
def state_dict(self) -> dict:
116116
return {
117-
"usal_buffer": self.usual_exp_buffer.state_dict(),
117+
"usual_buffer": self.usual_exp_buffer.state_dict(),
118118
"expert_buffer": self.expert_exp_buffer.state_dict(),
119119
}
120120

121121
def load_state_dict(self, state_dict: dict) -> None:
122-
if state_dict.get("usal_buffer", None):
123-
self.usual_exp_buffer.load_state_dict(state_dict["usal_buffer"])
122+
if state_dict.get("usual_buffer", None):
123+
self.usual_exp_buffer.load_state_dict(state_dict["usual_buffer"])
124124
if state_dict.get("expert_buffer", None):
125125
self.expert_exp_buffer.load_state_dict(state_dict["expert_buffer"])

trinity/buffer/reader/file_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def __init__(self, config: StorageConfig):
176176
total_epochs=self.config.total_epochs if not self.config.is_eval else 1,
177177
offset=self.config.index,
178178
drop_last=not self.config.is_eval,
179-
total_steps=self.config.total_steps,
179+
total_steps=self.config.total_steps if not self.config.is_eval else None,
180180
enable_progress_bar=self.config.enable_progress_bar,
181181
)
182182
self.formatter = FORMATTER.get("task")(config)

trinity/buffer/schema/formatter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def format(self, sample: Dict) -> Task:
6565
workflow_args=self.config.workflow_args,
6666
reward_fn_args=self.config.reward_fn_args,
6767
is_eval=self.config.is_eval,
68+
eval_at_k=self.config.eval_at_k,
6869
raw_task=sample,
6970
)
7071

trinity/common/config.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class GenerationConfig:
8484
logprobs: Optional[int] = None # 0 # vLLM return `logprobs + 1` elements
8585
max_tokens: Optional[int] = None # if None, use model.max_response_tokens
8686
# repeat each task for `n` times
87-
# ! DO NOT SET in `buffer.explorer_input.taskset.rollout_args`
87+
# ! DO NOT SET, it will be set by `algorithm.repeat_times` or `max(buffer.explorer_input.eval_tasksets[i].eval_at_k)`
8888
n: int = 1
8989

9090

@@ -196,6 +196,7 @@ class StorageConfig:
196196
workflow_args: dict = field(default_factory=dict)
197197
reward_fn_args: dict = field(default_factory=dict)
198198
task_selector: TaskSelectorConfig = field(default_factory=TaskSelectorConfig)
199+
eval_at_k: List[int] = field(default_factory=lambda: [1])
199200

200201
# enable progress bar (tqdm) for _HFBatchReader
201202
enable_progress_bar: Optional[bool] = False
@@ -237,6 +238,7 @@ class TasksetConfig:
237238
workflow_args: dict = field(default_factory=dict)
238239
reward_fn_args: dict = field(default_factory=dict)
239240
task_selector: TaskSelectorConfig = field(default_factory=TaskSelectorConfig)
241+
eval_at_k: List[int] = field(default_factory=lambda: [1])
240242

241243
# used for StorageType.FILE
242244
split: str = "train"
@@ -251,8 +253,8 @@ class TasksetConfig:
251253

252254
# ! DO NOT SET, automatically load from checkpoint
253255
index: int = 0
254-
# ! DO NOT SET, automatically set from algorithm.repeat_times
255-
repeat_times: int = 1
256+
# ! DO NOT SET, automatically set from `algorithm.repeat_times` or `max(buffer.explorer_input.eval_tasksets[i].eval_at_k)`
257+
repeat_times: Optional[int] = None
256258
# ! DO NOT SET, automatically set based on train/eval
257259
is_eval: bool = False
258260
# ! DO NOT SET, automatically set from buffer.batch_size
@@ -268,6 +270,7 @@ def to_storage_config(self) -> StorageConfig:
268270
storage_type=self.storage_type,
269271
path=self.path,
270272
task_selector=self.task_selector,
273+
eval_at_k=self.eval_at_k,
271274
repeat_times=self.repeat_times,
272275
split=self.split,
273276
subset_name=self.subset_name,
@@ -907,6 +910,7 @@ def _check_explorer_input(self) -> None:
907910
"`buffer.explorer_input.taskset.repeat_times` is set to `algorithm.repeat_times`"
908911
f" (={self.algorithm.repeat_times})."
909912
)
913+
taskset.eval_at_k = []
910914
taskset.total_epochs = self.buffer.total_epochs
911915
taskset.total_steps = self.buffer.total_steps
912916
taskset.batch_size = self.buffer.batch_size
@@ -927,7 +931,8 @@ def _check_explorer_input(self) -> None:
927931
dataset.batch_size = self.buffer.batch_size
928932
if not dataset.name:
929933
dataset.name = f"eval_taskset_{idx}"
930-
set_if_none(dataset, "repeat_times", 1)
934+
dataset.repeat_times = max(dataset.eval_at_k)
935+
931936
# eval_workflow has higher priority than workflow in eval tasksets, so we set it first
932937
set_if_none(dataset, "default_workflow_type", explorer_input.default_eval_workflow_type)
933938
set_if_none(dataset, "default_workflow_type", explorer_input.default_workflow_type)

trinity/common/workflows/workflow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class Task(dict):
3030
workflow_args: dict = field(default_factory=dict)
3131
reward_fn_args: dict = field(default_factory=dict)
3232
is_eval: bool = False
33+
eval_at_k: List[int] = field(default_factory=lambda: [1])
3334
reward_fn: Optional[Type[RewardFn]] = None
3435
raw_task: Optional[dict] = None # The raw data sample
3536

trinity/explorer/explorer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,9 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
397397
metric[f"{prefix}/{eval_task_name}/finished_task_count"] = len(statuses)
398398
metric.update(
399399
gather_metrics(
400-
[status.metrics[0] for status in statuses], f"{prefix}/{eval_task_name}"
400+
[status.metrics[0] for status in statuses],
401+
f"{prefix}/{eval_task_name}",
402+
output_stats=["mean", "std"],
401403
)
402404
)
403405
if self.eval_start_time is not None:

trinity/explorer/scheduler.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ class TaskWrapper:
3030
results: List[Tuple[Status, List[Experience]]] = field(default_factory=list)
3131

3232

33-
def calculate_task_level_metrics(metrics: List[Dict]) -> Dict[str, float]:
33+
def calculate_task_level_metrics(
34+
metrics: List[Dict], is_eval: bool, eval_at_k: List[int]
35+
) -> Dict[str, float]:
3436
"""Calculate task level metrics from experiences."""
3537
if not metrics:
3638
return {}
@@ -39,7 +41,20 @@ def calculate_task_level_metrics(metrics: List[Dict]) -> Dict[str, float]:
3941
for key, value in m.items():
4042
if isinstance(value, (int, float)):
4143
aggregated_metrics[key].append(value)
42-
return {key: sum(values) / len(values) for key, values in aggregated_metrics.items() if values}
44+
if is_eval:
45+
result = {}
46+
for key, values in aggregated_metrics.items():
47+
for k in eval_at_k:
48+
if k > len(values):
49+
continue
50+
result[f"{key}/mean@{k}"] = sum(values[:k]) / k
51+
result[f"{key}/best@{k}"] = max(values[:k]) / k
52+
result[f"{key}/worst@{k}"] = min(values[:k]) / k
53+
return result
54+
else:
55+
return {
56+
key: sum(values) / len(values) for key, values in aggregated_metrics.items() if values
57+
}
4358

4459

4560
class RunnerWrapper:
@@ -327,7 +342,12 @@ def task_done_callback(self, async_task: asyncio.Task):
327342
if not s.ok:
328343
all_success = False
329344
task_status = Status(
330-
ok=all_success, metrics=[calculate_task_level_metrics(task_metrics)]
345+
ok=all_success,
346+
metrics=[
347+
calculate_task_level_metrics(
348+
task_metrics, task.task.is_eval, task.task.eval_at_k
349+
)
350+
],
331351
)
332352
self.completed_tasks[task.batch_id].appendleft((task_status, task_experiences))
333353
self.logger.debug(f"Task completed (batch_id {task.batch_id}).")

trinity/utils/monitor.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,19 @@
2525
MONITOR = Registry("monitor")
2626

2727

28-
def gather_metrics(metric_list: List[Dict], prefix: str) -> Dict:
28+
def gather_metrics(
29+
metric_list: List[Dict], prefix: str, output_stats: List[str] = ["mean", "max", "min"]
30+
) -> Dict:
2931
if not metric_list:
3032
return {}
3133
try:
3234
df = pd.DataFrame(metric_list)
3335
numeric_df = df.select_dtypes(include=[np.number])
34-
stats_df = numeric_df.agg(["mean", "max", "min"])
36+
stats_df = numeric_df.agg(output_stats)
3537
metric = {}
3638
for col in stats_df.columns:
37-
metric[f"{prefix}/{col}/mean"] = stats_df.loc["mean", col].item()
38-
metric[f"{prefix}/{col}/max"] = stats_df.loc["max", col].item()
39-
metric[f"{prefix}/{col}/min"] = stats_df.loc["min", col].item()
39+
for stats in output_stats:
40+
metric[f"{prefix}/{col}/{stats}"] = stats_df.loc[stats, col].item()
4041
return metric
4142
except Exception as e:
4243
raise ValueError(f"Failed to gather metrics: {e}") from e

0 commit comments

Comments
 (0)