Skip to content

Commit 273d421

Browse files
committed
use a eval.yaml from the hub
1 parent c0e6004 commit 273d421

File tree

3 files changed

+27
-16
lines changed

3 files changed

+27
-16
lines changed

src/lighteval/__main__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@
7171
app.command(rich_help_panel="Evaluation Backends")(lighteval.main_custom.custom)
7272
app.command(rich_help_panel="Evaluation Backends")(lighteval.main_sglang.sglang)
7373
app.command(rich_help_panel="Evaluation Backends")(lighteval.main_inspect.eval)
74-
app.command(rich_help_panel="Evaluation Backends")(lighteval.main_inspect.from_hub)
7574
app.command(rich_help_panel="EvaluationUtils")(lighteval.main_inspect.bundle)
7675
app.add_typer(
7776
lighteval.main_endpoint.app,

src/lighteval/from_hub.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,22 @@ def record_to_sample(record, field_spec: dict):
1414
"""
1515
input_text = record[field_spec["input"]]
1616

17-
target = record[field_spec["target"]]
17+
target = field_spec["target"]
18+
19+
if target in ascii_uppercase:
20+
target = target
21+
else:
22+
target = record[field_spec["target"]]
1823

1924
if isinstance(target, int):
2025
target = ascii_uppercase[target]
2126

22-
choices_list = record[field_spec["choices"]]
27+
choices = field_spec["choices"]
28+
29+
if isinstance(choices, list):
30+
choices_list = [record[choice] for choice in choices]
31+
else:
32+
choices_list = record[choices]
2333

2434
metadata = field_spec.get("metadata", None)
2535

src/lighteval/main_inspect.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import requests
2929
from huggingface_hub import HfApi
3030
from inspect_ai import Epochs, Task, task
31-
from inspect_ai import eval as inspect_ai_eval
3231
from inspect_ai import eval_set as inspect_ai_eval_set
3332
from inspect_ai.dataset import hf_dataset
3433
from inspect_ai.log import bundle_log_dir
@@ -215,6 +214,7 @@ def eval( # noqa C901
215214
models: Annotated[list[str], Argument(help="Models to evaluate")],
216215
tasks: Annotated[str, Argument(help="Tasks to evaluate")],
217216
# model arguments
217+
revision: Annotated[str, Option(help="Revision of the benchmark repo on the hub")] = "main",
218218
model_base_url: Annotated[
219219
str | None,
220220
Option(
@@ -430,15 +430,23 @@ def eval( # noqa C901
430430
),
431431
] = False,
432432
):
433+
from huggingface_hub import HfApi
434+
433435
from lighteval.tasks.registry import Registry
434436

435-
registry = Registry(tasks=tasks, custom_tasks=None, load_multilingual=False)
436-
task_configs = registry.task_to_configs
437-
inspect_ai_tasks = []
437+
if "/" in tasks:
438+
api = HfApi()
439+
print(f"Loading tasks from dataset repository {tasks}...")
440+
api.repo_info(repo_id=tasks, repo_type="dataset", revision=revision)
441+
inspect_ai_tasks = create_task_function(tasks, revision)
442+
else:
443+
registry = Registry(tasks=tasks, custom_tasks=None, load_multilingual=False)
444+
task_configs = registry.task_to_configs
445+
inspect_ai_tasks = []
438446

439-
for task_name, task_configs in task_configs.items():
440-
for task_config in task_configs:
441-
inspect_ai_tasks.append(get_inspect_ai_task(task_config, epochs=epochs, epochs_reducer=epochs_reducer))
447+
for task_name, task_configs in task_configs.items():
448+
for task_config in task_configs:
449+
inspect_ai_tasks.append(get_inspect_ai_task(task_config, epochs=epochs, epochs_reducer=epochs_reducer))
442450

443451
if model_args is not None:
444452
model_args = InspectAIModelConfig._parse_args(model_args)
@@ -522,12 +530,6 @@ def eval( # noqa C901
522530
print("run 'inspect view' to view the results")
523531

524532

525-
def from_hub(repo_id: str, models: list[str], limit: int = 100, revision: str = "main"):
526-
task = create_task_function(repo_id, revision)
527-
528-
inspect_ai_eval(tasks=task, model=models, limit=limit)
529-
530-
531533
def bundle(log_dir: str, output_dir: str, overwrite: bool = True, repo_id: str | None = None, public: bool = False):
532534
bundle_log_dir(log_dir=log_dir, output_dir=output_dir, overwrite=overwrite)
533535

0 commit comments

Comments
 (0)