Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,11 @@ def test_trainer(self):
self.config.buffer.explorer_input.taskset.task_selector = TaskSelectorConfig(
selector_type="shuffle", seed=42
)
self.config.buffer.explorer_input.eval_tasksets.append(
get_unittest_dataset_config("countdown", "test")
)
self.config.buffer.explorer_input.eval_tasksets.append(
get_unittest_dataset_config("copy_countdown", "test")
)
eval_tasksets = self.config.buffer.explorer_input.eval_tasksets
eval_tasksets.append(get_unittest_dataset_config("countdown", "test"))
eval_tasksets.append(get_unittest_dataset_config("copy_countdown", "test"))
eval_tasksets[0].repeat_times = 4
eval_tasksets[1].repeat_times = 4
self.config.trainer.save_interval = 4
self.config.check_and_update()
_trainer_config = self.config.trainer.trainer_config
Expand Down Expand Up @@ -148,14 +147,15 @@ def test_trainer(self):
bench(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
for prefix in ["eval", "bench"]:
countdown_metrics = parser.metric_list(f"{prefix}/countdown")
copy_countdown_metrics = parser.metric_list(f"{prefix}/copy_countdown")
self.assertTrue(len(countdown_metrics) > 0)
self.assertTrue(len(copy_countdown_metrics) > 0)
countdown_metric_steps = parser.metric_steps(countdown_metrics[0])
countdown_copy_metric_steps = parser.metric_steps(copy_countdown_metrics[0])
self.assertEqual([0, 4, 8], countdown_metric_steps)
self.assertEqual([0, 4, 8], countdown_copy_metric_steps)
for taskset_name in ["countdown", "copy_countdown"]:
metrics = parser.metric_list(f"{prefix}/{taskset_name}")
self.assertTrue(len(metrics) > 0)
for eval_stats in ["mean", "best", "worst"]:
for k in [2, 4]:
for stats in ["mean", "std"]:
metric_name = f"{prefix}/{taskset_name}/score/{eval_stats}@{k}/{stats}"
metric_steps = parser.metric_steps(metric_name)
self.assertEqual(metric_steps, [0, 4, 8])

def tearDown(self):
# remove dir only when the test passed
Expand Down Expand Up @@ -969,6 +969,7 @@ def test_trainer(self):
self.config.buffer.explorer_input.eval_tasksets.append(
get_unittest_dataset_config("gsm8k", "test")
)
self.config.buffer.explorer_input.eval_tasksets[0].repeat_times = 8
self.config.model.model_path = get_model_path()
self.config.algorithm.algorithm_type = "grpo"
self.config.algorithm.advantage_fn = "grpo"
Expand Down Expand Up @@ -1019,8 +1020,12 @@ def test_trainer(self):
for prefix in ["eval", "bench"]:
gsm8k_metrics = parser.metric_list(f"{prefix}/gsm8k")
self.assertTrue(len(gsm8k_metrics) > 0)
gsm8k_metric_steps = parser.metric_steps(gsm8k_metrics[0])
self.assertEqual([0, 2], gsm8k_metric_steps)
for eval_stats in ["mean", "best", "worst"]:
for k in [2, 4, 8]:
for stats in ["mean", "std"]:
metric_name = f"{prefix}/gsm8k/accuracy/{eval_stats}@{k}/{stats}"
metric_steps = parser.metric_steps(metric_name)
self.assertEqual(metric_steps, [0, 2])

def tearDown(self):
shutil.rmtree(self.config.checkpoint_job_dir)
Expand Down
2 changes: 1 addition & 1 deletion trinity/buffer/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(self, config: StorageConfig):
total_epochs=self.config.total_epochs if not self.config.is_eval else 1,
offset=self.config.index,
drop_last=not self.config.is_eval,
total_steps=self.config.total_steps,
total_steps=self.config.total_steps if not self.config.is_eval else None,
enable_progress_bar=self.config.enable_progress_bar,
)
self.formatter = FORMATTER.get("task")(config)
Expand Down
6 changes: 3 additions & 3 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class GenerationConfig:
logprobs: Optional[int] = None # 0 # vLLM return `logprobs + 1` elements
max_tokens: Optional[int] = None # if None, use model.max_response_tokens
# repeat each task for `n` times
# ! DO NOT SET in `buffer.explorer_input.taskset.rollout_args`
# ! DO NOT SET, it will be set by `algorithm.repeat_times` or `max(buffer.explorer_input.eval_tasksets[i].repeat_times)`
n: int = 1


Expand Down Expand Up @@ -251,7 +251,7 @@ class TasksetConfig:

# ! DO NOT SET, automatically load from checkpoint
index: int = 0
# ! DO NOT SET, automatically set from algorithm.repeat_times
# ! DO NOT SET in trainer_input, automatically set from `algorithm.repeat_times`
repeat_times: int = 1
# ! DO NOT SET, automatically set based on train/eval
is_eval: bool = False
Expand Down Expand Up @@ -927,7 +927,7 @@ def _check_explorer_input(self) -> None:
dataset.batch_size = self.buffer.batch_size
if not dataset.name:
dataset.name = f"eval_taskset_{idx}"
set_if_none(dataset, "repeat_times", 1)

# eval_workflow has higher priority than workflow in eval tasksets, so we set it first
set_if_none(dataset, "default_workflow_type", explorer_input.default_eval_workflow_type)
set_if_none(dataset, "default_workflow_type", explorer_input.default_workflow_type)
Expand Down
4 changes: 3 additions & 1 deletion trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,9 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
metric[f"{prefix}/{eval_task_name}/finished_task_count"] = len(statuses)
metric.update(
gather_metrics(
[status.metrics[0] for status in statuses], f"{prefix}/{eval_task_name}"
[status.metrics[0] for status in statuses],
f"{prefix}/{eval_task_name}",
output_stats=["mean", "std"],
)
)
if self.eval_start_time is not None:
Expand Down
27 changes: 22 additions & 5 deletions trinity/explorer/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,15 @@ class TaskWrapper:
results: List[Tuple[Status, List[Experience]]] = field(default_factory=list)


def calculate_task_level_metrics(metrics: List[Dict]) -> Dict[str, float]:
def calculate_task_level_metrics(metrics: List[Dict], is_eval: bool) -> Dict[str, float]:
"""Calculate task level metrics (mean) from multiple runs of the same task.

Args:
metrics (`List[Dict]`): A list of metric dictionaries from multiple runs of the same task.
is_eval (`bool`): Whether this is an evaluation task.

Returns:
`Dict[str, float]`: A dictionary of aggregated metrics, where each metric is averaged over all runs.

TODO: support more aggregation methods like max, min.
"""
if not metrics:
return {}
Expand All @@ -49,7 +48,24 @@ def calculate_task_level_metrics(metrics: List[Dict]) -> Dict[str, float]:
for key, value in m.items():
if isinstance(value, (int, float)):
aggregated_metrics[key].append(value)
return {key: sum(values) / len(values) for key, values in aggregated_metrics.items() if values}
if is_eval:
result = {}
for key, values in aggregated_metrics.items():
k_list = []
k = 2
while k < len(values):
k_list.append(k)
k *= 2
k_list.append(len(values))
for k in k_list:
result[f"{key}/mean@{k}"] = sum(values[:k]) / k
result[f"{key}/best@{k}"] = max(values[:k])
result[f"{key}/worst@{k}"] = min(values[:k])
return result
else:
return {
key: sum(values) / len(values) for key, values in aggregated_metrics.items() if values
}


class RunnerWrapper:
Expand Down Expand Up @@ -340,7 +356,8 @@ def task_done_callback(self, async_task: asyncio.Task):
all_success = False
# calculate task level metrics
task_status = Status(
ok=all_success, metrics=[calculate_task_level_metrics(task_metrics)]
ok=all_success,
metrics=[calculate_task_level_metrics(task_metrics, task.task.is_eval)],
)
self.completed_tasks[task.batch_id].appendleft((task_status, task_experiences))
self.logger.debug(f"Task completed (batch_id {task.batch_id}).")
Expand Down
11 changes: 6 additions & 5 deletions trinity/utils/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,19 @@
MONITOR = Registry("monitor")


def gather_metrics(metric_list: List[Dict], prefix: str) -> Dict:
def gather_metrics(
metric_list: List[Dict], prefix: str, output_stats: List[str] = ["mean", "max", "min"]
) -> Dict:
if not metric_list:
return {}
try:
df = pd.DataFrame(metric_list)
numeric_df = df.select_dtypes(include=[np.number])
stats_df = numeric_df.agg(["mean", "max", "min"])
stats_df = numeric_df.agg(output_stats)
metric = {}
for col in stats_df.columns:
metric[f"{prefix}/{col}/mean"] = stats_df.loc["mean", col].item()
metric[f"{prefix}/{col}/max"] = stats_df.loc["max", col].item()
metric[f"{prefix}/{col}/min"] = stats_df.loc["min", col].item()
for stats in output_stats:
metric[f"{prefix}/{col}/{stats}"] = stats_df.loc[stats, col].item()
return metric
except Exception as e:
raise ValueError(f"Failed to gather metrics: {e}") from e
Expand Down