Skip to content

Commit 7d2809b

Browse files
committed
use a eval.yaml from the hub
1 parent 66ce47e commit 7d2809b

File tree

3 files changed

+169
-0
lines changed

3 files changed

+169
-0
lines changed

src/lighteval/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
app.command(rich_help_panel="Evaluation Backends")(lighteval.main_custom.custom)
7272
app.command(rich_help_panel="Evaluation Backends")(lighteval.main_sglang.sglang)
7373
app.command(rich_help_panel="Evaluation Backends")(lighteval.main_inspect.eval)
74+
app.command(rich_help_panel="Evaluation Backends")(lighteval.main_inspect.from_hub)
7475
app.command(rich_help_panel="EvaluationUtils")(lighteval.main_inspect.bundle)
7576
app.add_typer(
7677
lighteval.main_endpoint.app,

src/lighteval/from_hub.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import os
2+
from pathlib import Path
3+
from string import ascii_uppercase
4+
5+
import yaml
6+
from huggingface_hub import hf_hub_download
7+
from inspect_ai import Epochs, Task, task
8+
from inspect_ai.dataset import FieldSpec, Sample, hf_dataset
9+
from inspect_ai.scorer import choice, exact, match, model_graded_fact
10+
from inspect_ai.solver import (
11+
chain_of_thought,
12+
generate,
13+
multiple_choice,
14+
prompt_template,
15+
system_message,
16+
)
17+
18+
19+
def load_config(yaml_path: str = None) -> dict:
20+
"""Load and parse the YAML configuration file."""
21+
if yaml_path is None:
22+
yaml_path = os.getenv("EVAL_YAML", "eval.yaml")
23+
24+
yaml_path = Path(yaml_path)
25+
if not yaml_path.is_absolute():
26+
yaml_path = Path(__file__).parent / yaml_path
27+
28+
with open(yaml_path, "r") as f:
29+
return yaml.safe_load(f)
30+
31+
32+
def record_to_sample(record, field_spec: dict):
33+
"""Convert a dataset record to a Sample based on field_spec."""
34+
input_text = record[field_spec["input"]]
35+
36+
# Handle target - convert numeric labels to letters for multiple choice
37+
target_letter = ascii_uppercase[record[field_spec["target"]]]
38+
39+
# Get choices if specified
40+
choices_list = None
41+
if "choices" in field_spec:
42+
choices_list = [record[choice_field] for choice_field in field_spec["choices"]]
43+
44+
sample_kwargs = {
45+
"input": input_text,
46+
"target": target_letter,
47+
}
48+
if choices_list:
49+
sample_kwargs["choices"] = choices_list
50+
51+
return Sample(**sample_kwargs)
52+
53+
54+
def load_dataset(repo_id: str, revision: str = "main", task_config: dict = None, global_config: dict = None):
55+
"""Load dataset based on task configuration."""
56+
subset = task_config.get("subset")
57+
split = task_config.get("splits", "test")
58+
field_spec = task_config["field_spec"]
59+
60+
# Use custom function if choices are specified (for multiple choice with label conversion)
61+
if "choices" in field_spec:
62+
dataset = hf_dataset(
63+
path=repo_id,
64+
revision=revision,
65+
name=subset,
66+
split=split,
67+
sample_fields=lambda record: record_to_sample(record, field_spec),
68+
)
69+
else:
70+
# For non-multiple-choice, use FieldSpec
71+
dataset = hf_dataset(
72+
path=repo_id,
73+
revision=revision,
74+
name=subset,
75+
split=split,
76+
sample_fields=FieldSpec(
77+
input=field_spec["input"],
78+
target=field_spec["target"],
79+
**({k: v for k, v in field_spec.items() if k not in ["input", "target"]}),
80+
),
81+
)
82+
83+
return dataset
84+
85+
86+
def build_solvers(task_config: dict):
87+
"""Build solvers list from task configuration."""
88+
solvers = []
89+
solver_names = task_config.get("solvers", [])
90+
91+
for solver_name in solver_names:
92+
if solver_name == "prompt_template":
93+
if "prompt_template" in task_config and task_config["prompt_template"]:
94+
template = task_config["prompt_template"].strip().strip('"')
95+
template = template.replace("{{prompt}}", "{prompt}")
96+
solvers.append(prompt_template(template))
97+
elif solver_name == "system_message":
98+
if "system_message" in task_config and task_config["system_message"]:
99+
sys_msg = task_config["system_message"].strip().strip('"')
100+
solvers.append(system_message(sys_msg))
101+
elif solver_name == "chain_of_thought":
102+
solvers.append(chain_of_thought())
103+
elif solver_name == "multiple_choice":
104+
solvers.append(multiple_choice())
105+
elif solver_name == "generate":
106+
solvers.append(generate())
107+
108+
return solvers
109+
110+
111+
def build_scorer(task_config: dict):
112+
"""Build scorer from task configuration."""
113+
scorer_name = task_config.get("scorers", ["choice"])[0]
114+
115+
if scorer_name == "choice":
116+
return choice()
117+
elif scorer_name == "exact":
118+
return exact()
119+
elif scorer_name == "match":
120+
return match()
121+
elif scorer_name == "model_graded_fact":
122+
return model_graded_fact()
123+
else:
124+
raise ValueError(f"Unknown scorer: {scorer_name}")
125+
126+
127+
def create_task_from_config(
128+
repo_id: str, revision: str = "main", task_config: dict = None, global_config: dict = None
129+
):
130+
"""Create an inspect.ai Task from a task configuration."""
131+
dataset = load_dataset(repo_id, revision, task_config, global_config)
132+
solvers = build_solvers(task_config)
133+
scorer = build_scorer(task_config)
134+
epochs = task_config.get("epochs", 1)
135+
epochs_reducer = task_config.get("epochs_reducer", "mean")
136+
137+
return Task(
138+
dataset=dataset,
139+
solver=solvers,
140+
scorer=scorer,
141+
name=task_config["name"],
142+
epochs=Epochs(epochs, epochs_reducer),
143+
)
144+
145+
146+
def create_task_function(repo_id: str, revision: str = "main"):
147+
"""Factory function to create a task function with proper closure."""
148+
# read yaml from hf filesystem
149+
yaml_path = hf_hub_download(repo_id=repo_id, filename="eval.yaml", repo_type="dataset", revision=revision)
150+
151+
with open(yaml_path, "r") as f:
152+
global_config = yaml.safe_load(f)
153+
154+
task_config = global_config["tasks"][0]
155+
156+
@task
157+
def task_func():
158+
return create_task_from_config(repo_id, revision, task_config, global_config)
159+
160+
return task_func

src/lighteval/main_inspect.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import requests
2929
from huggingface_hub import HfApi
3030
from inspect_ai import Epochs, Task, task
31+
from inspect_ai import eval as inspect_ai_eval
3132
from inspect_ai import eval_set as inspect_ai_eval_set
3233
from inspect_ai.dataset import hf_dataset
3334
from inspect_ai.log import bundle_log_dir
@@ -37,6 +38,7 @@
3738
from typer import Argument, Option
3839
from typing_extensions import Annotated
3940

41+
from lighteval.from_hub import create_task_function
4042
from lighteval.models.abstract_model import InspectAIModelConfig
4143
from lighteval.tasks.lighteval_task import LightevalTaskConfig
4244

@@ -520,6 +522,12 @@ def eval( # noqa C901
520522
print("run 'inspect view' to view the results")
521523

522524

525+
def from_hub(model: str, repo_id: str, limit: int = 100, revision: str = "main"):
526+
task = create_task_function(repo_id, revision)
527+
model = "hf-inference-providers/meta-llama/Llama-3.1-8B-Instruct"
528+
inspect_ai_eval(tasks=[task], model=model, limit=100)
529+
530+
523531
def bundle(log_dir: str, output_dir: str, overwrite: bool = True, repo_id: str | None = None, public: bool = False):
524532
bundle_log_dir(log_dir=log_dir, output_dir=output_dir, overwrite=overwrite)
525533

0 commit comments

Comments
 (0)