Skip to content

Commit 02dbe30

Browse files
committed
add detailed_stats
1 parent 7037ce6 commit 02dbe30

File tree

5 files changed

+66
-42
lines changed

5 files changed

+66
-42
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ data = [
5656
"py-data-juicer>=1.4.3"
5757
]
5858
agent = [
59-
"agentscope>=1.0.9"
59+
"agentscope>=1.0.12"
6060
]
6161
rm_gallery = [
6262
"rm-gallery>=0.1.5"

tests/explorer/explorer_test.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def setUp(self):
4343
self.config.checkpoint_root_dir = get_checkpoint_path()
4444
self.config.synchronizer.sync_interval = 2
4545
self.config.explorer.eval_interval = 4
46+
self.config.monitor.detailed_stats = False
4647

4748

4849
class TestExplorerCountdownEval(BaseExplorerCase):
@@ -70,21 +71,48 @@ def test_explorer(self):
7071
self.assertEqual(parser.metric_max_step(eval_metrics[0]), 8)
7172
for eval_taskset, k_list in zip(eval_tasksets, [[1], [2, 4, 6], [2, 4, 8, 10]]):
7273
metric_name = "score" if eval_taskset.name == "countdown" else "accuracy"
73-
for eval_stats in ["mean", "std"]:
74-
k = k_list[-1]
74+
repeat_times = k_list[-1]
75+
expected_stat_suffixes = [f"mean@{repeat_times}", f"std@{repeat_times}"]
76+
for k in k_list:
77+
if k == 1:
78+
continue
79+
expected_stat_suffixes.extend([f"best@{k}", f"worst@{k}"])
80+
# only return the mean of the column
81+
for stat_suffix in expected_stat_suffixes:
7582
self.assertIn(
76-
f"eval/{eval_taskset.name}/{metric_name}/{eval_stats}@{k}",
83+
f"eval/{eval_taskset.name}/{metric_name}/{stat_suffix}",
84+
eval_metrics,
85+
)
86+
87+
88+
class TestExplorerEvalDetailedStats(BaseExplorerCase):
89+
def test_explorer(self):
90+
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
91+
self.config.monitor.detailed_stats = True
92+
eval_taskset = get_unittest_dataset_config("eval_short")
93+
eval_taskset.repeat_times = 6
94+
self.config.buffer.explorer_input.eval_tasksets = [eval_taskset]
95+
self.config.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
96+
self.config.check_and_update()
97+
explore(self.config)
98+
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
99+
rollout_metrics = parser.metric_list("rollout")
100+
self.assertTrue(len(rollout_metrics) > 0)
101+
eval_metrics = parser.metric_list("eval")
102+
self.assertTrue(len(eval_metrics) > 0)
103+
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)
104+
self.assertEqual(parser.metric_max_step(eval_metrics[0]), 8)
105+
metric_name, repeat_times, k_list = "accuracy", 6, [2, 4, 6]
106+
expected_stat_suffixes = [f"mean@{repeat_times}", f"std@{repeat_times}"]
107+
for k in k_list: # k_list does not include 1
108+
expected_stat_suffixes.extend([f"best@{k}", f"worst@{k}"])
109+
# test detailed stats
110+
for stat_suffix in expected_stat_suffixes:
111+
for stats in ["mean", "std", "max", "min"]:
112+
self.assertIn(
113+
f"eval/{eval_taskset.name}/{metric_name}/{stat_suffix}/{stats}",
77114
eval_metrics,
78115
)
79-
for eval_stats in ["best", "worst"]:
80-
for k in k_list:
81-
if k == 1:
82-
continue
83-
for stats in ["mean", "std"]:
84-
self.assertIn(
85-
f"eval/{eval_taskset.name}/{metric_name}/{eval_stats}@{k}/{stats}",
86-
eval_metrics,
87-
)
88116

89117

90118
class TestExplorerGSM8KRULERNoEval(BaseExplorerCase):

tests/trainer/trainer_test.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -172,19 +172,14 @@ def test_trainer(self):
172172
for taskset_name in ["countdown", "copy_countdown"]:
173173
metrics = parser.metric_list(f"{prefix}/{taskset_name}")
174174
self.assertGreater(len(metrics), 0, f"{prefix}/{taskset_name} metrics not found")
175-
# mean@k, std@k
176-
for eval_stats in ["mean", "std"]:
177-
k = 4
178-
metric_name = f"{prefix}/{taskset_name}/score/{eval_stats}@{k}"
175+
repeat_times, k_list = 4, [2, 4]
176+
expected_stat_suffixes = [f"mean@{repeat_times}", f"std@{repeat_times}"]
177+
for k in k_list:
178+
expected_stat_suffixes.extend([f"best@{k}", f"worst@{k}"])
179+
for stat_suffix in expected_stat_suffixes:
180+
metric_name = f"{prefix}/{taskset_name}/score/{stat_suffix}"
179181
metric_steps = parser.metric_steps(metric_name)
180182
self.assertEqual(metric_steps, [0, 4, 8])
181-
# best@k/mean, best@k/std, worst@k/mean, worst@k/std
182-
for eval_stats in ["best", "worst"]:
183-
for k in [2, 4]:
184-
for stats in ["mean", "std"]:
185-
metric_name = f"{prefix}/{taskset_name}/score/{eval_stats}@{k}/{stats}"
186-
metric_steps = parser.metric_steps(metric_name)
187-
self.assertEqual(metric_steps, [0, 4, 8])
188183

