Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/grpo_frozen_lake/frozen_lake.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ buffer:
env_max_steps: 8
agent_max_steps: 10
is_slippery: false
repeat_times: 4
rollout_args:
n: 4
top_p: 0.8
top_k: 20
default_workflow_type: 'frozen_lake_workflow'
Expand Down
16 changes: 13 additions & 3 deletions tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,16 @@ def setUp(self):
class TestExplorerCountdownEval(BaseExplorerCase):
def test_explorer(self):
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
self.config.buffer.explorer_input.eval_tasksets.extend(
eval_tasksets = self.config.buffer.explorer_input.eval_tasksets
eval_tasksets.extend(
[
get_unittest_dataset_config("countdown", "test"),
get_unittest_dataset_config("eval_short"),
get_unittest_dataset_config("eval_long"),
]
)
eval_tasksets[1].repeat_times = 6
eval_tasksets[2].repeat_times = 10
self.config.name = f"explore-eval-{datetime.now().strftime('%Y%m%d%H%M%S')}"
self.config.check_and_update()
explore(self.config)
Expand All @@ -65,8 +68,15 @@ def test_explorer(self):
self.assertTrue(len(eval_metrics) > 0)
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)
self.assertEqual(parser.metric_max_step(eval_metrics[0]), 8)
self.assertTrue("eval/eval_short/accuracy/max" in eval_metrics)
self.assertTrue("eval/eval_long/accuracy/max" in eval_metrics)
for eval_taskset, k_list in zip(eval_tasksets, [[1], [2, 4, 6], [2, 4, 8, 10]]):
for eval_stats in ["mean", "best", "worst"]:
for k in k_list:
for stats in ["mean", "std"]:
metric_name = "score" if eval_taskset.name == "countdown" else "accuracy"
self.assertIn(
f"eval/{eval_taskset.name}/{metric_name}/{eval_stats}@{k}/{stats}",
eval_metrics,
)


class TestExplorerGSM8KRULERNoEval(BaseExplorerCase):
Expand Down
37 changes: 21 additions & 16 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,11 @@ def test_trainer(self):
self.config.buffer.explorer_input.taskset.task_selector = TaskSelectorConfig(
selector_type="shuffle", seed=42
)
self.config.buffer.explorer_input.eval_tasksets.append(
get_unittest_dataset_config("countdown", "test")
)
self.config.buffer.explorer_input.eval_tasksets.append(
get_unittest_dataset_config("copy_countdown", "test")
)
eval_tasksets = self.config.buffer.explorer_input.eval_tasksets
eval_tasksets.append(get_unittest_dataset_config("countdown", "test"))
eval_tasksets.append(get_unittest_dataset_config("copy_countdown", "test"))
eval_tasksets[0].repeat_times = 4
eval_tasksets[1].repeat_times = 4
self.config.trainer.save_interval = 4
self.config.check_and_update()
_trainer_config = self.config.trainer.trainer_config
Expand Down Expand Up @@ -148,14 +147,15 @@ def test_trainer(self):
bench(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
for prefix in ["eval", "bench"]:
countdown_metrics = parser.metric_list(f"{prefix}/countdown")
copy_countdown_metrics = parser.metric_list(f"{prefix}/copy_countdown")
self.assertTrue(len(countdown_metrics) > 0)
self.assertTrue(len(copy_countdown_metrics) > 0)
countdown_metric_steps = parser.metric_steps(countdown_metrics[0])
countdown_copy_metric_steps = parser.metric_steps(copy_countdown_metrics[0])
self.assertEqual([0, 4, 8], countdown_metric_steps)
self.assertEqual([0, 4, 8], countdown_copy_metric_steps)
for taskset_name in ["countdown", "copy_countdown"]:
metrics = parser.metric_list(f"{prefix}/{taskset_name}")
self.assertTrue(len(metrics) > 0)
for eval_stats in ["mean", "best", "worst"]:
for k in [2, 4]:
for stats in ["mean", "std"]:
metric_name = f"{prefix}/{taskset_name}/score/{eval_stats}@{k}/{stats}"
metric_steps = parser.metric_steps(metric_name)
self.assertEqual(metric_steps, [0, 4, 8])

def tearDown(self):
# remove dir only when the test passed
Expand Down Expand Up @@ -969,6 +969,7 @@ def test_trainer(self):
self.config.buffer.explorer_input.eval_tasksets.append(
get_unittest_dataset_config("gsm8k", "test")
)
self.config.buffer.explorer_input.eval_tasksets[0].repeat_times = 8
self.config.model.model_path = get_model_path()
self.config.algorithm.algorithm_type = "grpo"
self.config.algorithm.advantage_fn = "grpo"
Expand Down Expand Up @@ -1019,8 +1020,12 @@ def test_trainer(self):
for prefix in ["eval", "bench"]:
gsm8k_metrics = parser.metric_list(f"{prefix}/gsm8k")
self.assertTrue(len(gsm8k_metrics) > 0)
gsm8k_metric_steps = parser.metric_steps(gsm8k_metrics[0])
self.assertEqual([0, 2], gsm8k_metric_steps)
for eval_stats in ["mean", "best", "worst"]:
for k in [2, 4, 8]:
for stats in ["mean", "std"]:
metric_name = f"{prefix}/gsm8k/accuracy/{eval_stats}@{k}/{stats}"
metric_steps = parser.metric_steps(metric_name)
self.assertEqual(metric_steps, [0, 2])

def tearDown(self):
shutil.rmtree(self.config.checkpoint_job_dir)
Expand Down
2 changes: 1 addition & 1 deletion trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(self, config: StorageConfig):
total_epochs=self.config.total_epochs if not self.config.is_eval else 1,
offset=self.config.index,
drop_last=not self.config.is_eval,
total_steps=self.config.total_steps,
total_steps=self.config.total_steps if not self.config.is_eval else None,
enable_progress_bar=self.config.enable_progress_bar,
)
self.formatter = FORMATTER.get("task")(config)
Expand Down
8 changes: 4 additions & 4 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class GenerationConfig:
logprobs: Optional[int] = None # 0 # vLLM return `logprobs + 1` elements
max_tokens: Optional[int] = None # if None, use model.max_response_tokens
# repeat each task for `n` times
# ! DO NOT SET in `buffer.explorer_input.taskset.rollout_args`
# ! DO NOT SET, it will be set by `algorithm.repeat_times` or `buffer.explorer_input.eval_tasksets[i].repeat_times`
n: int = 1


