Skip to content

Commit a0248bb

Browse files
Update registry.py
update to Nathan's suggestions
1 parent 82a3c25 commit a0248bb

File tree

1 file changed

+24
-19
lines changed

1 file changed

+24
-19
lines changed

src/lighteval/tasks/registry.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import importlib
2626
import logging
2727
import os
28+
import sys
2829
from functools import lru_cache
2930
from itertools import groupby
3031
from pathlib import Path
@@ -44,24 +45,26 @@ def load_community_tasks():
4445
modules = []
4546
try:
4647
# Community tasks are in the lighteval directory, not under src
47-
import sys
48-
import os
49-
community_path = os.path.join(os.path.dirname(__file__), "..", "..", "..", "community_tasks")
50-
if os.path.exists(community_path):
51-
sys.path.insert(0, os.path.dirname(community_path))
52-
53-
# List all python files in community_tasks
54-
community_files = [f[:-3] for f in os.listdir(community_path)
55-
if f.endswith('.py') and not f.startswith('_')]
56-
57-
for module_name in community_files:
58-
try:
59-
module = importlib.import_module(f"community_tasks.{module_name}")
60-
if hasattr(module, 'TASKS_TABLE'):
61-
modules.append(module)
62-
logger.info(f"Successfully loaded community tasks from {module_name}")
63-
except Exception as e:
64-
logger.warning(f"Failed to load community tasks from {module_name}: {e}")
48+
community_path = Path(__file__).parent.parent.parent / "community_tasks"
49+
if not community_path.exists():
50+
return modules
51+
52+
# Ensure the parent directory is on sys.path so we can import `community_tasks.*`
53+
parent_dir = str(community_path.parent)
54+
if parent_dir not in sys.path:
55+
sys.path.insert(0, parent_dir)
56+
57+
# List all python files in community_tasks
58+
community_files = [p.stem for p in community_path.glob("*.py") if not p.name.startswith('_')]
59+
60+
for module_name in community_files:
61+
try:
62+
module = importlib.import_module(f"community_tasks.{module_name}")
63+
if hasattr(module, 'TASKS_TABLE'):
64+
modules.append(module)
65+
logger.info(f"Successfully loaded community tasks from {module_name}")
66+
except Exception as e:
67+
logger.warning(f"Failed to load community tasks from {module_name}: {e}")
6568
except Exception as e:
6669
logger.warning(f"Error loading community tasks directory: {e}")
6770

@@ -218,7 +221,9 @@ def taskinfo_selector(self, tasks: str) -> dict[str, list[dict]]:
218221
219222
220223
Returns:
221-
dict[str, list[dict]]: A dictionary mapping each task name to a list of tuples representing the few_shot and truncate_few_shots values.
224+
tuple[list[str], dict[str, list[tuple[int, bool]]]]: A tuple containing:
225+
- A sorted list of unique task names in the format "suite|task".
226+
- A dictionary mapping each task name to a list of tuples representing the few_shot and truncate_few_shots values.
222227
"""
223228
few_shot_dict = collections.defaultdict(list)
224229

0 commit comments

Comments
 (0)