189184
def tearDown(self):
190185
# remove dir only when the test passed
@@ -1345,19 +1340,14 @@ def test_trainer(self):
13451340
for prefix in ["eval", "bench"]:
13461341
gsm8k_metrics = parser.metric_list(f"{prefix}/gsm8k")
13471342
self.assertGreater(len(gsm8k_metrics), 0, f"{prefix}/gsm8k metrics not found")
1348-
# mean@k, std@k
1349-
for eval_stats in ["mean", "std"]:
1350-
k = 8
1351-
metric_name = f"{prefix}/gsm8k/accuracy/{eval_stats}@{k}"
1343+
repeat_times, k_list = 8, [2, 4, 8]
1344+
expected_stat_suffixes = [f"mean@{repeat_times}", f"std@{repeat_times}"]
1345+
for k in k_list:
1346+
expected_stat_suffixes.extend([f"best@{k}", f"worst@{k}"])
1347+
for stat_suffix in expected_stat_suffixes:
1348+
metric_name = f"{prefix}/gsm8k/accuracy/{stat_suffix}"
13521349
metric_steps = parser.metric_steps(metric_name)
13531350
self.assertEqual(metric_steps, [0, 2])
1354-
# best@k/mean, best@k/std, worst@k/mean, worst@k/std
1355-
for eval_stats in ["best", "worst"]:
1356-
for k in [2, 4, 8]:
1357-
for stats in ["mean", "std"]:
1358-
metric_name = f"{prefix}/gsm8k/accuracy/{eval_stats}@{k}/{stats}"
1359-
metric_steps = parser.metric_steps(metric_name)
1360-
self.assertEqual(metric_steps, [0, 2])
13611351

13621352
def tearDown(self):
13631353
shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True)

trinity/explorer/explorer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(self, config: Config):
6666
role=self.config.explorer.name,
6767
config=config,
6868
)
69+
self.detailed_stats = config.monitor.detailed_stats
6970
if config.explorer.over_rollout.ratio > 0.0:
7071
self.min_wait_num = math.ceil(
7172
config.buffer.batch_size * (1 - config.explorer.over_rollout.ratio)
@@ -432,7 +433,9 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
432433
metric[f"{prefix}/{eval_task_name}/finished_task_count"] = len(statuses)
433434
metric.update(
434435
gather_eval_metrics(
435-
[status.metrics[0] for status in statuses], f"{prefix}/{eval_task_name}"
436+
[status.metrics[0] for status in statuses],
437+
f"{prefix}/{eval_task_name}",
438+
detailed_stats=self.detailed_stats,
436439
)
437440
)
438441
if self.eval_start_time is not None:

trinity/utils/monitor.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@ def gather_metrics(
5858

5959

6060
def gather_eval_metrics(
61-
metric_list: List[Dict], prefix: str, output_stats: List[str] = ["mean", "max", "min"]
61+
metric_list: List[Dict],
62+
prefix: str,
63+
output_stats: List[str] = ["mean", "max", "min", "std"],
64+
detailed_stats: bool = False,
6265
) -> Dict:
6366
if not metric_list:
6467
return {}
@@ -67,14 +70,14 @@ def gather_eval_metrics(
6770
numeric_df = df.select_dtypes(include=[np.number])
6871
metric = {}
6972
for col in numeric_df.columns:
70-
# Skip the columns that are already aggregated
71-
key_words = ["std", "mean", "min", "max"]
72-
if any(key_word in col.lower() for key_word in key_words):
73-
metric[f"{prefix}/{col}"] = numeric_df[col].mean()
74-
else:
73+
if detailed_stats:
7574
stats_df = numeric_df[[col]].agg(output_stats)
7675
for stats in output_stats:
7776
metric[f"{prefix}/{col}/{stats}"] = stats_df.loc[stats, col].item()
77+
else:
78+
# only return the mean of the column
79+
metric[f"{prefix}/{col}"] = numeric_df[col].mean()
80+
7881
return metric
7982
except Exception as e:
8083
raise ValueError(f"Failed to gather eval metrics: {e}") from e

0 commit comments

Comments
 (0)