Expand Down Expand Up @@ -249,10 +249,10 @@ class TasksetConfig:

enable_progress_bar: bool = False

# ! This setting is only valid for `eval_taskset`; for other taskset, it will be overridden by `algorithm.repeat_times`.
repeat_times: int = 1
# ! DO NOT SET, automatically load from checkpoint
index: int = 0
# ! DO NOT SET, automatically set from algorithm.repeat_times
repeat_times: int = 1
# ! DO NOT SET, automatically set based on train/eval
is_eval: bool = False
# ! DO NOT SET, automatically set from buffer.batch_size
Expand Down Expand Up @@ -927,7 +927,7 @@ def _check_explorer_input(self) -> None:
dataset.batch_size = self.buffer.batch_size
if not dataset.name:
dataset.name = f"eval_taskset_{idx}"
set_if_none(dataset, "repeat_times", 1)

# eval_workflow has higher priority than workflow in eval tasksets, so we set it first
set_if_none(dataset, "default_workflow_type", explorer_input.default_eval_workflow_type)
set_if_none(dataset, "default_workflow_type", explorer_input.default_workflow_type)
Expand Down
4 changes: 3 additions & 1 deletion trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,9 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
metric[f"{prefix}/{eval_task_name}/finished_task_count"] = len(statuses)
metric.update(
gather_metrics(
[status.metrics[0] for status in statuses], f"{prefix}/{eval_task_name}"
[status.metrics[0] for status in statuses],
f"{prefix}/{eval_task_name}",
output_stats=["mean", "std"],
)
)
if self.eval_start_time is not None:
Expand Down
30 changes: 25 additions & 5 deletions trinity/explorer/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,15 @@ class TaskWrapper:
results: List[Tuple[Status, List[Experience]]] = field(default_factory=list)


