Skip to content

Commit d298458

Browse files
committed
use a eval.yaml from the hub
1 parent 7d2809b commit d298458

File tree

2 files changed

+83
-75
lines changed

2 files changed

+83
-75
lines changed

src/lighteval/from_hub.py

Lines changed: 80 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,10 @@
1-
import os
2-
from pathlib import Path
1+
from importlib import import_module
32
from string import ascii_uppercase
43

54
import yaml
65
from huggingface_hub import hf_hub_download
76
from inspect_ai import Epochs, Task, task
87
from inspect_ai.dataset import FieldSpec, Sample, hf_dataset
9-
from inspect_ai.scorer import choice, exact, match, model_graded_fact
10-
from inspect_ai.solver import (
11-
chain_of_thought,
12-
generate,
13-
multiple_choice,
14-
prompt_template,
15-
system_message,
16-
)
17-
18-
19-
def load_config(yaml_path: str = None) -> dict:
20-
"""Load and parse the YAML configuration file."""
21-
if yaml_path is None:
22-
yaml_path = os.getenv("EVAL_YAML", "eval.yaml")
23-
24-
yaml_path = Path(yaml_path)
25-
if not yaml_path.is_absolute():
26-
yaml_path = Path(__file__).parent / yaml_path
27-
28-
with open(yaml_path, "r") as f:
29-
return yaml.safe_load(f)
308

319

3210
def record_to_sample(record, field_spec: dict):
@@ -51,9 +29,9 @@ def record_to_sample(record, field_spec: dict):
5129
return Sample(**sample_kwargs)
5230

5331

54-
def load_dataset(repo_id: str, revision: str = "main", task_config: dict = None, global_config: dict = None):
32+
def load_dataset(repo_id: str, revision: str = "main", task_config: dict = None):
5533
"""Load dataset based on task configuration."""
56-
subset = task_config.get("subset")
34+
subset = task_config.get("subset", "default")
5735
split = task_config.get("splits", "test")
5836
field_spec = task_config["field_spec"]
5937

@@ -76,85 +54,115 @@ def load_dataset(repo_id: str, revision: str = "main", task_config: dict = None,
7654
sample_fields=FieldSpec(
7755
input=field_spec["input"],
7856
target=field_spec["target"],
79-
**({k: v for k, v in field_spec.items() if k not in ["input", "target"]}),
57+
metadata=field_spec.get("metadata", []),
8058
),
8159
)
8260

8361
return dataset
8462

8563

8664
def build_solvers(task_config: dict):
87-
"""Build solvers list from task configuration."""
65+
"""
66+
Build a list of solvers from the task configuration.
67+
68+
task_config example:
69+
70+
```yaml
71+
solvers:
72+
- name: prompt_template
73+
args:
74+
template: >
75+
You are a helpful assistant.
76+
{prompt}
77+
- name: generate
78+
args:
79+
cache: true
80+
```
81+
82+
83+
"""
8884
solvers = []
89-
solver_names = task_config.get("solvers", [])
90-
91-
for solver_name in solver_names:
92-
if solver_name == "prompt_template":
93-
if "prompt_template" in task_config and task_config["prompt_template"]:
94-
template = task_config["prompt_template"].strip().strip('"')
95-
template = template.replace("{{prompt}}", "{prompt}")
96-
solvers.append(prompt_template(template))
97-
elif solver_name == "system_message":
98-
if "system_message" in task_config and task_config["system_message"]:
99-
sys_msg = task_config["system_message"].strip().strip('"')
100-
solvers.append(system_message(sys_msg))
101-
elif solver_name == "chain_of_thought":
102-
solvers.append(chain_of_thought())
103-
elif solver_name == "multiple_choice":
104-
solvers.append(multiple_choice())
105-
elif solver_name == "generate":
106-
solvers.append(generate())
85+
solver_configs = task_config.get("solvers", [])
86+
solver_module = import_module("inspect_ai.solver")
10787

108-
return solvers
88+
for solver_config in solver_configs:
89+
solver_name = solver_config["name"]
10990

91+
if not hasattr(solver_module, solver_name):
92+
raise ValueError(f"Unknown solver: {solver_name}")
11093

