Skip to content

Commit 9077331

Browse files
authored
Add option to re-score a specific index for score command (#82)
Adds a CLI flag to allow re-scoring a particular set of indices for `score` Usage: `skythought score --task <task> --run-dir <path/to/run-dir> --ids id1,id2,id3`
1 parent 2ff6858 commit 9077331

File tree

5 files changed

+86
-13
lines changed

5 files changed

+86
-13
lines changed

skythought/evals/cli.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import os
44
from pathlib import Path
5-
from typing import Tuple
5+
from typing import List, Tuple
66

77
import click
88
import typer
@@ -20,7 +20,11 @@
2020
)
2121
from skythought.evals.models import ModelConfig, get_system_prompt_keys
2222
from skythought.evals.tasks import TASK_HANDLER_MAP, TASK_NAMES_TO_YAML, TaskConfig
23-
from skythought.evals.util.cli_util import get_deterministic_hash, parse_multi_args
23+
from skythought.evals.util.cli_util import (
24+
comma_separated_to_list,
25+
get_deterministic_hash,
26+
parse_multi_args,
27+
)
2428
from skythought.evals.util.common import set_seed
2529
from skythought.evals.util.results import SummaryResults
2630

@@ -537,12 +541,24 @@ def score(
537541
case_sensitive=False,
538542
),
539543
],
544+
ids: Annotated[
545+
str,
546+
typer.Option(
547+
help="Comma-separated list of indices in the results JSON to re-score."
548+
"If provided, only the scores for these samples are computed/re-computed. If None, we compute scores for all samples",
549+
),
550+
] = None,
540551
):
541552
if not os.path.exists(run_dir):
542553
raise ValueError(f"Run directory {run_dir} does not exist.")
543554

544555
run_dir = Path(run_dir)
545556

