Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions skythought/evals/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from pathlib import Path
from typing import Tuple
from typing import List, Tuple

import click
import typer
Expand All @@ -20,7 +20,11 @@
)
from skythought.evals.models import ModelConfig, get_system_prompt_keys
from skythought.evals.tasks import TASK_HANDLER_MAP, TASK_NAMES_TO_YAML, TaskConfig
from skythought.evals.util.cli_util import get_deterministic_hash, parse_multi_args
from skythought.evals.util.cli_util import (
comma_separated_to_list,
get_deterministic_hash,
parse_multi_args,
)
from skythought.evals.util.common import set_seed
from skythought.evals.util.results import SummaryResults

Expand Down Expand Up @@ -530,12 +534,24 @@ def score(
case_sensitive=False,
),
],
ids: Annotated[
str,
typer.Option(
help="Comma-separated list of indices in the results JSON to re-score."
"If provided, only the scores for these samples are computed/re-computed. If None, we compute scores for all samples",
),
] = None,
):
if not os.path.exists(run_dir):
raise ValueError(f"Run directory {run_dir} does not exist.")

run_dir = Path(run_dir)

if ids:
ids: List[str] = comma_separated_to_list(ids)
# make them unique
ids = list(set(ids))

if task not in TASK_NAMES_TO_YAML:
raise ValueError(
f"Task {task} not found. Should be one of {TASK_NAMES_TO_YAML.keys()}"
Expand All @@ -556,7 +572,7 @@ def score(

run_summary = SummaryResults(**run_summary)

score_results(handler, run_dir, run_summary)
score_results(handler, run_dir, run_summary, ids)


def main():
Expand Down
64 changes: 58 additions & 6 deletions skythought/evals/inference_and_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def generate_responses_for_dataset(
def score_responses(
handler: TaskHandler,
id_to_results: Dict[str, Dict[str, Any]],
*,
max_workers: int = 32,
) -> Tuple[float, Dict[str, List[int]], int]:
"""Computes correctness for model responses for the given task
Expand Down Expand Up @@ -341,7 +342,7 @@ def score_responses(
# TODO (sumanthrh): this can be improved
if unique_id not in id_to_scores:
id_to_scores[unique_id] = [0 for _ in range(N)]
id_to_scores[unique_id][i] = new_response_entry["correctness"]
id_to_scores[unique_id][i] = int(new_response_entry["correctness"])

total_correct += new_response_entry["correctness"]
total_finish += 1
Expand All @@ -350,6 +351,41 @@ def score_responses(
return accuracy, id_to_scores, total_finish


def score_responses_for_indices(
handler: TaskHandler,
id_to_results: Dict[str, Dict[str, Any]],
*,
indices: List[str],
) -> List[int]:
"""Computes correctness for model responses for the given task for the unique index `idx`.

The 'id_to_results' dictionary is assumed to be a mapping between problem ID -> { responses: [...], ... },
This is updated in-place.

Returns:
- list of scores
"""
if not id_to_results:
return []
logger.info(f"Computing scores for {len(indices)} samples")
for idx in indices:
# Figure out how many generations per problem
N = len(next(iter(id_to_results.values()))["responses"])
record = id_to_results[idx]
scores = []
for i in range(N):
content = record["responses"][i]["content"]
response_entry = handler.update_results(record, content)

# Update correctness and reason in the original results dict
id_to_results[idx]["responses"][i]["correctness"] = response_entry[
"correctness"
]
id_to_results[idx]["responses"][i]["reason"] = response_entry["reason"]
scores.append(response_entry["correctness"])
return scores


def generate_and_score(
handler: TaskHandler,
model_config: ModelConfig,
Expand Down Expand Up @@ -480,17 +516,31 @@ def generate_and_save(


def score_results(
handler: TaskHandler, run_dir: Path, run_summary: SummaryResults
handler: TaskHandler,
run_dir: Path,
run_summary: SummaryResults,
indices: Optional[List[str]] = None,
) -> None:
# load existing results
result_file = run_dir / RESULTS_FILENAME
summary_file = run_dir / SUMMARY_FILENAME
id_to_results = load_existing_results(result_file)
logger.info(f"Loaded {len(id_to_results)} existing results for scoring.")

accuracy, id_to_scores, total_finish = score_responses(handler, id_to_results)

logger.info(f"Accuracy: {accuracy}")
if not indices:
accuracy, id_to_scores, total_finish = score_responses(handler, id_to_results)
else:
N = len(next(iter(id_to_results.values()))["responses"])
score_responses_for_indices(handler, id_to_results, indices=indices)
id_to_scores = {
index: [
id_to_results[index]["responses"][i]["correctness"] for i in range(N)
]
for index in id_to_results
}
accuracy = round(
sum(map(sum, id_to_scores.values())) / (len(id_to_scores) * N), 4
)

sample_count = 0
if id_to_results:
Expand All @@ -501,7 +551,9 @@ def score_results(

run_summary.accuracy = accuracy
run_summary.pass_at_k = pass_at_k_metrics

logger.info(f"Accuracy: {accuracy}")
save_summary(summary_file, run_summary)

save_results(result_file, id_to_results)
logger.info(f"Re-scored results saved to {result_file}")
logger.info(f"Scored results saved to {result_file}")
2 changes: 1 addition & 1 deletion skythought/evals/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def check_correctness(
pass

@abstractmethod
def update_results(self, problem: Dict[str, Any], response: str):
def update_results(self, problem: Dict[str, Any], response: str) -> Dict[str, Any]:
pass

def make_conversations(
Expand Down
7 changes: 6 additions & 1 deletion skythought/evals/util/cli_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ast import literal_eval
from typing import Any
from typing import Any, List

import msgpack
import xxhash
Expand Down Expand Up @@ -40,6 +40,11 @@ def parse_multi_args(vals: str) -> dict:
) from err


def comma_separated_to_list(vals: str) -> List[str]:
vals = vals.replace(" ", "")
return vals.split(",")


def to_tuple(d) -> tuple:
if isinstance(d, dict):
return tuple(map(to_tuple, d.items()))
Expand Down
4 changes: 2 additions & 2 deletions skythought/evals/util/metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import math
from collections import defaultdict
from typing import Any, Dict
from typing import Dict, List

import numpy as np

Expand All @@ -17,7 +17,7 @@ def _pass_at_k(n, c, k):
return float(1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)))


def pass_at_k(N: int, id_to_scores: Dict[str, Dict[str, Any]]):
def pass_at_k(N: int, id_to_scores: Dict[str, List[int]]):
final_passk_scores = {}
k_to_passk_scores = defaultdict(list) # k -> list of scores
for _, sample_scores in id_to_scores.items():
Expand Down