111-
def build_scorer(task_config: dict):
112-
"""Build scorer from task configuration."""
113-
scorer_name = task_config.get("scorers", ["choice"])[0]
114-
115-
if scorer_name == "choice":
116-
return choice()
117-
elif scorer_name == "exact":
118-
return exact()
119-
elif scorer_name == "match":
120-
return match()
121-
elif scorer_name == "model_graded_fact":
122-
return model_graded_fact()
123-
else:
124-
raise ValueError(f"Unknown scorer: {scorer_name}")
94+
solver_fn = getattr(solver_module, solver_name)
95+
solvers.append(solver_fn(**solver_config.get("args", {})))
96+
97+
return solvers
12598

12699

127-
def create_task_from_config(
128-
repo_id: str, revision: str = "main", task_config: dict = None, global_config: dict = None
129-
):
100+
def build_scorer(task_config: dict):
101+
"""
102+
Build a scorer from the task configuration.
103+
task_config example:
104+
105+
```yaml
106+
scorers:
107+
- name: model_graded_fact
108+
args:
109+
template: |
110+
grade this,
111+
112+
question:
113+
{question}
114+
criterion:
115+
{criterion}
116+
answer:
117+
{answer}
118+
```
119+
"""
120+
scorers = []
121+
scorer_configs = task_config.get("scorers", [])
122+
scorer_module = import_module("inspect_ai.scorer")
123+
124+
for scorer_config in scorer_configs:
125+
scorer_name = scorer_config["name"]
126+
127+
if not hasattr(scorer_module, scorer_name):
128+
raise ValueError(f"Unknown scorer: {scorer_name}")
129+
130+
scorer_fn = getattr(scorer_module, scorer_name)
131+
scorers.append(scorer_fn(**scorer_config.get("args", {})))
132+
133+
return scorers
134+
135+
136+
@task
137+
def create_task_from_config(repo_id: str, revision: str = "main", task_config: dict = None):
130138
"""Create an inspect.ai Task from a task configuration."""
131-
dataset = load_dataset(repo_id, revision, task_config, global_config)
139+
dataset = load_dataset(repo_id, revision, task_config)
132140
solvers = build_solvers(task_config)
133-
scorer = build_scorer(task_config)
141+
scorers = build_scorer(task_config)
134142
epochs = task_config.get("epochs", 1)
135143
epochs_reducer = task_config.get("epochs_reducer", "mean")
136144

137145
return Task(
138146
dataset=dataset,
139147
solver=solvers,
140-
scorer=scorer,
148+
scorer=scorers,
141149
name=task_config["name"],
142150
epochs=Epochs(epochs, epochs_reducer),
143151
)
144152

145153

146-
def create_task_function(repo_id: str, revision: str = "main"):
154+
def create_task_function(repo_id: str, revision: str = "main") -> list:
147155
"""Factory function to create a task function with proper closure."""
148156
# read yaml from hf filesystem
149157
yaml_path = hf_hub_download(repo_id=repo_id, filename="eval.yaml", repo_type="dataset", revision=revision)
150158

151159
with open(yaml_path, "r") as f:
152160
global_config = yaml.safe_load(f)
153161

154-
task_config = global_config["tasks"][0]
162+
task_configs = global_config["tasks"]
155163

156-
@task
157-
def task_func():
158-
return create_task_from_config(repo_id, revision, task_config, global_config)
164+
tasks = []
165+
for task_config in task_configs:
166+
tasks.append(create_task_from_config(repo_id, revision, task_config))
159167

160-
return task_func
168+
return tasks

src/lighteval/main_inspect.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -522,10 +522,10 @@ def eval( # noqa C901
522522
print("run 'inspect view' to view the results")
523523

524524

525-
def from_hub(model: str, repo_id: str, limit: int = 100, revision: str = "main"):
525+
def from_hub(repo_id: str, models: list[str], limit: int = 100, revision: str = "main"):
526526
task = create_task_function(repo_id, revision)
527-
model = "hf-inference-providers/meta-llama/Llama-3.1-8B-Instruct"
528-
inspect_ai_eval(tasks=[task], model=model, limit=100)
527+
528+
inspect_ai_eval(tasks=task, model=models, limit=limit)
529529

530530

531531
def bundle(log_dir: str, output_dir: str, overwrite: bool = True, repo_id: str | None = None, public: bool = False):

0 commit comments

Comments
 (0)