557+
if ids:
558+
ids: List[str] = comma_separated_to_list(ids)
559+
# make them unique
560+
ids = list(set(ids))
561+
546562
if task not in TASK_NAMES_TO_YAML:
547563
raise ValueError(
548564
f"Task {task} not found. Should be one of {TASK_NAMES_TO_YAML.keys()}"
@@ -563,7 +579,7 @@ def score(
563579

564580
run_summary = SummaryResults(**run_summary)
565581

566-
score_results(handler, run_dir, run_summary)
582+
score_results(handler, run_dir, run_summary, ids)
567583

568584

569585
def main():

skythought/evals/inference_and_check.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def generate_responses_for_dataset(
294294
def score_responses(
295295
handler: TaskHandler,
296296
id_to_results: Dict[str, Dict[str, Any]],
297+
*,
297298
max_workers: int = 32,
298299
) -> Tuple[float, Dict[str, List[int]], int]:
299300
"""Computes correctness for model responses for the given task
@@ -341,7 +342,7 @@ def score_responses(
341342
# TODO (sumanthrh): this can be improved
342343
if unique_id not in id_to_scores:
343344
id_to_scores[unique_id] = [0 for _ in range(N)]
344-
id_to_scores[unique_id][i] = new_response_entry["correctness"]
345+
id_to_scores[unique_id][i] = int(new_response_entry["correctness"])
345346

346347
total_correct += new_response_entry["correctness"]
347348
total_finish += 1
@@ -350,6 +351,41 @@ def score_responses(
350351
return accuracy, id_to_scores, total_finish
351352

352353

354+
def score_responses_for_indices(
355+
handler: TaskHandler,
356+
id_to_results: Dict[str, Dict[str, Any]],
357+
*,
358+
indices: List[str],
359+
) -> List[int]:
360+
"""Computes correctness for model responses for the given task for the unique index `idx`.
361+
362+
The 'id_to_results' dictionary is assumed to be a mapping between problem ID -> { responses: [...], ... },
363+
This is updated in-place.
364+
365+
Returns:
366+
- list of scores
367+
"""
368+
if not id_to_results:
369+
return []
370+
logger.info(f"Computing scores for {len(indices)} samples")
371+
for idx in indices:
372+
# Figure out how many generations per problem
373+
N = len(next(iter(id_to_results.values()))["responses"])
374+
record = id_to_results[idx]
375+
scores = []
376+
for i in range(N):
377+
content = record["responses"][i]["content"]
378+
response_entry = handler.update_results(record, content)
379+
380+
# Update correctness and reason in the original results dict
381+
id_to_results[idx]["responses"][i]["correctness"] = response_entry[
382+
"correctness"
383+
]
384+
id_to_results[idx]["responses"][i]["reason"] = response_entry["reason"]
385+
scores.append(response_entry["correctness"])
386+
return scores
387+
388+
353389
def generate_and_score(
354390
handler: TaskHandler,
355391
model_config: ModelConfig,
@@ -480,17 +516,31 @@ def generate_and_save(
480516

481517

482518
def score_results(
483-
handler: TaskHandler, run_dir: Path, run_summary: SummaryResults
519+
handler: TaskHandler,
520+
run_dir: Path,
521+
run_summary: SummaryResults,
522+
indices: Optional[List[str]] = None,
484523
) -> None:
485524
# load existing results
486525
result_file = run_dir / RESULTS_FILENAME
487526
summary_file = run_dir / SUMMARY_FILENAME
488527
id_to_results = load_existing_results(result_file)
489528
logger.info(f"Loaded {len(id_to_results)} existing results for scoring.")
490529

491-
accuracy, id_to_scores, total_finish = score_responses(handler, id_to_results)
492-
493-
logger.info(f"Accuracy: {accuracy}")
530+
if not indices:
531+
accuracy, id_to_scores, total_finish = score_responses(handler, id_to_results)
532+
else:
533+
N = len(next(iter(id_to_results.values()))["responses"])
534+
score_responses_for_indices(handler, id_to_results, indices=indices)
535+
id_to_scores = {
536+
index: [
537+
id_to_results[index]["responses"][i]["correctness"] for i in range(N)
538+
]
539+
for index in id_to_results
540+
}
541+
accuracy = round(
542+
sum(map(sum, id_to_scores.values())) / (len(id_to_scores) * N), 4
543+
)
494544

495545
sample_count = 0
496546
if id_to_results:
@@ -501,7 +551,9 @@ def score_results(
501551

502552
run_summary.accuracy = accuracy
503553
run_summary.pass_at_k = pass_at_k_metrics
554+
555+
logger.info(f"Accuracy: {accuracy}")
504556
save_summary(summary_file, run_summary)
505557

506558
save_results(result_file, id_to_results)
507-
logger.info(f"Re-scored results saved to {result_file}")
559+
logger.info(f"Scored results saved to {result_file}")

skythought/evals/tasks/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def check_correctness(
5858
pass
5959

6060
@abstractmethod
61-
def update_results(self, problem: Dict[str, Any], response: str):
61+
def update_results(self, problem: Dict[str, Any], response: str) -> Dict[str, Any]:
6262
pass
6363

6464
def make_conversations(

skythought/evals/util/cli_util.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ast import literal_eval
2-
from typing import Any
2+
from typing import Any, List
33

44
import msgpack
55
import xxhash
@@ -40,6 +40,11 @@ def parse_multi_args(vals: str) -> dict:
4040
) from err
4141

4242

43+
def comma_separated_to_list(vals: str) -> List[str]:
44+
vals = vals.replace(" ", "")
45+
return vals.split(",")
46+
47+
4348
def to_tuple(d) -> tuple:
4449
if isinstance(d, dict):
4550
return tuple(map(to_tuple, d.items()))

skythought/evals/util/metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import math
33
from collections import defaultdict
4-
from typing import Any, Dict
4+
from typing import Dict, List
55

66
import numpy as np
77

@@ -17,7 +17,7 @@ def _pass_at_k(n, c, k):
1717
return float(1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)))
1818

1919

20-
def pass_at_k(N: int, id_to_scores: Dict[str, Dict[str, Any]]):
20+
def pass_at_k(N: int, id_to_scores: Dict[str, List[int]]):
2121
final_passk_scores = {}
2222
k_to_passk_scores = defaultdict(list) # k -> list of scores
2323
for _, sample_scores in id_to_scores.items():

0 commit comments

Comments
 (0)