|
28 | 28 | import requests |
29 | 29 | from huggingface_hub import HfApi |
30 | 30 | from inspect_ai import Epochs, Task, task |
31 | | -from inspect_ai import eval as inspect_ai_eval |
32 | 31 | from inspect_ai import eval_set as inspect_ai_eval_set |
33 | 32 | from inspect_ai.dataset import hf_dataset |
34 | 33 | from inspect_ai.log import bundle_log_dir |
@@ -215,6 +214,7 @@ def eval( # noqa C901 |
215 | 214 | models: Annotated[list[str], Argument(help="Models to evaluate")], |
216 | 215 | tasks: Annotated[str, Argument(help="Tasks to evaluate")], |
217 | 216 | # model arguments |
| 217 | + revision: Annotated[str, Option(help="Revision of the benchmark repo on the hub")] = "main", |
218 | 218 | model_base_url: Annotated[ |
219 | 219 | str | None, |
220 | 220 | Option( |
@@ -430,15 +430,23 @@ def eval( # noqa C901 |
430 | 430 | ), |
431 | 431 | ] = False, |
432 | 432 | ): |
| 433 | + from huggingface_hub import HfApi |
| 434 | + |
433 | 435 | from lighteval.tasks.registry import Registry |
434 | 436 |
|
435 | | - registry = Registry(tasks=tasks, custom_tasks=None, load_multilingual=False) |
436 | | - task_configs = registry.task_to_configs |
437 | | - inspect_ai_tasks = [] |
| 437 | + if "/" in tasks: |
| 438 | + api = HfApi() |
| 439 | + print(f"Loading tasks from dataset repository {tasks}...") |
| 440 | + api.repo_info(repo_id=tasks, repo_type="dataset", revision=revision) |
| 441 | + inspect_ai_tasks = create_task_function(tasks, revision) |
| 442 | + else: |
| 443 | + registry = Registry(tasks=tasks, custom_tasks=None, load_multilingual=False) |
| 444 | + task_configs = registry.task_to_configs |
| 445 | + inspect_ai_tasks = [] |
438 | 446 |
|
439 | | - for task_name, task_configs in task_configs.items(): |
440 | | - for task_config in task_configs: |
441 | | - inspect_ai_tasks.append(get_inspect_ai_task(task_config, epochs=epochs, epochs_reducer=epochs_reducer)) |
| 447 | + for task_name, task_configs in task_configs.items(): |
| 448 | + for task_config in task_configs: |
| 449 | + inspect_ai_tasks.append(get_inspect_ai_task(task_config, epochs=epochs, epochs_reducer=epochs_reducer)) |
442 | 450 |
|
443 | 451 | if model_args is not None: |
444 | 452 | model_args = InspectAIModelConfig._parse_args(model_args) |
@@ -522,12 +530,6 @@ def eval( # noqa C901 |
522 | 530 | print("run 'inspect view' to view the results") |
523 | 531 |
|
524 | 532 |
|
525 | | -def from_hub(repo_id: str, models: list[str], limit: int = 100, revision: str = "main"): |
526 | | - task = create_task_function(repo_id, revision) |
527 | | - |
528 | | - inspect_ai_eval(tasks=task, model=models, limit=limit) |
529 | | - |
530 | | - |
531 | 533 | def bundle(log_dir: str, output_dir: str, overwrite: bool = True, repo_id: str | None = None, public: bool = False): |
532 | 534 | bundle_log_dir(log_dir=log_dir, output_dir=output_dir, overwrite=overwrite) |
533 | 535 |
|
|
0 commit comments