Skip to content

Commit b609626

Browse files
committed
makes requested changes
1 parent d96e395 commit b609626

File tree

1 file changed

+52
-77
lines changed

1 file changed

+52
-77
lines changed

scripts/evaluate_best_checkpoint.py

Lines changed: 52 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,42 @@
2020
app = typer.Typer()
2121

2222

23+
def print_metrics(result: dict, checkpoint_name: str = None, prefix: str = ""):
24+
"""
25+
Print formatted metrics for a checkpoint result.
26+
27+
Args:
28+
result: The evaluation result dictionary
29+
checkpoint_name: Optional checkpoint name to display
30+
prefix: Optional prefix for each line
31+
"""
32+
if checkpoint_name:
33+
print(f"{prefix}[bold]Leaderboard results[/bold]: {checkpoint_name}")
34+
print(f"{prefix}Overall: {result['overall_score'] * 100:.2f}%")
35+
if "leaderboard_bbh" in result:
36+
print(f"{prefix}BBH: {result['leaderboard_bbh']['score'] * 100:.2f}%")
37+
if "leaderboard_gpqa" in result:
38+
print(f"{prefix}GPQA: {result['leaderboard_gpqa']['score'] * 100:.2f}%")
39+
if "leaderboard_ifeval" in result:
40+
print(f"{prefix}IFEval: {result['leaderboard_ifeval']['score'] * 100:.2f}%")
41+
if "leaderboard_math_hard" in result:
42+
print(f"{prefix}MATH-Hard: {result['leaderboard_math_hard']['score'] * 100:.2f}%")
43+
if "leaderboard_mmlu_pro" in result:
44+
print(f"{prefix}MMLU-Pro: {result['leaderboard_mmlu_pro']['score'] * 100:.2f}%")
45+
if "leaderboard_musr" in result:
46+
print(f"{prefix}MUSR: {result['leaderboard_musr']['score'] * 100:.2f}%")
47+
48+
2349
@app.command()
2450
def best_checkpoint(
2551
input_dir: Path = typer.Argument(..., help="Input directory to process"),
2652
output_file: Optional[Path] = typer.Option(None, help="Optional output file path"),
27-
tasks: Annotated[Optional[list[str]], typer.Option()] = None,
53+
tasks: Annotated[Optional[list[str]], typer.Option(help="Specific tasks to evaluate (e.g., 'leaderboard_bbh', 'leaderboard_gpqa')")] = None,
54+
num_gpus: int = typer.Option(8, help="Number of GPUs to use for evaluation"),
2855
):
2956
"""
30-
Process files in the input directory and optionally save results to an output file.
57+
Find the best checkpoint by evaluating all checkpoints in the input directory.
58+
Processes all checkpoint subdirectories and ranks them by overall score.
3159
"""
3260
if not input_dir.exists():
3361
typer.echo(f"Error: Input directory '{input_dir}' does not exist")
@@ -55,7 +83,7 @@ def best_checkpoint(
5583
typer.echo(f"Processing checkpoint: {checkpoint}")
5684
ckpt_output_file = checkpoint / "leaderboard_results.json"
5785
evaluator = LeaderboardV2Evaluator(
58-
model_path=str(checkpoint), output_file=ckpt_output_file, num_gpus=8
86+
model_path=str(checkpoint), output_file=ckpt_output_file, num_gpus=num_gpus
5987
)
6088
if tasks:
6189
evaluator.tasks = tasks
@@ -72,28 +100,10 @@ def best_checkpoint(
72100
typer.echo(f"{'=' * 100}")
73101
# Add [BEST CHECKPOINT] label for the first checkpoint
74102
if i == 0:
75-
typer.echo(
76-
f"[bold]Leaderboard results[/bold]: {checkpoint_name} [bold green][BEST CHECKPOINT][/bold green]"
77-
)
103+
checkpoint_display = f"{checkpoint_name} [bold green][BEST CHECKPOINT][/bold green]"
78104
else:
79-
typer.echo(f"[bold]Leaderboard results[/bold]: {checkpoint_name}")
80-
typer.echo(f"Overall: {result['overall_score'] * 100:.2f}%")
81-
if "leaderboard_bbh" in result:
82-
typer.echo(f"BBH: {result['leaderboard_bbh']['score'] * 100:.2f}%")
83-
if "leaderboard_gpqa" in result:
84-
typer.echo(f"GPQA: {result['leaderboard_gpqa']['score'] * 100:.2f}%")
85-
if "leaderboard_ifeval" in result:
86-
typer.echo(f"IFEval: {result['leaderboard_ifeval']['score'] * 100:.2f}%")
87-
if "leaderboard_math_hard" in result:
88-
typer.echo(
89-
f"MATH-Hard: {result['leaderboard_math_hard']['score'] * 100:.2f}%"
90-
)
91-
if "leaderboard_mmlu_pro" in result:
92-
typer.echo(
93-
f"MMLU-Pro: {result['leaderboard_mmlu_pro']['score'] * 100:.2f}%"
94-
)
95-
if "leaderboard_musr" in result:
96-
typer.echo(f"MUSR: {result['leaderboard_musr']['score'] * 100:.2f}%")
105+
checkpoint_display = checkpoint_name
106+
print_metrics(result, checkpoint_display)
97107

98108
typer.echo(f"{'=' * 100}")
99109
typer.echo(
@@ -113,10 +123,12 @@ def best_checkpoint(
113123
@app.command()
114124
def evaluate(
115125
input_dir: Path = typer.Argument(..., help="Input directory to process"),
116-
tasks: Annotated[Optional[list[str]], typer.Option()] = None,
126+
tasks: Annotated[Optional[list[str]], typer.Option(help="Specific tasks to evaluate (e.g., 'leaderboard_bbh', 'leaderboard_gpqa')")] = None,
127+
num_gpus: int = typer.Option(8, help="Number of GPUs to use for evaluation"),
128+
output_file: Optional[Path] = typer.Option(None, help="Custom output file path (default: input_dir/leaderboard_results.json)"),
117129
):
118130
"""
119-
Process files in the input directory and optionally save results to an output file.
131+
Evaluate a single checkpoint directory and save results to JSON file.
120132
"""
121133
if not input_dir.exists():
122134
typer.echo(f"Error: Input directory '{input_dir}' does not exist")
@@ -133,30 +145,25 @@ def evaluate(
133145
typer.echo("done")
134146

135147
evaluator = LeaderboardV2Evaluator(
136-
model_path=str(input_dir), num_gpus=8, eval_config={"batch_size": "auto"}
148+
model_path=str(input_dir), num_gpus=num_gpus, eval_config={"batch_size": "auto"}
137149
)
138150
if tasks:
139151
evaluator.tasks = tasks
140152
result = evaluator.run()
141153

142154
# now just print out the checkpoint results
143-
print(f"[bold]Leaderboard results[/bold]: {input_dir}")
144-
print(f"Overall: {result['overall_score'] * 100:.2f}%")
145-
if "leaderboard_bbh" in result:
146-
print(f"BBH: {result['leaderboard_bbh']['score'] * 100:.2f}%")
147-
if "leaderboard_gpqa" in result:
148-
print(f"GPQA: {result['leaderboard_gpqa']['score'] * 100:.2f}%")
149-
if "leaderboard_ifeval" in result:
150-
print(f"IFEval: {result['leaderboard_ifeval']['score'] * 100:.2f}%")
151-
if "leaderboard_math_hard" in result:
152-
print(f"MATH-Hard: {result['leaderboard_math_hard']['score'] * 100:.2f}%")
153-
if "leaderboard_mmlu_pro" in result:
154-
print(f"MMLU-Pro: {result['leaderboard_mmlu_pro']['score'] * 100:.2f}%")
155-
if "leaderboard_musr" in result:
156-
print(f"MUSR: {result['leaderboard_musr']['score'] * 100:.2f}%")
157-
158-
output_file = input_dir / "leaderboard_results.json"
155+
print_metrics(result, str(input_dir))
156+
157+
# Determine output file path
158+
if output_file is None:
159+
output_file = input_dir / "leaderboard_results.json"
160+
161+
# Check if file exists and warn user
162+
if output_file.exists():
163+
typer.echo(f"Warning: Output file '{output_file}' already exists and will be overwritten")
164+
159165
output_file.write_text(json.dumps(result, indent=2))
166+
typer.echo(f"Results saved to: {output_file}")
160167

161168

162169
@app.command()
@@ -218,43 +225,11 @@ def find_best(
218225
is_best = checkpoint == best_checkpoint
219226
prefix = "→ " if is_best else " "
220227
print(f"\n{prefix}Checkpoint: {checkpoint}")
221-
print(f" Overall score: {score * 100:.2f}%")
222-
if "leaderboard_bbh" in results:
223-
print(f" BBH: {results['leaderboard_bbh']['score'] * 100:.2f}%")
224-
if "leaderboard_gpqa" in results:
225-
print(f" GPQA: {results['leaderboard_gpqa']['score'] * 100:.2f}%")
226-
if "leaderboard_ifeval" in results:
227-
print(f" IFEval: {results['leaderboard_ifeval']['score'] * 100:.2f}%")
228-
if "leaderboard_math_hard" in results:
229-
print(
230-
f" MATH-Hard: {results['leaderboard_math_hard']['score'] * 100:.2f}%"
231-
)
232-
if "leaderboard_mmlu_pro" in results:
233-
print(
234-
f" MMLU-Pro: {results['leaderboard_mmlu_pro']['score'] * 100:.2f}%"
235-
)
236-
if "leaderboard_musr" in results:
237-
print(f" MUSR: {results['leaderboard_musr']['score'] * 100:.2f}%")
228+
print_metrics(results, prefix=" ")
238229
else:
239230
# Print only best results
240231
print(f"\n[bold]Best checkpoint found[/bold]: {best_checkpoint}")
241-
print(f"Overall score: {best_score * 100:.2f}%")
242-
if "leaderboard_bbh" in best_results:
243-
print(f"BBH: {best_results['leaderboard_bbh']['score'] * 100:.2f}%")
244-
if "leaderboard_gpqa" in best_results:
245-
print(f"GPQA: {best_results['leaderboard_gpqa']['score'] * 100:.2f}%")
246-
if "leaderboard_ifeval" in best_results:
247-
print(f"IFEval: {best_results['leaderboard_ifeval']['score'] * 100:.2f}%")
248-
if "leaderboard_math_hard" in best_results:
249-
print(
250-
f"MATH-Hard: {best_results['leaderboard_math_hard']['score'] * 100:.2f}%"
251-
)
252-
if "leaderboard_mmlu_pro" in best_results:
253-
print(
254-
f"MMLU-Pro: {best_results['leaderboard_mmlu_pro']['score'] * 100:.2f}%"
255-
)
256-
if "leaderboard_musr" in best_results:
257-
print(f"MUSR: {best_results['leaderboard_musr']['score'] * 100:.2f}%")
232+
print_metrics(best_results)
258233

259234

260235
if __name__ == "__main__":

0 commit comments

Comments
 (0)