def calculate_task_level_metrics(metrics: List[Dict]) -> Dict[str, float]:
def calculate_task_level_metrics(metrics: List[Dict], is_eval: bool) -> Dict[str, float]:
"""Calculate task level metrics (mean) from multiple runs of the same task.

Args:
metrics (`List[Dict]`): A list of metric dictionaries from multiple runs of the same task.
is_eval (`bool`): Whether this is an evaluation task.

Returns:
`Dict[str, float]`: A dictionary of aggregated metrics, where each metric is averaged over all runs.

TODO: support more aggregation methods like max, min.
"""
if not metrics:
return {}
Expand All @@ -49,7 +48,27 @@ def calculate_task_level_metrics(metrics: List[Dict]) -> Dict[str, float]:
for key, value in m.items():
if isinstance(value, (int, float)):
aggregated_metrics[key].append(value)
return {key: sum(values) / len(values) for key, values in aggregated_metrics.items() if values}
if is_eval:
result = {}
for key, values in aggregated_metrics.items():
if "time/task_execution" in key or "time/run_execution" in key:
result[key] = sum(values) / len(values)
continue
k_list = []
k = 2
while k < len(values):
k_list.append(k)
k *= 2
k_list.append(len(values))
for k in k_list:
result[f"{key}/mean@{k}"] = sum(values[:k]) / k
result[f"{key}/best@{k}"] = max(values[:k])
result[f"{key}/worst@{k}"] = min(values[:k])
return result
else:
return {
key: sum(values) / len(values) for key, values in aggregated_metrics.items() if values
}


class RunnerWrapper:
Expand Down Expand Up @@ -340,7 +359,8 @@ def task_done_callback(self, async_task: asyncio.Task):
all_success = False
# calculate task level metrics
task_status = Status(
ok=all_success, metrics=[calculate_task_level_metrics(task_metrics)]
ok=all_success,
metrics=[calculate_task_level_metrics(task_metrics, task.task.is_eval)],
)
self.completed_tasks[task.batch_id].appendleft((task_status, task_experiences))
self.logger.debug(f"Task completed (batch_id {task.batch_id}).")
Expand Down
6 changes: 3 additions & 3 deletions trinity/explorer/workflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def _run_task(
# repeatable workflow cannot calculate run level metrics, we use experience level metrics directly
run_metrics = [exp.metrics for exp in exps if exp.metrics]
for metric in run_metrics:
metric["time/task_execution"] = et - st
metric["time/run_execution"] = et - st
else:
exps = []
run_metrics = []
Expand All @@ -155,7 +155,7 @@ async def _run_task(
et = time.time()
self.runner_state["terminate_time"] = et
run_metric = calculate_run_level_metrics(new_exps)
run_metric["time/task_execution"] = et - st
run_metric["time/run_execution"] = et - st
run_metrics.append(run_metric)
for exp in new_exps:
exp.eid.run = run_id_base + i
Expand Down Expand Up @@ -209,7 +209,7 @@ async def run_task(
error_trace_back = traceback.format_exc()
self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}")
return (
Status(False, metrics=[{"time/task_execution": time.time() - st}], message=str(e)),
Status(False, metrics=[{"time/run_execution": time.time() - st}], message=str(e)),
[],
)

Expand Down
11 changes: 6 additions & 5 deletions trinity/utils/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,19 @@
MONITOR = Registry("monitor")


def gather_metrics(metric_list: List[Dict], prefix: str) -> Dict:
def gather_metrics(
metric_list: List[Dict], prefix: str, output_stats: List[str] = ["mean", "max", "min"]
) -> Dict:
if not metric_list:
return {}
try:
df = pd.DataFrame(metric_list)
numeric_df = df.select_dtypes(include=[np.number])
stats_df = numeric_df.agg(["mean", "max", "min"])
stats_df = numeric_df.agg(output_stats)
metric = {}
for col in stats_df.columns:
metric[f"{prefix}/{col}/mean"] = stats_df.loc["mean", col].item()
metric[f"{prefix}/{col}/max"] = stats_df.loc["max", col].item()
metric[f"{prefix}/{col}/min"] = stats_df.loc["min", col].item()
for stats in output_stats:
metric[f"{prefix}/{col}/{stats}"] = stats_df.loc[stats, col].item()
return metric
except Exception as e:
raise ValueError(f"Failed to gather metrics: {e}") from e
Expand Down