diff --git a/pyproject.toml b/pyproject.toml index 8a5a9cf90..322ed726b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,6 +107,7 @@ extended_tasks = [ ] s3 = ["s3fs"] multilingual = [ + "langcodes", "stanza", "spacy[ja,ko,th]", "jieba", # for chinese tokenizer diff --git a/src/lighteval/tasks/registry.py b/src/lighteval/tasks/registry.py index 01c43e942..ce2070d48 100644 --- a/src/lighteval/tasks/registry.py +++ b/src/lighteval/tasks/registry.py @@ -25,6 +25,7 @@ import importlib import logging import os +import sys from functools import lru_cache from itertools import groupby from pathlib import Path @@ -37,6 +38,38 @@ from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig from lighteval.utils.imports import CANNOT_USE_EXTENDED_TASKS_MSG, can_load_extended_tasks +# Import community tasks +AVAILABLE_COMMUNITY_TASKS_MODULES = [] +def load_community_tasks(): + """Dynamically load community tasks, handling errors gracefully.""" + modules = [] + try: + # Community tasks are in the lighteval directory, not under src + community_path = Path(__file__).parent.parent.parent / "community_tasks" + if not community_path.exists(): + return modules + + # Ensure the parent directory is on sys.path so we can import `community_tasks.*` + parent_dir = str(community_path.parent) + if parent_dir not in sys.path: + sys.path.insert(0, parent_dir) + + # List all python files in community_tasks + community_files = [p.stem for p in community_path.glob("*.py") if not p.name.startswith('_')] + + for module_name in community_files: + try: + module = importlib.import_module(f"community_tasks.{module_name}") + if hasattr(module, 'TASKS_TABLE'): + modules.append(module) + logger.info(f"Successfully loaded community tasks from {module_name}") + except Exception as e: + logger.warning(f"Failed to load community tasks from {module_name}: {e}") + except Exception as e: + logger.warning(f"Error loading community tasks directory: {e}") + + return modules + logger = logging.getLogger(__name__) @@ -137,6 +170,23 @@ def task_registry(self) -> dict[str, LightevalTaskConfig]: else: logger.warning(CANNOT_USE_EXTENDED_TASKS_MSG) + # Load community tasks + community_modules = load_community_tasks() + for community_task_module in community_modules: + custom_tasks_module.append(community_task_module) + + # Load multilingual tasks + MULTILINGUAL_TASKS_AVAILABLE = False + multilingual_tasks = None + try: + import lighteval.tasks.multilingual.tasks as multilingual_tasks + MULTILINGUAL_TASKS_AVAILABLE = True + except ImportError as e: + logger.warning(f"Could not load multilingual tasks: {e}. You may need to install additional dependencies.") + + if MULTILINGUAL_TASKS_AVAILABLE and multilingual_tasks is not None: + custom_tasks_module.append(multilingual_tasks) + for module in custom_tasks_module: custom_task_configs.extend(module.TASKS_TABLE) logger.info(f"Found {len(module.TASKS_TABLE)} custom tasks in {module.__file__}") @@ -171,7 +221,9 @@ def taskinfo_selector(self, tasks: str) -> dict[str, list[dict]]: Returns: - dict[str, list[dict]]: A dictionary mapping each task name to a list of tuples representing the few_shot and truncate_few_shots values. + tuple[list[str], dict[str, list[tuple[int, bool]]]]: A tuple containing: + - A sorted list of unique task names in the format "suite|task". + - A dictionary mapping each task name to a list of tuples representing the few_shot and truncate_few_shots values. """ few_shot_dict = collections.defaultdict(list) @@ -283,13 +335,26 @@ def print_all_tasks(self): Print all the tasks in the task registry. """ tasks_names = list(self.task_registry.keys()) + + # Ensure all default suites are present + suites_in_registry = {name.split("|")[0] for name in tasks_names} + for suite in DEFAULT_SUITES: + if suite not in suites_in_registry: + # We add a dummy task to make sure the suite is printed + tasks_names.append(f"{suite}|") + tasks_names.sort() + for suite, g in groupby(tasks_names, lambda x: x.split("|")[0]): - tasks_names = list(g) - tasks_names.sort() + tasks_in_suite = [name for name in g if name.split("|")[1]] # Filter out dummy tasks + tasks_in_suite.sort() + print(f"\n- {suite}:") - for task_name in tasks_names: - print(f" - {task_name}") + if not tasks_in_suite: + print(" (no tasks in this suite)") + else: + for task_name in tasks_in_suite: + print(f" - {task_name}") @staticmethod def create_custom_tasks_module(custom_tasks: str | Path | ModuleType) -> ModuleType: