Skip to content

Commit fa2dbf2

Browse files
committed
remove eval_at_k
1 parent aaa6aae commit fa2dbf2

File tree

5 files changed

+21
-38
lines changed

5 files changed

+21
-38
lines changed

tests/trainer/trainer_test.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,11 @@ def test_trainer(self):
8484
self.config.buffer.explorer_input.taskset.task_selector = TaskSelectorConfig(
8585
selector_type="shuffle", seed=42
8686
)
87-
self.config.buffer.explorer_input.eval_tasksets.append(
88-
get_unittest_dataset_config("countdown", "test")
89-
)
90-
self.config.buffer.explorer_input.eval_tasksets.append(
91-
get_unittest_dataset_config("copy_countdown", "test")
92-
)
93-
self.config.buffer.explorer_input.eval_tasksets[0].eval_at_k = [1, 2]
94-
self.config.buffer.explorer_input.eval_tasksets[1].eval_at_k = [3, 4]
87+
eval_tasksets = self.config.buffer.explorer_input.eval_tasksets
88+
eval_tasksets.append(get_unittest_dataset_config("countdown", "test"))
89+
eval_tasksets.append(get_unittest_dataset_config("copy_countdown", "test"))
90+
eval_tasksets[0].repeat_times = 4
91+
eval_tasksets[1].repeat_times = 4
9592
self.config.trainer.save_interval = 4
9693
self.config.check_and_update()
9794
_trainer_config = self.config.trainer.trainer_config
@@ -149,13 +146,12 @@ def test_trainer(self):
149146
self.config.check_and_update()
150147
bench(self.config)
151148
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
152-
eval_tasksets = self.config.buffer.explorer_input.eval_tasksets
153149
for prefix in ["eval", "bench"]:
154-
for eval_taskset, taskset_name in zip(eval_tasksets, ["countdown", "copy_countdown"]):
150+
for taskset_name in ["countdown", "copy_countdown"]:
155151
metrics = parser.metric_list(f"{prefix}/{taskset_name}")
156152
self.assertTrue(len(metrics) > 0)
157153
for eval_stats in ["mean", "best", "worst"]:
158-
for k in eval_taskset.eval_at_k:
154+
for k in [2, 4]:
159155
for stats in ["mean", "std"]:
160156
metric_name = f"{prefix}/{taskset_name}/score/{eval_stats}@{k}/{stats}"
161157
metric_steps = parser.metric_steps(metric_name)
@@ -973,7 +969,7 @@ def test_trainer(self):
973969
self.config.buffer.explorer_input.eval_tasksets.append(
974970
get_unittest_dataset_config("gsm8k", "test")
975971
)
976-
self.config.buffer.explorer_input.eval_tasksets[0].eval_at_k = [1, 2, 4]
972+
self.config.buffer.explorer_input.eval_tasksets[0].repeat_times = 8
977973
self.config.model.model_path = get_model_path()
978974
self.config.algorithm.algorithm_type = "grpo"
979975
self.config.algorithm.advantage_fn = "grpo"
@@ -1021,12 +1017,11 @@ def test_trainer(self):
10211017
self.config.check_and_update()
10221018
bench(self.config)
10231019
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
1024-
eval_taskset = self.config.buffer.explorer_input.eval_tasksets[0]
10251020
for prefix in ["eval", "bench"]:
10261021
gsm8k_metrics = parser.metric_list(f"{prefix}/gsm8k")
10271022
self.assertTrue(len(gsm8k_metrics) > 0)
10281023
for eval_stats in ["mean", "best", "worst"]:
1029-
for k in eval_taskset.eval_at_k:
1024+
for k in [2, 4, 8]:
10301025
for stats in ["mean", "std"]:
10311026
metric_name = f"{prefix}/gsm8k/accuracy/{eval_stats}@{k}/{stats}"
10321027
metric_steps = parser.metric_steps(metric_name)

trinity/buffer/schema/formatter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def format(self, sample: Dict) -> Task:
6565
workflow_args=self.config.workflow_args,
6666
reward_fn_args=self.config.reward_fn_args,
6767
is_eval=self.config.is_eval,
68-
eval_at_k=self.config.eval_at_k,
6968
raw_task=sample,
7069
)
7170

trinity/common/config.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class GenerationConfig:
8484
logprobs: Optional[int] = None # 0 # vLLM return `logprobs + 1` elements
8585
max_tokens: Optional[int] = None # if None, use model.max_response_tokens
8686
# repeat each task for `n` times
87-
# ! DO NOT SET, it will be set by `algorithm.repeat_times` or `max(buffer.explorer_input.eval_tasksets[i].eval_at_k)`
87+
# ! DO NOT SET, it will be set by `algorithm.repeat_times` or `max(buffer.explorer_input.eval_tasksets[i].repeat_times)`
8888
n: int = 1
8989

9090

@@ -196,7 +196,6 @@ class StorageConfig:
196196
workflow_args: dict = field(default_factory=dict)
197197
reward_fn_args: dict = field(default_factory=dict)
198198
task_selector: TaskSelectorConfig = field(default_factory=TaskSelectorConfig)
199-
eval_at_k: List[int] = field(default_factory=lambda: [1])
200199

201200
# enable progress bar (tqdm) for _HFBatchReader
202201
enable_progress_bar: Optional[bool] = False
@@ -238,7 +237,6 @@ class TasksetConfig:
238237
workflow_args: dict = field(default_factory=dict)
239238
reward_fn_args: dict = field(default_factory=dict)
240239
task_selector: TaskSelectorConfig = field(default_factory=TaskSelectorConfig)
241-
eval_at_k: List[int] = field(default_factory=lambda: [1])
242240

243241
# used for StorageType.FILE
244242
split: str = "train"
@@ -253,8 +251,8 @@ class TasksetConfig:
253251

254252
# ! DO NOT SET, automatically load from checkpoint
255253
index: int = 0
256-
# ! DO NOT SET, automatically set from `algorithm.repeat_times` or `max(buffer.explorer_input.eval_tasksets[i].eval_at_k)`
257-
repeat_times: Optional[int] = None
254+
# ! DO NOT SET in trainer_input, automatically set from `algorithm.repeat_times`
255+
repeat_times: int = 1
258256
# ! DO NOT SET, automatically set based on train/eval
259257
is_eval: bool = False
260258
# ! DO NOT SET, automatically set from buffer.batch_size
@@ -270,7 +268,6 @@ def to_storage_config(self) -> StorageConfig:
270268
storage_type=self.storage_type,
271269
path=self.path,
272270
task_selector=self.task_selector,
273-
eval_at_k=self.eval_at_k,
274271
repeat_times=self.repeat_times,
275272
split=self.split,
276273
subset_name=self.subset_name,
@@ -910,7 +907,6 @@ def _check_explorer_input(self) -> None:
910907
"`buffer.explorer_input.taskset.repeat_times` is set to `algorithm.repeat_times`"
911908
f" (={self.algorithm.repeat_times})."
912909
)
913-
taskset.eval_at_k = []
914910
taskset.total_epochs = self.buffer.total_epochs
915911
taskset.total_steps = self.buffer.total_steps
916912
taskset.batch_size = self.buffer.batch_size
@@ -931,9 +927,6 @@ def _check_explorer_input(self) -> None:
931927
dataset.batch_size = self.buffer.batch_size
932928
if not dataset.name:
933929
dataset.name = f"eval_taskset_{idx}"
934-
if len(dataset.eval_at_k) == 0:
935-
dataset.eval_at_k = [1]
936-
dataset.repeat_times = max(dataset.eval_at_k)
937930

938931
# eval_workflow has higher priority than workflow in eval tasksets, so we set it first
939932
set_if_none(dataset, "default_workflow_type", explorer_input.default_eval_workflow_type)

trinity/common/workflows/workflow.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ class Task(dict):
3030
workflow_args: dict = field(default_factory=dict)
3131
reward_fn_args: dict = field(default_factory=dict)
3232
is_eval: bool = False
33-
eval_at_k: List[int] = field(default_factory=lambda: [1])
3433
reward_fn: Optional[Type[RewardFn]] = None
3534
raw_task: Optional[dict] = None # The raw data sample
3635

trinity/explorer/scheduler.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,12 @@ class TaskWrapper:
3131
results: List[Tuple[Status, List[Experience]]] = field(default_factory=list)
3232

3333

34-
def calculate_task_level_metrics(
35-
metrics: List[Dict], is_eval: bool, eval_at_k: List[int]
36-
) -> Dict[str, float]:
34+
def calculate_task_level_metrics(metrics: List[Dict], is_eval: bool) -> Dict[str, float]:
3735
"""Calculate task level metrics (mean) from multiple runs of the same task.
3836
3937
Args:
4038
metrics (`List[Dict]`): A list of metric dictionaries from multiple runs of the same task.
4139
is_eval (`bool`): Whether this is an evaluation task.
42-
eval_at_k (`List[int]`): A list of k values to evaluate at.
4340
4441
Returns:
4542
`Dict[str, float]`: A dictionary of aggregated metrics, where each metric is averaged over all runs.
@@ -54,9 +51,13 @@ def calculate_task_level_metrics(
5451
if is_eval:
5552
result = {}
5653
for key, values in aggregated_metrics.items():
57-
for k in eval_at_k:
58-
if k > len(values):
59-
continue
54+
k_list = []
55+
k = 2
56+
while k < len(values):
57+
k_list.append(k)
58+
k *= 2
59+
k_list.append(len(values))
60+
for k in k_list:
6061
result[f"{key}/mean@{k}"] = sum(values[:k]) / k
6162
result[f"{key}/best@{k}"] = max(values[:k])
6263
result[f"{key}/worst@{k}"] = min(values[:k])
@@ -356,11 +357,7 @@ def task_done_callback(self, async_task: asyncio.Task):
356357
# calculate task level metrics
357358
task_status = Status(
358359
ok=all_success,
359-
metrics=[
360-
calculate_task_level_metrics(
361-
task_metrics, task.task.is_eval, task.task.eval_at_k
362-
)
363-
],
360+
metrics=[calculate_task_level_metrics(task_metrics, task.task.is_eval)],
364361
)
365362
self.completed_tasks[task.batch_id].appendleft((task_status, task_experiences))
366363
self.logger.debug(f"Task completed (batch_id {task.batch_id}).")

0 commit comments

Comments
 (0)