diff --git a/lm_eval/tasks/__init__.py b/lm_eval/tasks/__init__.py index f9251386078..71d65df1200 100644 --- a/lm_eval/tasks/__init__.py +++ b/lm_eval/tasks/__init__.py @@ -1,555 +1,27 @@ -import collections -import inspect -import logging -import os -from collections.abc import Mapping -from functools import partial -from pathlib import Path +"""Task management for lm-evaluation-harness. + +This module provides: +- TaskManager: Main class for discovering and loading evaluation tasks +- get_task_dict: Function to create a dictionary of task objects +- Helper functions for task name resolution +""" + from typing import Dict, List, Optional, Union -from lm_eval import utils -from lm_eval.api.group import ConfigurableGroup, GroupConfig from lm_eval.api.task import ConfigurableTask, Task from lm_eval.evaluator_utils import get_subtask_list +# Import TaskManager from the refactored module +from lm_eval.tasks.manager import TaskManager -GROUP_ONLY_KEYS = list(GroupConfig().to_dict().keys()) - -eval_logger = logging.getLogger(__name__) - - -class TaskManager: - """TaskManager indexes all tasks from the default `lm_eval/tasks/` - and an optional directory if provided. - - """ - - def __init__( - self, - verbosity: str | None = None, - include_path: str | list | None = None, - include_defaults: bool = True, - metadata: dict | None = None, - ) -> None: - if verbosity is not None: - utils.setup_logging(verbosity) - self.include_path = include_path - self.metadata = metadata - self._task_index = self.initialize_tasks( - include_path=include_path, include_defaults=include_defaults - ) - self._all_tasks = sorted(list(self._task_index.keys())) - - self._all_groups = sorted( - [x for x in self._all_tasks if self._task_index[x]["type"] == "group"] - ) - self._all_subtasks = sorted( - [ - x - for x in self._all_tasks - if self._task_index[x]["type"] in ["task", "python_task"] - ] - ) - self._all_tags = sorted( - [x for x in self._all_tasks if self._task_index[x]["type"] == "tag"] - ) - - self.task_group_map = collections.defaultdict(list) - - def initialize_tasks( - self, - include_path: str | list | None = None, - include_defaults: bool = True, - ) -> dict[str, dict]: - """Creates a dictionary of tasks indexes. - - :param include_path: Union[str, List] = None - An additional path to be searched for tasks recursively. - Can provide more than one such path as a list. - :param include_defaults: bool = True - If set to false, default tasks (those in lm_eval/tasks/) are not indexed. - return - Dictionary of task names as key and task metadata - """ - if include_defaults: - all_paths = [os.path.dirname(os.path.abspath(__file__)) + "/"] - else: - all_paths = [] - if include_path is not None: - if isinstance(include_path, str): - include_path = [include_path] - all_paths.extend(include_path) - - task_index = {} - for task_dir in all_paths: - tasks = self._get_task_and_group(task_dir) - task_index = {**task_index, **tasks} - - return task_index - - @property - def all_tasks(self): - return self._all_tasks - - @property - def all_groups(self): - return self._all_groups - - @property - def all_subtasks(self): - return self._all_subtasks - - @property - def all_tags(self): - return self._all_tags - - @property - def task_index(self): - return self._task_index - - def list_all_tasks( - self, list_groups=True, list_tags=True, list_subtasks=True - ) -> str: - from pytablewriter import MarkdownTableWriter - - def sanitize_path(path): - # don't print full path if we are within the lm_eval/tasks dir ! - # if we aren't though, provide the full path. - if "lm_eval/tasks/" in path: - return "lm_eval/tasks/" + path.split("lm_eval/tasks/")[-1] - else: - return path - - group_table = MarkdownTableWriter() - group_table.headers = ["Group", "Config Location"] - gt_values = [] - for g in self.all_groups: - path = self.task_index[g]["yaml_path"] - if path == -1: - path = "---" - else: - path = sanitize_path(path) - gt_values.append([g, path]) - group_table.value_matrix = gt_values - - tag_table = MarkdownTableWriter() - tag_table.headers = ["Tag"] - tag_table.value_matrix = [[t] for t in self.all_tags] - - subtask_table = MarkdownTableWriter() - subtask_table.headers = ["Task", "Config Location", "Output Type"] - st_values = [] - for t in self.all_subtasks: - path = self.task_index[t]["yaml_path"] - - output_type = "" - - # read the yaml file to determine the output type - if path != -1: - config = utils.load_yaml_config(path, mode="simple") - if "output_type" in config: - output_type = config["output_type"] - elif ( - "include" in config - ): # if no output type, check if there is an include with an output type - include_path = path.split("/")[:-1] + config["include"] - include_config = utils.load_yaml_config(include_path, mode="simple") - if "output_type" in include_config: - output_type = include_config["output_type"] - - if path == -1: - path = "---" - else: - path = sanitize_path(path) - st_values.append([t, path, output_type]) - subtask_table.value_matrix = st_values - - result = "\n" - if list_groups: - result += group_table.dumps() + "\n\n" - if list_tags: - result += tag_table.dumps() + "\n\n" - if list_subtasks: - result += subtask_table.dumps() + "\n\n" - return result - - def match_tasks(self, task_list: list[str]) -> list[str]: - return utils.pattern_match(task_list, self.all_tasks) - - def _name_is_registered(self, name: str) -> bool: - if name in self.all_tasks: - return True - return False - - def _name_is_task(self, name: str) -> bool: - if self._name_is_registered(name) and (self.task_index[name]["type"] == "task"): - return True - return False - - def _name_is_tag(self, name: str) -> bool: - if self._name_is_registered(name) and (self.task_index[name]["type"] == "tag"): - return True - return False - - def _name_is_group(self, name: str) -> bool: - if self._name_is_registered(name) and ( - self.task_index[name]["type"] == "group" - ): - return True - return False - - def _name_is_python_task(self, name: str) -> bool: - if self._name_is_registered(name) and ( - self.task_index[name]["type"] == "python_task" - ): - return True - return False - - def _config_is_task(self, config: dict) -> bool: - if ("task" in config) and isinstance(config["task"], str): - return True - return False - - def _config_is_group(self, config: dict) -> bool: - if ("task" in config) and isinstance(config["task"], list): - return True - return False - - def _config_is_python_task(self, config: dict) -> bool: - if "class" in config: - return True - return False - - def _get_yaml_path(self, name: str): - if name not in self.task_index: - raise ValueError - return self.task_index[name]["yaml_path"] - - def _get_config(self, name): - if name not in self.task_index: - raise ValueError - yaml_path = self._get_yaml_path(name) - if yaml_path == -1: - return {} - else: - return utils.load_yaml_config(yaml_path, mode="full") - - def _get_tasklist(self, name): - if self._name_is_task(name): - raise ValueError - return self.task_index[name]["task"] - - def _process_alias(self, config, group=None): - # If the group is not the same as the original - # group which the group alias was intended for, - # Set the group_alias to None instead. - if ("group_alias" in config) and ("group" in config) and group is not None: - if config["group"] != group: - config["group_alias"] = None - return config - - def _class_has_config_in_constructor(self, cls): - constructor = getattr(cls, "__init__", None) - return ( - "config" in inspect.signature(constructor).parameters - if constructor - else False - ) - def _load_individual_task_or_group( - self, - name_or_config: str | dict | None = None, - parent_name: str | None = None, - update_config: dict | None = None, - ) -> Mapping: - def _load_task(config, task): - if "include" in config: - config = { - **utils.load_yaml_config( - yaml_path=None, - yaml_config={"include": config.pop("include")}, - mode="full", - ), - **config, - } - if self._config_is_python_task(config): - if self._class_has_config_in_constructor(config["class"]): - task_object = config["class"](config=config) - else: - task_object = config["class"]() - if isinstance(task_object, ConfigurableTask): - # very scuffed: set task name here. TODO: fixme? - task_object.config.task = task - else: - if self.metadata is not None: - config["metadata"] = config.get("metadata", {}) | self.metadata - else: - config["metadata"] = config.get("metadata", {}) - task_object = ConfigurableTask(config=config) - - return {task: task_object} - - def _get_group_and_subtask_from_config( - config: dict, - ) -> tuple[ConfigurableGroup, list[str]]: - if self.metadata is not None: - config["metadata"] = config.get("metadata", {}) | self.metadata - group_name = ConfigurableGroup(config=config) - subtask_list = [] - for task in group_name.config["task"]: - if isinstance(task, str) and self._name_is_tag(task): - subtask_list.extend(self._get_tasklist(task)) - else: - subtask_list.append(task) - return group_name, subtask_list - - def _process_group_config( - config: dict, update_config: dict = None - ) -> tuple[dict, dict]: - if update_config is not None: - config = {**config, **update_config} - _update_config = { - k: v for k, v in config.items() if k not in GROUP_ONLY_KEYS - } - if not bool(_update_config): - _update_config = None - - group_config = {k: v for k, v in config.items() if k in GROUP_ONLY_KEYS} - return group_config, _update_config - - if isinstance(name_or_config, str): - if update_config is not None: - # Process name_or_config as a dict instead - name_or_config = {"task": name_or_config, **update_config} - elif self._name_is_task(name_or_config) or self._name_is_python_task( - name_or_config - ): - task_config = self._get_config(name_or_config) - return _load_task(task_config, task=name_or_config) - else: - subtask_list = self._get_tasklist(name_or_config) - if subtask_list == -1: - group_config = self._get_config(name_or_config) - group_config, update_config = _process_group_config(group_config) - group_name, subtask_list = _get_group_and_subtask_from_config( - group_config - ) - else: - if self._name_is_tag(name_or_config): - fn = partial( - self._load_individual_task_or_group, - update_config=name_or_config - if isinstance(name_or_config, dict) - else None, - ) - return dict( - collections.ChainMap(*map(fn, reversed(subtask_list))) - ) - else: - group_name = ConfigurableGroup( - config={"group": name_or_config, "task": subtask_list} - ) - - if isinstance(name_or_config, dict): - if self._config_is_task(name_or_config): - name = name_or_config.pop("task") - if update_config is not None: - name_or_config = {**name_or_config, **update_config} - # If the name is registered as a group - if self._name_is_group(name): - group_config = self._get_config(name) - - group_config, update_config = _process_group_config( - group_config, name_or_config - ) - group_name, subtask_list = _get_group_and_subtask_from_config( - group_config - ) - elif self._name_is_tag(name): - subtask_list = self._get_tasklist(name) - fn = partial( - self._load_individual_task_or_group, - update_config=name_or_config, - ) - return dict(collections.ChainMap(*map(fn, reversed(subtask_list)))) - else: - if self._name_is_registered(name): - base_task_config = self._get_config(name) - - # Check if this is a duplicate. - if parent_name is not None: - num_duplicate = len( - list( - filter( - lambda x: x.startswith(name), - self.task_group_map[parent_name], - ) - ) - ) - if num_duplicate > 0: - name = f"{name}-{num_duplicate}" - self.task_group_map[parent_name].append(name) - - task_config = { - **base_task_config, - **name_or_config, - } - else: - task_config = name_or_config - return _load_task(task_config, task=name) - else: - group_config, update_config = _process_group_config(name_or_config) - group_name, subtask_list = _get_group_and_subtask_from_config( - group_config - ) - - fn = partial( - self._load_individual_task_or_group, - parent_name=group_name, - update_config=update_config, - ) - return { - group_name: dict(collections.ChainMap(*map(fn, reversed(subtask_list)))) - } - - def load_task_or_group(self, task_list: str | list | None = None) -> dict: - """Loads a dictionary of task objects from a list - - :param task_list: Union[str, list] = None - Single string or list of string of task names to be loaded - - :return - Dictionary of task objects - """ - if isinstance(task_list, str): - task_list = [task_list] - - all_loaded_tasks = dict( - collections.ChainMap( - *map( - lambda task: self._load_individual_task_or_group(task), - task_list, - ) - ) - ) - return all_loaded_tasks - - def load_config(self, config: dict): - return self._load_individual_task_or_group(config) - - def _get_task_and_group(self, task_dir: str): - """Creates a dictionary of tasks index with the following metadata, - - `type`, that can be either `task`, `python_task`, `group` or `tags`. - `task` refer to regular task configs, `python_task` are special - yaml files that only consists of `task` and `class` parameters. - `group` are group configs. `tags` are labels that can be assigned - to tasks to assist in sorting and calling tasks of certain themes. - - `yaml_path`, path to the yaml file. If the entry is a `group` that - was configured through a task config, the yaml_path will be -1 - and all subtasks will be listed in `task` (see below) - - `task`, reserved for entries with `type` as `group`. This will list - all subtasks. When a group config is created (as opposed to task - config having `group` parameter set), this will be set to -1 to - avoid recursive indexing. The whole list of subtasks will be loaded - at evaluation. - - :param task_dir: str - A directory to check for tasks - - :return - Dictionary of task names as key and task metadata - """ - - def _populate_tags_and_groups(config, task, tasks_and_groups, print_info): - # TODO: remove group in next release - if "tag" in config: - attr_list = config["tag"] - if isinstance(attr_list, str): - attr_list = [attr_list] - - for tag in attr_list: - if tag not in tasks_and_groups: - tasks_and_groups[tag] = { - "type": "tag", - "task": [task], - "yaml_path": -1, - } - elif tasks_and_groups[tag]["type"] != "tag": - eval_logger.info( - f"The tag '{tag}' is already registered as a group, this tag will not be registered. " - "This may affect tasks you want to call." - ) - break - else: - tasks_and_groups[tag]["task"].append(task) - - # TODO: remove group in next release - print_info = True - ignore_dirs = [ - "__pycache__", - ".ipynb_checkpoints", - ] - tasks_and_groups = collections.defaultdict() - for root, dirs, file_list in os.walk(task_dir): - dirs[:] = [d for d in dirs if d not in ignore_dirs] - dirs.sort() # Sort directories for deterministic traversal order - file_list.sort() # Sort files for consistent processing order - for f in file_list: - if f.endswith(".yaml"): - yaml_path = os.path.join(root, f) - config = utils.load_yaml_config(yaml_path, mode="simple") - if self._config_is_python_task(config): - # This is a python class config - task = config["task"] - tasks_and_groups[task] = { - "type": "python_task", - "yaml_path": yaml_path, - } - _populate_tags_and_groups( - config, task, tasks_and_groups, print_info - ) - elif self._config_is_group(config): - # This is a group config - tasks_and_groups[config["group"]] = { - "type": "group", - "task": -1, # This signals that - # we don't need to know - # the task list for indexing - # as it can be loaded - # when called. - "yaml_path": yaml_path, - } - - # # Registered the level 1 tasks from a group config - # for config in config["task"]: - # if isinstance(config, dict) and self._config_is_task(config): - # task = config["task"] - # tasks_and_groups[task] = { - # "type": "task", - # "yaml_path": yaml_path, - # } - - elif self._config_is_task(config): - # This is a task config - task = config["task"] - if task in tasks_and_groups: - eval_logger.warning( - f"Duplicate task name '{task}' found. " - f"Already registered from: {tasks_and_groups[task]['yaml_path']}. " - f"Skipping duplicate from: {yaml_path}" - ) - continue - tasks_and_groups[task] = { - "type": "task", - "yaml_path": yaml_path, - } - _populate_tags_and_groups( - config, task, tasks_and_groups, print_info - ) - else: - eval_logger.debug(f"File {f} in {root} could not be loaded") - - return tasks_and_groups +__all__ = [ + "ConfigurableTask", + "TaskManager", + "get_task_dict", + "get_task_name_from_config", + "get_task_name_from_object", +] def get_task_name_from_config(task_config: dict[str, str]) -> str: diff --git a/lm_eval/tasks/_config_loader.py b/lm_eval/tasks/_config_loader.py new file mode 100644 index 00000000000..ee0335b3500 --- /dev/null +++ b/lm_eval/tasks/_config_loader.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from typing import Any + +import yaml + + +_Base = ( + yaml.CSafeLoader if getattr(yaml, "__with_libyaml__", False) else yaml.FullLoader +) +_IGNORE_DIRS = {"__pycache__", ".ipynb_checkpoints"} + + +def _mk_function_ctor(base_dir: Path, resolve: bool): + def ctor(loader: yaml.Loader, node: yaml.Node): + spec = loader.construct_scalar(node) # type: ignore[arg-type] + if not resolve: + return str(base_dir.expanduser() / spec) + return _import_func_in_yml(spec, base_dir) + + return ctor + + +def _make_loader(base_dir: Path, *, resolve_funcs: bool) -> type[yaml.Loader]: + class Loader(_Base): ... # type: ignore[no-redef] + + yaml.add_constructor( + "!function", + _mk_function_ctor(base_dir, resolve_funcs), + Loader=Loader, + ) + return Loader + + +def _load_module_with_cache(module_path: Path) -> Any: + """Load a module from a file path with caching and hot-reload support. + + Args: + module_path: Path to the Python file to load + + Returns: + The loaded module + """ + # Determine module name based on location + path_str = str(module_path) + + # Check if this is a built-in task module + if "/lm_eval/tasks/" in path_str: + # Find the position of lm_eval/tasks/ in the path + tasks_idx = path_str.find("/lm_eval/tasks/") + if tasks_idx != -1: + # Extract path starting from lm_eval/tasks/ + # e.g., /path/to/lm_eval/tasks/hellaswag/utils.py → hellaswag/utils.py + relative_path = path_str[tasks_idx + len("/lm_eval/tasks/") :] + # Remove .py and convert to module name + # e.g., hellaswag/utils.py → lm_eval.tasks.hellaswag.utils + module_parts = relative_path.replace(".py", "").replace("/", ".") + module_name = f"lm_eval.tasks.{module_parts}" + else: + # Fallback to a full path if a pattern not found + module_name = str(module_path.with_suffix("")) + else: + # External module - use a full path without extension + module_name = str(module_path.with_suffix("")) + + # Check if we need to reload the module + if module_name in sys.modules: + existing_module = sys.modules[module_name] + # Check if it was modified + current_mtime = module_path.stat().st_mtime_ns + if ( + hasattr(existing_module, "__mtime__") + and existing_module.__mtime__ == current_mtime + ): + # Module hasn't changed, reuse it + return existing_module + + # Load or reload the module + spec = importlib.util.spec_from_file_location(module_name, module_path) + if spec is None or spec.loader is None: + raise ImportError(f"Cannot load module from {module_path}") from None + module = importlib.util.module_from_spec(spec) + # Store mtime for future checks + module.__mtime__ = module_path.stat().st_mtime_ns # type: ignore + spec.loader.exec_module(module) # type: ignore[arg-type] + sys.modules[module_name] = module + return module + + +def _import_func_in_yml(qual: str, base_dir: Path): + """Import function from qual: utils.process_doc, checking local files first then standard imports. + + Args: + qual: Qualified function name (e.g., 'utils.process_doc') + base_dir: Directory to search for local modules + """ + mod_path, _, fn_name = qual.rpartition(".") + # 1) relative "utils.py" next to YAML + rel = (base_dir / f"{mod_path.replace('.', '/')}.py").resolve() + if rel.exists(): + module = _load_module_with_cache(rel) + return getattr(module, fn_name) + + # 2) already-importable module + module = __import__(mod_path, fromlist=[fn_name]) + return getattr(module, fn_name) + + +def _import_fun_from_str(path_str: str) -> Any: + """Import a function from a string in the form '/absolute/path/to/module.function_name'.""" + try: + # Split off the function name from the rightmost dot + module_path_str, function_name = path_str.rsplit(".", 1) + except ValueError as e: + raise ValueError( + f"Invalid path format: {path_str}. Expected format: /path/to/module.function_name" + ) from e + + # Convert to Path and handle .py extension + module_path = Path(module_path_str) + if not module_path.suffix: + module_path = module_path.with_suffix(".py") + elif module_path.suffix != ".py": + # If it has a non-.py suffix, the user might have included .py in the path + # e.g., "/path/to/module.py.function_name" + base_path = module_path.with_suffix("") + if base_path.with_suffix(".py").exists(): + module_path = base_path.with_suffix(".py") + + if not module_path.exists(): + raise ImportError(f"Module file not found: {module_path}") + + module = _load_module_with_cache(module_path) + + if not hasattr(module, function_name): + raise AttributeError( + f"Function '{function_name}' not found in module {module_path}" + ) + + return getattr(module, function_name) + + +def load_yaml( + path: str | Path, + *, + resolve_func: bool = True, + recursive: bool = True, + _seen: set[Path] | None = None, +) -> dict[str, Any]: + """Pure data-loading helper. + Returns a dict ready for higher-level interpretation. + •No task/group/tag semantics here. + """ + path = Path(path).expanduser().resolve() + if _seen is None: + _seen = set() + if path in _seen: + raise ValueError(f"Include cycle at {path}") + _seen.add(path) + + loader_cls = _make_loader(path.parent, resolve_funcs=resolve_func) + with path.open("rb") as fh: + cfg = yaml.load(fh, Loader=loader_cls) + + if not recursive or "include" not in cfg: + return cfg + else: + includes = cfg.pop("include") + + merged = {} + for inc in includes if isinstance(includes, list) else [includes]: + inc_path = (path.parent / inc) if not Path(inc).is_absolute() else Path(inc) + inc_cfg = load_yaml( + inc_path, + resolve_func=resolve_func, + recursive=True, + _seen=_seen, + ) + # Don't inherit task_list - it defines tasks for the included file only + inc_cfg.pop("task_list", None) + merged.update(inc_cfg) + merged.update(cfg) # local keys win + return merged diff --git a/lm_eval/tasks/factory.py b/lm_eval/tasks/factory.py new file mode 100644 index 00000000000..4851cfe2610 --- /dev/null +++ b/lm_eval/tasks/factory.py @@ -0,0 +1,261 @@ +from __future__ import annotations + +import inspect +from collections.abc import Mapping +from copy import deepcopy +from typing import Any + +from lm_eval.api.group import ConfigurableGroup, GroupConfig +from lm_eval.api.task import ConfigurableTask +from lm_eval.tasks._config_loader import load_yaml +from lm_eval.tasks.index import Entry, Kind + + +class TaskFactory: + """ + Turns a *Entry* (plus optional overrides) into a + *Task* | *ConfigurableTask* | *GroupConfig* hierarchy. + """ + + def __init__(self, *, meta: dict[str, Any] | None = None): + self._meta = meta or {} + + # ---------------------------------------------------------------- public API + def build( + self, + entry: Entry, + *, + overrides: dict[str, Any] | None = None, + registry: Mapping[str, Entry], + ): + """ + * entry.kind == TASK / PY_TASK -> returns instantiated task object + * entry.kind == GROUP -> returns (GroupConfig, mapping-of-subtasks) + * entry.kind == TAG -> returns mapping-of-tasks (tag expansion) + * entry with ref_target -> resolves reference and builds target + * entry with tag_ref -> expands tag and builds tasks + """ + # Handle external references (ref: in children) + if entry.ref_target: + if entry.ref_target not in registry: + raise KeyError( + f"Reference '{entry.ref_target}' not found for '{entry.name}'" + ) + target_entry = registry[entry.ref_target] + return self.build(target_entry, overrides=overrides, registry=registry) + + # Handle tag expansion (tag: in children) + if entry.tag_ref: + if entry.tag_ref not in registry: + raise KeyError(f"Tag '{entry.tag_ref}' not found for '{entry.name}'") + tag_entry = registry[entry.tag_ref] + return self._build_tag(tag_entry, overrides, registry) + + if entry.kind is Kind.TAG: + return self._build_tag(entry, overrides, registry) + + if entry.kind is Kind.GROUP: + return self._build_group(entry, overrides, registry) + + return self._build_task(entry, overrides) + + def _build_task(self, entry: Entry, overrides: dict[str, Any] | None): + """Build a task and return it wrapped in a dict {task_name: task_obj}.""" + cfg = self._load_full_config(entry, overrides) + + # Remove structural keys that aren't part of task config + for key in ("children", "ref", "tag", "group"): + cfg.pop(key, None) + + # Use cfg["task"] as key (may be overridden, e.g., for namespacing) + task_name = cfg["task"] + + if "class" in cfg: # PY_TASK route + cls = cfg["class"] + obj = cls(config=cfg) if _ctor_accepts_config(cls) else cls() + if hasattr(obj, "config") and hasattr(obj.config, "task"): + obj.config.task = task_name + return {task_name: obj} + + # Regular YAML task - use ConfigurableTask + task_obj = ConfigurableTask(config=cfg) + return {task_name: task_obj} + + def _build_group( + self, + entry: Entry, + overrides: dict[str, Any] | None, + registry: Mapping[str, Entry], + ): + raw_cfg = self._load_full_config(entry, None) + grp_cfg = {k: v for k, v in raw_cfg.items() if k in GroupConfig.__annotations__} + grp_cfg["metadata"] = grp_cfg.get("metadata", {}) | self._meta + group_obj = ConfigurableGroup(config=grp_cfg) + group_name = entry.name + + children: dict[str, Any] = {} + + # Handle new-style children: dict (hierarchical) + if "children" in raw_cfg: + children.update( + self._build_children( + raw_cfg["children"], group_name, overrides, registry + ) + ) + + # Handle old-style task: list (backward compatibility) + if "task" in grp_cfg and isinstance(grp_cfg["task"], list): + children.update( + self._build_task_list(grp_cfg["task"], group_name, overrides, registry) + ) + + return {group_obj: children} + + def _build_children( + self, + children_cfg: dict[str, Any], + group_name: str, + overrides: dict[str, Any] | None, + registry: Mapping[str, Entry], + ) -> dict[str, Any]: + """Build children defined via children: dict.""" + result: dict[str, Any] = {} + + for child_name, child_cfg in children_cfg.items(): + child_path = f"{group_name}::{child_name}" + + # Look up pre-registered entry from index + if child_path in registry: + child_entry = registry[child_path] + child_overrides = overrides or {} + + # Merge any inline overrides from child_cfg (excluding structural keys) + inline_overrides = { + k: v + for k, v in child_cfg.items() + if k not in ("ref", "tag", "children") + } + if inline_overrides: + child_overrides = {**child_overrides, **inline_overrides} + + child = self.build( + child_entry, overrides=child_overrides, registry=registry + ) + result.update(child) + else: + # Fallback: inline task not pre-registered (shouldn't normally happen) + task_cfg = {**child_cfg, "task": child_path} + task_cfg["metadata"] = task_cfg.get("metadata", {}) | self._meta + result[child_path] = ConfigurableTask(config=task_cfg) + + return result + + def _build_task_list( + self, + task_list: list, + group_name: str, + overrides: dict[str, Any] | None, + registry: Mapping[str, Entry], + ) -> dict[str, Any]: + """Build children defined via task: list (backward compatibility).""" + result: dict[str, Any] = {} + + for item in task_list: + # Step 1: Normalize - extract base_name and item_overrides + if isinstance(item, str): + base_name = item + item_overrides = overrides or {} + elif isinstance(item, dict): + base_name = item["task"] + item_overrides = {**overrides, **item} + else: + raise TypeError( + f"Unsupported sub-entry {item!r} in group '{group_name}'" + ) + + # Step 2: Handle inline task (not in registry) + if base_name not in registry: + namespaced = f"{group_name}::{base_name}" + task_cfg = {**item_overrides, "task": namespaced} + task_cfg["metadata"] = task_cfg.get("metadata", {}) | self._meta + result[namespaced] = ConfigurableTask(config=task_cfg) + continue + + # Step 3: Build based on entry kind + child_entry = registry[base_name] + + if child_entry.kind is Kind.GROUP: + child = self.build( + child_entry, overrides=item_overrides, registry=registry + ) + elif child_entry.kind is Kind.TAG: + child = {} + for task_name in child_entry.tags: + namespaced = f"{group_name}::{task_name}" + child.update( + self.build( + registry[task_name], + overrides={"task": namespaced, **item_overrides}, + registry=registry, + ) + ) + else: # TASK or PY_TASK + namespaced = f"{group_name}::{base_name}" + child = self.build( + child_entry, + overrides={"task": namespaced, **item_overrides}, + registry=registry, + ) + + result.update(child) + + return result + + def _build_tag( + self, + entry: Entry, + overrides: dict[str, Any] | None, + registry: Mapping[str, Entry], + ): + """Build all tasks in a tag and return merged dict.""" + result = {} + for name in entry.tags: + result.update(self._build_task(registry[name], overrides)) + return result + + def _load_full_config( + self, entry: Entry, overrides: dict[str, Any] | None + ) -> dict[str, Any]: + # For inline children (have parent), use the stored cfg directly + # instead of loading from YAML (which would load the parent's full config) + if entry.parent and entry.cfg: + cfg = deepcopy(entry.cfg) + elif entry.yaml_path: + cfg = deepcopy(load_yaml(entry.yaml_path, resolve_func=True)) + else: + cfg: dict[str, Any] = { + "metadata": {"config": "unknown"} + } # python task without YAML + + # Handle task_list configs - merge base config with per-task overrides + if "task_list" in cfg: + task_list = cfg.pop("task_list") + # Find the entry for this task in task_list + for item in task_list: + if isinstance(item, dict) and item.get("task") == entry.name: + # Merge per-task overrides + cfg = {**cfg, **item} + break + + if overrides: + cfg = {**cfg, **overrides} + cfg["metadata"] = ( + m if isinstance(m := cfg.get("metadata", {}), dict) else {"_metadata": m} + ) | self._meta + cfg.setdefault("task", entry.name) + return cfg + + +def _ctor_accepts_config(cls) -> bool: + init = getattr(cls, "__init__", None) + return bool(init and "config" in inspect.signature(init).parameters) diff --git a/lm_eval/tasks/index.py b/lm_eval/tasks/index.py new file mode 100644 index 00000000000..9adaeb3839a --- /dev/null +++ b/lm_eval/tasks/index.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from enum import Enum, auto +from typing import TYPE_CHECKING, Any + +from lm_eval.tasks._config_loader import load_yaml as load_cfg + + +if TYPE_CHECKING: + from collections.abc import Iterable + from pathlib import Path + + +class Kind(Enum): + TASK = auto() # YAML task, or task_list entry + PY_TASK = auto() # Python-defined, via "class" + GROUP = auto() + TAG = auto() + TASK_LIST = auto() + + +@dataclass +class Entry: + name: str + kind: Kind + yaml_path: Path | None # None for generated / py-only entries + cfg: dict[str, str] | None = None + tags: set[str] = field(default_factory=set) + task_list_path: Path | None = None + # Hierarchical task support + parent: str | None = ( + None # parent path for inline children (e.g., "mmlu" for "mmlu::stem") + ) + ref_target: str | None = None # for children with ref: points to external entry + tag_ref: str | None = None # for children with tag: expands to tagged tasks + + +log = logging.getLogger(__name__) +_IGNORE_DIRS = {"__pycache__", ".ipynb_checkpoints"} + + +class TaskIndex: + """Walks one or more directories, parses YAML quickly (functions unresolved), + and produces a mapping {task_name: Entry}. + """ + + def __init__(self, *, meta: dict[str, str] | None = None) -> None: + self._metadata = meta or {} + + def build( + self, + paths: Iterable[Path], + *, + resolve_includes=True, + ) -> dict[str, Entry]: + index: dict[str, Entry] = {} + log.debug("Building task index from %s", paths) + for root in paths: + for yaml_path in self._iter_yaml_files(root): + try: + cfg = load_cfg( + yaml_path, + resolve_func=False, + recursive=resolve_includes, + ) + self.process_cfg(cfg, yaml_path, index) + except Exception as err: + log.debug("Skip %s (%s)", yaml_path, err) + continue + + # self._process_cfg(cfg, yaml_path, index) + log.debug("Built task index with %d entries", len(index)) + return index + + @staticmethod + def _iter_yaml_files(root: Path): + # Sort for deterministic traversal order across filesystems + yield from sorted( + ( + p + for p in root.glob("**/*.yaml") + if not any(part in _IGNORE_DIRS for part in p.parts) + ), + key=lambda p: p.as_posix(), + ) + + @staticmethod + def process_cfg( + cfg: dict[str, Any], + path: Path, + index: dict[str, Entry], + parent_path: str | None = None, + ) -> None: + kind = TaskIndex._kind_of(cfg) + if kind is Kind.GROUP: + grp_name = cfg["group"] + # Build full path for hierarchical addressing + full_path = f"{parent_path}::{grp_name}" if parent_path else grp_name + + if full_path in index: + log.debug( + f"Duplicate group name '{full_path}' found. " + f"Already registered from: {index[full_path].yaml_path}. " + f"Skipping duplicate from: {path}" + ) + return + index[full_path] = Entry( + name=full_path, + kind=Kind.GROUP, + yaml_path=path, + tags=TaskIndex._str_to_set(cfg.get("tag")), + cfg=cfg, + parent=parent_path, + ) + + # Process inline children if present + if "children" in cfg: + TaskIndex._process_children(cfg["children"], full_path, path, index) + return + + if kind is Kind.TASK or kind is Kind.PY_TASK: + name = cfg["task"] + if name in index: + log.warning( + f"Duplicate task name '{name}' found. " + f"Already registered from: {index[name].yaml_path}. " + f"Skipping duplicate from: {path}" + ) + return + index[name] = Entry( + name=name, + kind=Kind.TASK, + yaml_path=path, + tags=TaskIndex._str_to_set(cfg.get("tag")), + cfg=cfg, + ) + TaskIndex._register_tags(name, cfg.get("tag"), index) + return + + if kind is Kind.TASK_LIST: + # If config also has a top-level "task", register it as the base task + if "task" in cfg and isinstance(cfg["task"], str): + base_name = cfg["task"] + if base_name not in index: + index[base_name] = Entry( + name=base_name, + kind=Kind.TASK, + yaml_path=path, + tags=TaskIndex._str_to_set(cfg.get("tag")), + cfg=cfg, + ) + TaskIndex._register_tags(base_name, cfg.get("tag"), index) + + # Register each task in task_list + base_tag = cfg.get("tag") + for entry in cfg["task_list"]: + task_name = entry["task"] if isinstance(entry, dict) else entry + if task_name in index: + log.warning( + f"Duplicate task name '{task_name}' found. " + f"Already registered from: {index[task_name].yaml_path}. " + f"Skipping duplicate from: {path}" + ) + continue + # Combine base tag with per-entry tag + entry_tag = entry.get("tag") if isinstance(entry, dict) else None + combined_tags = TaskIndex._str_to_set(base_tag) | TaskIndex._str_to_set( + entry_tag + ) + index[task_name] = Entry( + name=task_name, + kind=Kind.TASK, + yaml_path=path, + tags=combined_tags, + cfg=cfg, + ) + # Register both base config's tag and per-entry tag + TaskIndex._register_tags(task_name, base_tag, index) + TaskIndex._register_tags(task_name, entry_tag, index) + return + + @staticmethod + def _register_tags( + task: str, + tags: str | list[str] | None, + index: dict[str, Entry], + ) -> None: + if not tags: + return + for tag in tags if isinstance(tags, list) else [tags]: + entry = index.setdefault( + tag, + Entry(name=tag, kind=Kind.TAG, yaml_path=None, tags=set()), + ) + entry.tags.add(task) + + @staticmethod + def _kind_of(cfg: dict) -> Kind: + if "class" in cfg: + return Kind.PY_TASK + if "group" in cfg: + return Kind.GROUP + if "task_list" in cfg: + return Kind.TASK_LIST + if "task" in cfg: + return Kind.GROUP if isinstance(cfg["task"], list) else Kind.TASK + msg = "Unknown config shape" + raise ValueError(msg) from None + + @staticmethod + def _str_to_set(tags: str | list[str] | None = None) -> set[str]: + """Convert a string or list of strings to a set of strings.""" + return ( + set(tags) + if isinstance(tags, list) + else {tags} + if isinstance(tags, str) + else set() + ) + + @staticmethod + def _process_children( + children: dict[str, Any], + parent_path: str, + yaml_path: Path, + index: dict[str, Entry], + ) -> None: + """Process inline children definitions within a group. + + Children can be: + - Inline task: dict with task config fields (dataset_path, doc_to_text, etc.) + - Inline subgroup: dict with 'children' key + - External ref: dict with 'ref' key pointing to existing entry + - Tag expansion: dict with 'tag' key to expand tagged tasks + """ + for child_name, child_cfg in children.items(): + if not isinstance(child_cfg, dict): + log.warning( + f"Invalid child config for '{child_name}' in '{parent_path}': " + f"expected dict, got {type(child_cfg).__name__}" + ) + continue + + child_path = f"{parent_path}::{child_name}" + + if child_path in index: + log.debug(f"Duplicate child '{child_path}' found, skipping.") + continue + + if "ref" in child_cfg: + # External reference - register with ref_target for build-time resolution + index[child_path] = Entry( + name=child_path, + kind=Kind.GROUP, # Assume group, will resolve at build time + yaml_path=yaml_path, + parent=parent_path, + ref_target=child_cfg["ref"], + cfg=child_cfg, + tags=TaskIndex._str_to_set(child_cfg.get("tag")), + ) + + elif "tag" in child_cfg: + # Tag expansion - register with tag_ref for build-time expansion + index[child_path] = Entry( + name=child_path, + kind=Kind.TAG, + yaml_path=yaml_path, + parent=parent_path, + tag_ref=child_cfg["tag"], + cfg=child_cfg, + tags=TaskIndex._str_to_set(child_cfg.get("tag")), + ) + + elif "children" in child_cfg: + # Nested inline group - recurse + nested_cfg = {**child_cfg, "group": child_name} + TaskIndex.process_cfg( + nested_cfg, yaml_path, index, parent_path=parent_path + ) + + else: + # Inline task definition + task_cfg = {**child_cfg, "task": child_path} + index[child_path] = Entry( + name=child_path, + kind=Kind.TASK, + yaml_path=yaml_path, + parent=parent_path, + cfg=task_cfg, + tags=TaskIndex._str_to_set(child_cfg.get("tag")), + ) + # Register tags for inline tasks + TaskIndex._register_tags(child_path, child_cfg.get("tag"), index) diff --git a/lm_eval/tasks/manager.py b/lm_eval/tasks/manager.py new file mode 100644 index 00000000000..80f84517f52 --- /dev/null +++ b/lm_eval/tasks/manager.py @@ -0,0 +1,246 @@ +from __future__ import annotations + +from collections import defaultdict +from itertools import chain +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from lm_eval import utils +from lm_eval.tasks.factory import TaskFactory +from lm_eval.tasks.index import Entry, Kind, TaskIndex +from lm_eval.utils import setup_logging + + +if TYPE_CHECKING: + from lm_eval.api.task import Task + + +class TaskManager: + """Discovers, indexes, and loads evaluation tasks from YAML configs. + + Scans directories for task definitions and provides methods to load them + by name, glob pattern, or inline config. Handles groups, tags, and task + namespacing (e.g., "mmlu_humanities::formal_logic"). + """ + + def __init__( + self, + verbosity: str | None = None, + include_path: str | Path | list[str | Path] | None = None, + include_defaults: bool = True, + metadata: dict[str, dict[str, Any]] | None = None, + ) -> None: + """ + Args: + verbosity: Logging level (e.g., "INFO", "DEBUG") + include_path: Custom paths to scan for task configs (takes precedence) + include_defaults: Whether to include built-in tasks from lm_eval/tasks/ + metadata: Extra metadata to attach to all loaded tasks + """ + if verbosity: + setup_logging(verbosity) + + self.include_path = include_path + self.metadata = metadata + + index = TaskIndex() + self._factory = TaskFactory(meta=metadata) + + all_paths: list[Path] = [] + # Process include_path FIRST so user tasks take precedence over defaults + if include_path: + all_paths += [ + Path(p) + for p in ( + include_path + if isinstance(include_path, (list, tuple)) + else [include_path] + ) + ] + if include_defaults: + all_paths.append(Path(__file__).parent) + + self._index = index.build(all_paths) + + buckets = defaultdict(list) + for k, e in self._index.items(): + buckets[e.kind].append(k) + + self._all_tasks = sorted(self._index.keys()) + self._all_subtasks = sorted( + chain.from_iterable(buckets[k] for k in {Kind.TASK, Kind.PY_TASK}) + ) + self._all_groups = sorted(buckets[Kind.GROUP]) + self._all_tags = sorted(buckets[Kind.TAG]) + + # ---------------------------------------------------------------- properties + @property + def all_tasks(self) -> list[str]: + """All registered names (tasks, groups, tags).""" + return self._all_tasks + + @property + def all_groups(self) -> list[str]: + """All group names (e.g., "mmlu", "arc").""" + return self._all_groups + + @property + def all_subtasks(self) -> list[str]: + """All individual task names (YAML and Python tasks).""" + return self._all_subtasks + + @property + def all_tags(self) -> list[str]: + """All tag names (e.g., "ai2_arc", "mmlu_humanities_tasks").""" + return self._all_tags + + @property + def task_index(self) -> dict[str, Entry]: + """Raw index mapping names to Entry objects.""" + return self._index + + # ---------------------------------------------------------------- name checks + def _name_is_registered(self, name: str) -> bool: + return name in self._index + + def _name_is_task(self, name: str) -> bool: + return self._name_is_registered(name) and self._index[name].kind == Kind.TASK + + def _name_is_tag(self, name: str) -> bool: + return self._name_is_registered(name) and self._index[name].kind == Kind.TAG + + def _name_is_group(self, name: str) -> bool: + return self._name_is_registered(name) and self._index[name].kind == Kind.GROUP + + def _name_is_python_task(self, name: str) -> bool: + return self._name_is_registered(name) and self._index[name].kind == Kind.PY_TASK + + # ---------------------------------------------------------------- utility + def match_tasks(self, task_list: list[str]) -> list[str]: + """Match task names using glob patterns.""" + return utils.pattern_match(task_list, self.all_tasks) + + def list_all_tasks( + self, + list_groups: bool = True, + list_tags: bool = True, + list_subtasks: bool = True, + ) -> str: + """Generate a markdown table listing all available tasks.""" + from pytablewriter import MarkdownTableWriter + + def sanitize_path(path): + if path is None: + return "---" + path_str = str(path) + if "lm_eval/tasks/" in path_str: + return "lm_eval/tasks/" + path_str.split("lm_eval/tasks/")[-1] + return path_str + + group_table = MarkdownTableWriter() + group_table.headers = ["Group", "Config Location"] + gt_values = [] + for g in self.all_groups: + entry = self._index[g] + path = sanitize_path(entry.yaml_path) + gt_values.append([g, path]) + group_table.value_matrix = gt_values + + tag_table = MarkdownTableWriter() + tag_table.headers = ["Tag"] + tag_table.value_matrix = [[t] for t in self.all_tags] + + subtask_table = MarkdownTableWriter() + subtask_table.headers = ["Task", "Config Location", "Output Type"] + st_values = [] + for t in self.all_subtasks: + entry = self._index[t] + path = entry.yaml_path + output_type = "" + + if path is not None: + config = utils.load_yaml_config(str(path), mode="simple") + if "output_type" in config: + output_type = config["output_type"] + elif "include" in config: + include_path = str(path.parent / config["include"]) + include_config = utils.load_yaml_config(include_path, mode="simple") + if "output_type" in include_config: + output_type = include_config["output_type"] + + path = sanitize_path(path) + st_values.append([t, path, output_type]) + subtask_table.value_matrix = st_values + + result = "\n" + if list_groups: + result += group_table.dumps() + "\n\n" + if list_tags: + result += tag_table.dumps() + "\n\n" + if list_subtasks: + result += subtask_table.dumps() + "\n\n" + return result + + # ---------------------------------------------------------------- core API + def _entry(self, name: str) -> Entry: + if name not in self._index: + raise KeyError(f"Unknown task/group/tag: {name}") + return self._index[name] + + def load_spec(self, spec: str | dict[str, Any]): + """Load a task/group/tag by name or with inline overrides. + + Args: + spec: Task name (str) or dict with "task" key and overrides + + Returns: + Dict mapping task names to task objects (nested for groups) + """ + if isinstance(spec, str): + entry = self._entry(spec) + return self._factory.build(entry, overrides=None, registry=self._index) + + if isinstance(spec, dict): + # inline dict => find base entry, then pass overrides + name = spec["task"] + entry = self._entry(name) + return self._factory.build(entry, overrides=spec, registry=self._index) + + raise TypeError("spec must be str or dict") + + def load_task_or_group(self, task_list: str | list[str]) -> dict: + """Load tasks/groups and return a merged dictionary. + + :param task_list: Single task name or list of task names + :return: Dictionary of task objects (possibly nested for groups) + """ + import collections + + if isinstance(task_list, str): + task_list = [task_list] + + # Each load_spec call returns a dict (possibly nested for groups) + # We merge them using ChainMap (like the original implementation) + return dict(collections.ChainMap(*[self.load_spec(s) for s in task_list])) + + def load_config(self, config: dict) -> dict: + """Load a task from an inline config dict.""" + return self.load_spec(config) + + +def get_task_dict( + task_name_list: str | list[str | dict | Task], + task_manager: TaskManager | None = None, +): + """Helper to load multiple tasks into a dict. Creates TaskManager if not provided.""" + if not task_manager: + task_manager = TaskManager() + else: + assert isinstance(task_manager, TaskManager) + + return { + task_name: task_manager.load_spec(task_name) + if isinstance(task_name, str) + else task_name + for task_name in task_name_list + } diff --git a/tests/test_configs/group.yaml b/tests/test_configs/group.yaml new file mode 100644 index 00000000000..aff5e5d6036 --- /dev/null +++ b/tests/test_configs/group.yaml @@ -0,0 +1,32 @@ +# Group configuration demonstrating task collections + +group: test_group +task: + - task: group_task_fs0 + dataset_path: json + dataset_kwargs: + data_files: + test: tests/test_configs/test_data.json + output_type: multiple_choice + doc_to_text: "{{question}}" + doc_to_target: "{{choices[answer]}}" + test_split: test + num_fewshot: 0 + metric_list: + - metric: acc + aggregation: mean + higher_is_better: true + - task: group_task_fs2 + dataset_path: json + dataset_kwargs: + data_files: + test: tests/test_configs/test_data.json + output_type: multiple_choice + doc_to_text: "{{question}}" + doc_to_target: "{{choices[answer]}}" + test_split: test + num_fewshot: 2 + metric_list: + - metric: acc + aggregation: mean + higher_is_better: true diff --git a/tests/test_configs/include_base.yaml b/tests/test_configs/include_base.yaml new file mode 100644 index 00000000000..37c8c9988a7 --- /dev/null +++ b/tests/test_configs/include_base.yaml @@ -0,0 +1,20 @@ +# Base configuration for include walkthrough tests +# This will be included by other configs to demonstrate inheritance + +task: base_task # This should be overridden by including configs +dataset_path: json +dataset_kwargs: + data_files: + test: tests/test_configs/test_data.json +output_type: multiple_choice +doc_to_text: "{{question}}" +doc_to_target: "{{choices[answer]}}" +test_split: test +num_fewshot: 0 # Default, can be overridden +metric_list: + - metric: acc + aggregation: mean + higher_is_better: true +metadata: + version: 1.0 + description: "Base config for include demonstration" diff --git a/tests/test_configs/include_group.yaml b/tests/test_configs/include_group.yaml new file mode 100644 index 00000000000..4eabb380626 --- /dev/null +++ b/tests/test_configs/include_group.yaml @@ -0,0 +1,8 @@ +# Group with multiple tasks using include inheritance +# Demonstrates tasks sharing the same base config + +group: include_group +task: + - include_task_fs0 + - include_task_fs1 + - include_task_fs5 diff --git a/tests/test_configs/include_task_fs0.yaml b/tests/test_configs/include_task_fs0.yaml new file mode 100644 index 00000000000..f91d1678ad7 --- /dev/null +++ b/tests/test_configs/include_task_fs0.yaml @@ -0,0 +1,6 @@ +# Task demonstrating include inheritance + +task: include_task_fs0 +include: include_base.yaml +num_fewshot: 0 +description: "Zero-shot with inheritance" diff --git a/tests/test_configs/include_task_fs1.yaml b/tests/test_configs/include_task_fs1.yaml new file mode 100644 index 00000000000..cbddea4c05e --- /dev/null +++ b/tests/test_configs/include_task_fs1.yaml @@ -0,0 +1,6 @@ +# Task demonstrating include inheritance + +task: include_task_fs1 +include: include_base.yaml +num_fewshot: 1 +description: "One-shot with inheritance" diff --git a/tests/test_configs/include_task_fs5.yaml b/tests/test_configs/include_task_fs5.yaml new file mode 100644 index 00000000000..d9d61335ed5 --- /dev/null +++ b/tests/test_configs/include_task_fs5.yaml @@ -0,0 +1,13 @@ +# Task demonstrating include inheritance with custom metrics + +task: include_task_fs5 +include: include_base.yaml +num_fewshot: 5 +description: "Five-shot with custom metrics" +metric_list: + - metric: acc + aggregation: mean + higher_is_better: true + - metric: acc_norm + aggregation: mean + higher_is_better: true diff --git a/tests/test_configs/simple_task.yaml b/tests/test_configs/simple_task.yaml new file mode 100644 index 00000000000..9895faae50f --- /dev/null +++ b/tests/test_configs/simple_task.yaml @@ -0,0 +1,20 @@ +# Simple task configuration for walkthrough tests +# Demonstrates basic task loading without any special features + +task: simple_task +dataset_path: json +dataset_kwargs: + data_files: + test: tests/test_configs/test_data.json +output_type: multiple_choice +doc_to_text: "{{question}}" +doc_to_target: "{{choices[answer]}}" +test_split: test +num_fewshot: 1 +metric_list: + - metric: acc + aggregation: mean + higher_is_better: true +metadata: + version: 1.0 + description: "Simple task for basic walkthrough" diff --git a/tests/test_configs/tag_parent_group.yaml b/tests/test_configs/tag_parent_group.yaml new file mode 100644 index 00000000000..716c3b991b1 --- /dev/null +++ b/tests/test_configs/tag_parent_group.yaml @@ -0,0 +1,5 @@ +# Parent group containing a subgroup - simulates mmlu pattern +# Structure: tag_parent_group -> tag_subgroup -> test_tag_tasks (TAG) -> tag_task_1, tag_task_2, tag_task_3 +group: tag_parent_group +task: + - tag_subgroup diff --git a/tests/test_configs/tag_subgroup.yaml b/tests/test_configs/tag_subgroup.yaml new file mode 100644 index 00000000000..03f9dec1dfc --- /dev/null +++ b/tests/test_configs/tag_subgroup.yaml @@ -0,0 +1,5 @@ +# Subgroup that references a TAG - simulates mmlu_humanities pattern +# This group contains a tag reference (test_tag_tasks) which expands to multiple tasks +group: tag_subgroup +task: + - test_tag_tasks diff --git a/tests/test_configs/tag_task_1.yaml b/tests/test_configs/tag_task_1.yaml new file mode 100644 index 00000000000..cf61407a10f --- /dev/null +++ b/tests/test_configs/tag_task_1.yaml @@ -0,0 +1,16 @@ +# Task 1 with tag - simulates mmlu_formal_logic pattern +task: tag_task_1 +tag: test_tag_tasks +dataset_path: json +dataset_kwargs: + data_files: + test: tests/test_configs/test_data.json +output_type: multiple_choice +doc_to_text: "{{question}}" +doc_to_target: "{{choices[answer]}}" +test_split: test +num_fewshot: 0 +metric_list: + - metric: acc + aggregation: mean + higher_is_better: true diff --git a/tests/test_configs/tag_task_2.yaml b/tests/test_configs/tag_task_2.yaml new file mode 100644 index 00000000000..19161360261 --- /dev/null +++ b/tests/test_configs/tag_task_2.yaml @@ -0,0 +1,16 @@ +# Task 2 with tag - simulates mmlu_high_school_history pattern +task: tag_task_2 +tag: test_tag_tasks +dataset_path: json +dataset_kwargs: + data_files: + test: tests/test_configs/test_data.json +output_type: multiple_choice +doc_to_text: "{{question}}" +doc_to_target: "{{choices[answer]}}" +test_split: test +num_fewshot: 0 +metric_list: + - metric: acc + aggregation: mean + higher_is_better: true diff --git a/tests/test_configs/tag_task_3.yaml b/tests/test_configs/tag_task_3.yaml new file mode 100644 index 00000000000..79965fc28dd --- /dev/null +++ b/tests/test_configs/tag_task_3.yaml @@ -0,0 +1,16 @@ +# Task 3 with tag - simulates mmlu_philosophy pattern +task: tag_task_3 +tag: test_tag_tasks +dataset_path: json +dataset_kwargs: + data_files: + test: tests/test_configs/test_data.json +output_type: multiple_choice +doc_to_text: "{{question}}" +doc_to_target: "{{choices[answer]}}" +test_split: test +num_fewshot: 0 +metric_list: + - metric: acc + aggregation: mean + higher_is_better: true diff --git a/tests/test_configs/task_list.yaml b/tests/test_configs/task_list.yaml new file mode 100644 index 00000000000..b9a9dd3e27c --- /dev/null +++ b/tests/test_configs/task_list.yaml @@ -0,0 +1,36 @@ +# Task list configuration for code walkthrough tests +# This demonstrates the task_list feature with shared config and task-specific overrides + +dataset_path: json +dataset_kwargs: + data_files: + test: tests/test_configs/test_data.json +output_type: multiple_choice +doc_to_text: "{{question}}" +doc_to_target: "{{choices[answer]}}" +test_split: test +metric_list: + - metric: acc + aggregation: mean + higher_is_better: true +metadata: + version: 1.0 + description: "Task list walkthrough example" + +task_list: + - task: task_list_fs0 + num_fewshot: 0 + description: "Zero-shot variant" + - task: task_list_fs1 + num_fewshot: 1 + description: "One-shot variant" + - task: task_list_fs3 + num_fewshot: 3 + description: "Three-shot variant" + metric_list: + - metric: acc + aggregation: mean + higher_is_better: true + - metric: acc_norm + aggregation: mean + higher_is_better: true diff --git a/tests/test_configs/test_data.json b/tests/test_configs/test_data.json new file mode 100644 index 00000000000..37fe084f158 --- /dev/null +++ b/tests/test_configs/test_data.json @@ -0,0 +1,15 @@ +[ + { "question": "What is 2+2?", "choices": ["1", "2", "3", "4"], "answer": 3 }, + { "question": "What is 3+3?", "choices": ["4", "5", "6", "7"], "answer": 2 }, + { "question": "What is 4+4?", "choices": ["6", "7", "8", "9"], "answer": 2 }, + { + "question": "What is 5+5?", + "choices": ["8", "9", "10", "11"], + "answer": 2 + }, + { + "question": "What is 6+6?", + "choices": ["10", "11", "12", "13"], + "answer": 2 + } +] diff --git a/tests/test_include_path.py b/tests/test_include_path.py deleted file mode 100644 index 9271a3c8bd7..00000000000 --- a/tests/test_include_path.py +++ /dev/null @@ -1,186 +0,0 @@ -import os - -from lm_eval import tasks - - -def test_include_path_precedence(): - """Test that user-specified include paths take precedence over default paths when tasks have the same name.""" - import tempfile - - # Create a temporary directory for our custom task - with tempfile.TemporaryDirectory() as custom_dir: - # Create a custom arc_easy.yaml that has a different metric - custom_task_content = """task: arc_easy -dataset_path: allenai/ai2_arc -dataset_name: ARC-Easy -output_type: multiple_choice -training_split: train -validation_split: validation -test_split: test -doc_to_text: "Custom Question: {{question}}\\nAnswer:" -doc_to_target: "{{choices.label.index(answerKey)}}" -doc_to_choice: "{{choices.text}}" -metric_list: - - metric: f1 - aggregation: mean - higher_is_better: true -metadata: - version: 2.0 - custom: true -""" - - # Write the custom task file - custom_task_path = os.path.join(custom_dir, "arc_easy.yaml") - with open(custom_task_path, "w") as f: - f.write(custom_task_content) - - # Test 1: User path should override default when include_defaults=True - task_manager = tasks.TaskManager(include_defaults=True, include_path=custom_dir) - - # Load the task - task_dict = task_manager.load_task_or_group(["arc_easy"]) - arc_easy_task = task_dict["arc_easy"] - - # Check that the custom version was loaded (has f1 metric and custom doc_to_text) - assert any( - metric["metric"] == "f1" for metric in arc_easy_task.config["metric_list"] - ), "Custom task should have f1 metric" - assert "Custom Question:" in arc_easy_task.config["doc_to_text"], ( - "Custom task should have custom doc_to_text" - ) - assert arc_easy_task.config["metadata"]["version"] == 2.0, ( - "Custom task should have version 2.0" - ) - - # Test 2: Verify default is used when no custom path is provided - default_task_manager = tasks.TaskManager(include_defaults=True) - default_task_dict = default_task_manager.load_task_or_group(["arc_easy"]) - default_arc_easy = default_task_dict["arc_easy"] - - # Default should not have f1 metric or custom text - assert not any( - metric["metric"] == "f1" - for metric in default_arc_easy.config.get("metric_list", []) - ), "Default task should not have f1 metric" - assert "Custom Question:" not in default_arc_easy.config["doc_to_text"], ( - "Default task should not have custom doc_to_text" - ) - - -def test_include_defaults_false_with_custom_path(): - """Test that when include_defaults=False, only custom tasks are available.""" - import tempfile - - with tempfile.TemporaryDirectory() as custom_dir: - # Create a custom task using a real dataset - custom_task_content = """task: custom_arc_task -dataset_path: allenai/ai2_arc -dataset_name: ARC-Challenge -output_type: multiple_choice -training_split: train -validation_split: validation -test_split: test -doc_to_text: "Q: {{question}}\nA:" -doc_to_target: "{{choices.label.index(answerKey)}}" -doc_to_choice: "{{choices.text}}" -metric_list: - - metric: acc - aggregation: mean - higher_is_better: true -metadata: - version: 1.0 - custom: true -""" - - # Write the custom task file - custom_task_path = os.path.join(custom_dir, "custom_arc_task.yaml") - with open(custom_task_path, "w") as f: - f.write(custom_task_content) - - # Initialize with include_defaults=False - task_manager = tasks.TaskManager( - include_defaults=False, include_path=custom_dir - ) - - # Custom task should be available - assert "custom_arc_task" in task_manager.all_tasks, ( - "Custom task should be available when include_defaults=False" - ) - - # Default tasks should NOT be available - assert "arc_easy" not in task_manager.all_tasks, ( - "Default arc_easy should not be available when include_defaults=False" - ) - assert "arc_challenge" not in task_manager.all_tasks, ( - "Default arc_challenge should not be available when include_defaults=False" - ) - - # Check that only our custom task is present - assert len(task_manager.all_tasks) == 1, ( - f"Should only have 1 task, but found {len(task_manager.all_tasks)}" - ) - - # Check task metadata is correctly loaded - task_info = task_manager.task_index["custom_arc_task"] - assert task_info["type"] == "task" - assert custom_dir in task_info["yaml_path"] - - -def test_include_defaults_true_with_new_tasks(): - """Test that new tasks from include_path are added alongside default tasks.""" - import tempfile - - with tempfile.TemporaryDirectory() as custom_dir: - # Create a completely new task (not overriding any default) - new_task_content = """task: arc_custom_generation -dataset_path: allenai/ai2_arc -dataset_name: ARC-Easy -output_type: generate_until -training_split: train -validation_split: validation -test_split: test -doc_to_text: "Question: {{question}}\nGenerate answer:" -doc_to_target: "{{choices.text[choices.label.index(answerKey)]}}" -generation_kwargs: - max_gen_toks: 50 - temperature: 0.1 - until: - - "\n" -metric_list: - - metric: exact_match - aggregation: mean - higher_is_better: true -metadata: - version: 1.0 - custom_benchmark: true -""" - - # Write the new task file - new_task_path = os.path.join(custom_dir, "arc_custom_generation.yaml") - with open(new_task_path, "w") as f: - f.write(new_task_content) - - # Initialize with include_defaults=True (default behavior) - task_manager = tasks.TaskManager(include_defaults=True, include_path=custom_dir) - - # Both custom and default tasks should be available - assert "arc_custom_generation" in task_manager.all_tasks, ( - "New custom task should be available" - ) - assert "arc_easy" in task_manager.all_tasks, ( - "Default arc_easy should still be available" - ) - assert "arc_challenge" in task_manager.all_tasks, ( - "Default arc_challenge should still be available" - ) - - # Check task metadata - custom_task_info = task_manager.task_index["arc_custom_generation"] - assert custom_task_info["type"] == "task" - assert custom_dir in custom_task_info["yaml_path"] - - # Verify the counts - should have more tasks than just defaults - default_only_manager = tasks.TaskManager(include_defaults=True) - assert len(task_manager.all_tasks) > len(default_only_manager.all_tasks), ( - "Should have more tasks when including custom path" - ) diff --git a/tests/test_task_manager.py b/tests/test_task_manager.py index 00cecc8fd3c..9ff94aa871d 100644 --- a/tests/test_task_manager.py +++ b/tests/test_task_manager.py @@ -1,9 +1,17 @@ +import logging import tempfile from pathlib import Path import pytest from lm_eval.tasks import TaskManager +from lm_eval.tasks._config_loader import load_yaml +from lm_eval.tasks.index import Entry, Kind, TaskIndex + + +# ============================================================================= +# Existing fixtures and tests +# ============================================================================= @pytest.fixture(scope="module") @@ -71,3 +79,927 @@ def test_python_task_inclusion( assert custom_task_tag in task_manager.all_tags # check if it can be loaded by tag (custom_task_tag) assert custom_task_name in task_manager.load_task_or_group(custom_task_tag) + + +# ============================================================================= +# Config Loader Tests +# ============================================================================= + + +class TestConfigLoader: + def test_load_simple_yaml(self, tmp_path): + """Load a basic YAML without includes or functions""" + content = """ +task: simple_test +dataset_path: test_dataset +output_type: generate_until +""" + yaml_path = tmp_path / "simple.yaml" + yaml_path.write_text(content) + + cfg = load_yaml(yaml_path) + + assert cfg["task"] == "simple_test" + assert cfg["dataset_path"] == "test_dataset" + assert cfg["output_type"] == "generate_until" + + def test_load_yaml_with_include(self, tmp_path): + """Load YAML that includes another file""" + base_content = """ +dataset_path: base_dataset +output_type: multiple_choice +num_fewshot: 5 +""" + child_content = """ +include: base.yaml +task: child_task +num_fewshot: 10 +""" + (tmp_path / "base.yaml").write_text(base_content) + (tmp_path / "child.yaml").write_text(child_content) + + cfg = load_yaml(tmp_path / "child.yaml") + + # Child overrides base + assert cfg["task"] == "child_task" + assert cfg["num_fewshot"] == 10 + # Inherited from base + assert cfg["dataset_path"] == "base_dataset" + assert cfg["output_type"] == "multiple_choice" + + def test_load_yaml_with_function_tag_resolved(self, tmp_path): + """Load YAML with !function tag, resolve_func=True""" + utils_content = """ +def my_processor(doc): + return doc +""" + yaml_content = """ +task: func_test +process_docs: !function utils.my_processor +""" + (tmp_path / "utils.py").write_text(utils_content) + (tmp_path / "test.yaml").write_text(yaml_content) + + cfg = load_yaml(tmp_path / "test.yaml", resolve_func=True) + + assert cfg["task"] == "func_test" + assert callable(cfg["process_docs"]) + + def test_load_yaml_without_function_resolution(self, tmp_path): + """Load YAML with !function tag, resolve_func=False (returns path string)""" + yaml_content = """ +task: func_test +process_docs: !function utils.my_processor +""" + (tmp_path / "test.yaml").write_text(yaml_content) + + cfg = load_yaml(tmp_path / "test.yaml", resolve_func=False) + + assert cfg["task"] == "func_test" + # When resolve_func=False, returns path string + assert isinstance(cfg["process_docs"], str) + assert "utils.my_processor" in cfg["process_docs"] + + def test_load_yaml_recursive_includes(self, tmp_path): + """Load YAML with nested includes""" + grandparent = """ +output_type: generate_until +metric_list: + - metric: exact_match +""" + parent = """ +include: grandparent.yaml +dataset_path: parent_dataset +""" + child = """ +include: parent.yaml +task: nested_task +""" + (tmp_path / "grandparent.yaml").write_text(grandparent) + (tmp_path / "parent.yaml").write_text(parent) + (tmp_path / "child.yaml").write_text(child) + + cfg = load_yaml(tmp_path / "child.yaml") + + assert cfg["task"] == "nested_task" + assert cfg["dataset_path"] == "parent_dataset" + assert cfg["output_type"] == "generate_until" + + def test_load_yaml_cycle_detection(self, tmp_path): + """Detect include cycles""" + a_content = """ +include: b.yaml +task: a +""" + b_content = """ +include: a.yaml +task: b +""" + (tmp_path / "a.yaml").write_text(a_content) + (tmp_path / "b.yaml").write_text(b_content) + + with pytest.raises(ValueError, match="Include cycle"): + load_yaml(tmp_path / "a.yaml") + + +# ============================================================================= +# TaskIndex Tests +# ============================================================================= + + +class TestKind: + def test_kind_enum_values(self): + """Verify Kind enum has expected values""" + assert Kind.TASK is not None + assert Kind.PY_TASK is not None + assert Kind.GROUP is not None + assert Kind.TAG is not None + assert Kind.TASK_LIST is not None + + +class TestEntry: + def test_entry_dataclass_fields(self): + """Verify Entry has expected fields""" + entry = Entry( + name="test", + kind=Kind.TASK, + yaml_path=Path("/test.yaml"), + cfg={"task": "test"}, + tags={"tag1"}, + ) + assert entry.name == "test" + assert entry.kind == Kind.TASK + assert entry.yaml_path == Path("/test.yaml") + assert entry.cfg == {"task": "test"} + assert entry.tags == {"tag1"} + + +class TestTaskIndex: + def test_build_from_directory(self, tmp_path): + """Build index from a directory with YAML files""" + task_content = """ +task: test_task +dataset_path: test +output_type: generate_until +""" + (tmp_path / "test_task.yaml").write_text(task_content) + + index = TaskIndex() + result = index.build([tmp_path]) + + assert "test_task" in result + assert result["test_task"].kind == Kind.TASK + + def test_deterministic_traversal(self, tmp_path): + """Verify files are processed in sorted order""" + # Create files that would be in different order without sorting + (tmp_path / "z_task.yaml").write_text("task: z_task\ndataset_path: z") + (tmp_path / "a_task.yaml").write_text("task: a_task\ndataset_path: a") + (tmp_path / "m_task.yaml").write_text("task: m_task\ndataset_path: m") + + index = TaskIndex() + result = index.build([tmp_path]) + + # All tasks should be indexed + assert "a_task" in result + assert "m_task" in result + assert "z_task" in result + + def test_duplicate_task_detection(self, tmp_path, caplog): + """Verify warning logged for duplicate task names""" + # Create subdirectories with duplicate task names + dir1 = tmp_path / "dir1" + dir2 = tmp_path / "dir2" + dir1.mkdir() + dir2.mkdir() + + (dir1 / "task.yaml").write_text("task: duplicate_task\ndataset_path: a") + (dir2 / "task.yaml").write_text("task: duplicate_task\ndataset_path: b") + + index = TaskIndex() + with caplog.at_level(logging.WARNING): + result = index.build([tmp_path]) + + # Only one should be registered + assert "duplicate_task" in result + # Warning should be logged + assert "Duplicate task name" in caplog.text + + def test_duplicate_group_detection(self, tmp_path, caplog): + """Verify debug message logged for duplicate group names""" + dir1 = tmp_path / "dir1" + dir2 = tmp_path / "dir2" + dir1.mkdir() + dir2.mkdir() + + group_content = """ +group: duplicate_group +task: + - task1 +""" + (dir1 / "group.yaml").write_text(group_content) + (dir2 / "group.yaml").write_text(group_content) + + # Also need task1 to exist + (tmp_path / "task1.yaml").write_text("task: task1\ndataset_path: t") + + index = TaskIndex() + with caplog.at_level(logging.DEBUG): + result = index.build([tmp_path]) + + assert "duplicate_group" in result + assert "Duplicate group name" in caplog.text + + def test_kind_detection_task(self): + """Config with 'task' key (string) detected as TASK""" + cfg = {"task": "my_task", "dataset_path": "test"} + kind = TaskIndex._kind_of(cfg) + assert kind == Kind.TASK + + def test_kind_detection_group(self): + """Config with 'group' key detected as GROUP""" + cfg = {"group": "my_group", "task": ["task1", "task2"]} + kind = TaskIndex._kind_of(cfg) + assert kind == Kind.GROUP + + def test_kind_detection_py_task(self): + """Config with 'class' key detected as PY_TASK""" + cfg = {"task": "my_task", "class": "SomeClass"} + kind = TaskIndex._kind_of(cfg) + assert kind == Kind.PY_TASK + + def test_kind_detection_task_list(self): + """Config with 'task_list' key detected as TASK_LIST""" + cfg = {"task_list": [{"task": "task1"}, {"task": "task2"}]} + kind = TaskIndex._kind_of(cfg) + assert kind == Kind.TASK_LIST + + def test_tag_registration(self, tmp_path): + """Tags from tasks are registered in index""" + task_content = """ +task: tagged_task +dataset_path: test +tag: my_custom_tag +""" + (tmp_path / "task.yaml").write_text(task_content) + + index = TaskIndex() + result = index.build([tmp_path]) + + assert "tagged_task" in result + assert "my_custom_tag" in result + assert result["my_custom_tag"].kind == Kind.TAG + assert "tagged_task" in result["my_custom_tag"].tags + + def test_ignore_pycache(self, tmp_path): + """Files in __pycache__ are ignored""" + pycache = tmp_path / "__pycache__" + pycache.mkdir() + (pycache / "task.yaml").write_text("task: should_ignore\ndataset_path: t") + + index = TaskIndex() + result = index.build([tmp_path]) + + assert "should_ignore" not in result + + +# ============================================================================= +# TaskManager Integration Tests +# ============================================================================= + + +# Module-level fixture to avoid re-creating TaskManager for each test +@pytest.fixture(scope="module") +def shared_task_manager(): + """Create a TaskManager with default tasks (shared across module)""" + return TaskManager() + + +@pytest.fixture(scope="module") +def test_configs_task_manager(): + """TaskManager with only test_configs tasks (fast - no default task scanning)""" + test_configs_path = Path(__file__).parent / "test_configs" + return TaskManager(include_path=str(test_configs_path), include_defaults=False) + + +class TestTaskManagerIntegration: + def test_initialization(self, shared_task_manager): + """TaskManager initializes with default tasks""" + assert len(shared_task_manager.all_tasks) > 0 + + def test_all_tasks_sorted(self, shared_task_manager): + """all_tasks returns sorted list""" + tasks = shared_task_manager.all_tasks + assert tasks == sorted(tasks) + + def test_all_groups_property(self, shared_task_manager): + """all_groups returns only groups""" + groups = shared_task_manager.all_groups + assert len(groups) > 0 + for g in groups[:5]: # Check first 5 + assert shared_task_manager._name_is_group(g) + + def test_all_subtasks_property(self, shared_task_manager): + """all_subtasks returns TASK and PY_TASK kinds""" + subtasks = shared_task_manager.all_subtasks + assert len(subtasks) > 0 + for t in subtasks[:5]: # Check first 5 + entry = shared_task_manager.task_index[t] + assert entry.kind in (Kind.TASK, Kind.PY_TASK) + + def test_all_tags_property(self, shared_task_manager): + """all_tags returns only tags""" + tags = shared_task_manager.all_tags + assert len(tags) > 0 + for t in tags[:5]: # Check first 5 + assert shared_task_manager._name_is_tag(t) + + def test_load_task_by_name(self, test_configs_task_manager): + """Load a single task by name""" + result = test_configs_task_manager.load_task_or_group(["simple_task"]) + assert "simple_task" in result + + def test_load_group_by_name(self, test_configs_task_manager): + """Load a group and get nested structure with namespaced task names""" + result = test_configs_task_manager.load_task_or_group(["test_group"]) + # Result is {ConfigurableGroup: {task_name: task_obj}} + # Get the children dict from the group + children = list(result.values())[0] + # test_group contains inline tasks, namespaced as group_name::task_name + assert "test_group::group_task_fs0" in children + assert "test_group::group_task_fs2" in children + + def test_load_tag_by_name(self, shared_task_manager): + """Load all tasks in a tag""" + result = shared_task_manager.load_task_or_group(["ai2_arc"]) + # Should load both arc_easy and arc_challenge + assert "arc_easy" in result + assert "arc_challenge" in result + + def test_include_path(self): + """Custom include_path adds tasks to index using tests/test_configs/""" + test_configs_path = Path(__file__).parent / "test_configs" + tm = TaskManager(include_path=str(test_configs_path), include_defaults=False) + # simple_task is defined in test_configs/simple_task.yaml + assert "simple_task" in tm.all_tasks + + def test_include_defaults_false(self): + """include_defaults=False excludes built-in tasks""" + test_configs_path = Path(__file__).parent / "test_configs" + tm = TaskManager(include_path=str(test_configs_path), include_defaults=False) + # Should have tasks from test_configs + assert "simple_task" in tm.all_tasks + # Built-in tasks like arc_easy should not be present + assert "arc_easy" not in tm.all_tasks + + def test_include_resolution(self): + """Test that includes are properly resolved""" + test_configs_path = Path(__file__).parent / "test_configs" + tm = TaskManager(include_path=str(test_configs_path), include_defaults=False) + # include_task_fs5 includes include_base which has the actual task config + assert "include_task_fs5" in tm.all_tasks + + def test_include_inheritance_override(self): + """Test that child config overrides parent values from include""" + test_configs_path = Path(__file__).parent / "test_configs" + tm = TaskManager(include_path=str(test_configs_path), include_defaults=False) + + # Load the task to get full resolved config + result = tm.load_task_or_group(["include_task_fs5"]) + task_obj = result["include_task_fs5"] + + # include_base has num_fewshot=0, include_task_fs5 overrides to 5 + assert task_obj.config.num_fewshot == 5 + + # include_base has dataset_path=json (inherited) + assert task_obj.config.dataset_path == "json" + + def test_include_custom_metrics(self): + """Test that include_task_fs5 has custom metrics (acc + acc_norm)""" + test_configs_path = Path(__file__).parent / "test_configs" + tm = TaskManager(include_path=str(test_configs_path), include_defaults=False) + + result = tm.load_task_or_group(["include_task_fs5"]) + task_obj = result["include_task_fs5"] + + # include_task_fs5 defines both acc and acc_norm metrics + metric_names = [m["metric"] for m in task_obj.config.metric_list] + assert "acc" in metric_names + assert "acc_norm" in metric_names + + def test_group_loading(self): + """Test that groups are indexed from test_configs""" + test_configs_path = Path(__file__).parent / "test_configs" + tm = TaskManager(include_path=str(test_configs_path), include_defaults=False) + # group.yaml defines a group called 'test_group' + assert "test_group" in tm.all_groups + + def test_include_group(self): + """Test group with tasks sharing same base config via includes""" + test_configs_path = Path(__file__).parent / "test_configs" + tm = TaskManager(include_path=str(test_configs_path), include_defaults=False) + # include_group.yaml: group with include_task_fs0, fs1, fs5 + assert "include_group" in tm.all_groups + # The subtasks should also be indexed + assert "include_task_fs0" in tm.all_tasks + assert "include_task_fs1" in tm.all_tasks + assert "include_task_fs5" in tm.all_tasks + + def test_task_list_loading(self): + """Test task_list feature with shared config and task-specific overrides""" + test_configs_path = Path(__file__).parent / "test_configs" + tm = TaskManager(include_path=str(test_configs_path), include_defaults=False) + # task_list.yaml defines tasks via task_list key + assert "task_list_fs0" in tm.all_tasks + assert "task_list_fs1" in tm.all_tasks + assert "task_list_fs3" in tm.all_tasks + + def test_task_list_overrides(self): + """Test task_list task-specific overrides are applied""" + test_configs_path = Path(__file__).parent / "test_configs" + tm = TaskManager(include_path=str(test_configs_path), include_defaults=False) + + # Load tasks and verify their num_fewshot values + result = tm.load_task_or_group( + ["task_list_fs0", "task_list_fs1", "task_list_fs3"] + ) + + assert result["task_list_fs0"].config.num_fewshot == 0 + assert result["task_list_fs1"].config.num_fewshot == 1 + assert result["task_list_fs3"].config.num_fewshot == 3 + + # task_list_fs3 has custom metrics (acc + acc_norm) + metric_names = [m["metric"] for m in result["task_list_fs3"].config.metric_list] + assert "acc" in metric_names + assert "acc_norm" in metric_names + + def test_task_list_base_field_inheritance(self): + """Test that task_list tasks inherit base fields from the shared config""" + test_configs_path = Path(__file__).parent / "test_configs" + tm = TaskManager(include_path=str(test_configs_path), include_defaults=False) + + result = tm.load_task_or_group(["task_list_fs0"]) + task = result["task_list_fs0"] + + # Base fields should be inherited from the shared config + assert task.config.dataset_path == "json", ( + "Should inherit dataset_path from base" + ) + assert task.config.output_type == "multiple_choice", ( + "Should inherit output_type from base" + ) + assert task.config.doc_to_text == "{{question}}", ( + "Should inherit doc_to_text from base" + ) + assert task.config.test_split == "test", "Should inherit test_split from base" + + # Default metric_list should be inherited (task_list_fs0 doesn't override it) + metric_names = [m["metric"] for m in task.config.metric_list] + assert "acc" in metric_names, "Should inherit metric_list from base" + + # Per-task override should still be applied + assert task.config.num_fewshot == 0, "Should have per-task num_fewshot override" + + def test_match_tasks_glob(self, shared_task_manager): + """match_tasks handles glob patterns""" + matches = shared_task_manager.match_tasks(["arc_*"]) + assert "arc_easy" in matches + assert "arc_challenge" in matches + + def test_name_is_registered(self, shared_task_manager): + """_name_is_registered checks if name exists""" + assert shared_task_manager._name_is_registered("arc_easy") + assert not shared_task_manager._name_is_registered("nonexistent_task_xyz") + + def test_name_is_task(self, shared_task_manager): + """_name_is_task returns True for tasks""" + assert shared_task_manager._name_is_task("arc_easy") + assert not shared_task_manager._name_is_task("ai2_arc") # This is a tag + + def test_name_is_tag(self, shared_task_manager): + """_name_is_tag returns True for tags""" + assert shared_task_manager._name_is_tag("ai2_arc") + assert not shared_task_manager._name_is_tag("arc_easy") # This is a task + + def test_include_path_precedence(self, shared_task_manager): + """Test that user-specified include paths take precedence over default paths when tasks have the same name.""" + with tempfile.TemporaryDirectory() as custom_dir: + # Create a custom arc_easy.yaml that has a different metric + custom_task_content = """task: arc_easy +dataset_path: allenai/ai2_arc +dataset_name: ARC-Easy +output_type: multiple_choice +training_split: train +validation_split: validation +test_split: test +doc_to_text: "Custom Question: {{question}}\\nAnswer:" +doc_to_target: "{{choices.label.index(answerKey)}}" +doc_to_choice: "{{choices.text}}" +metric_list: + - metric: f1 + aggregation: mean + higher_is_better: true +metadata: + version: 2.0 + custom: true +""" + # Write the custom task file + custom_task_path = Path(custom_dir) / "arc_easy.yaml" + custom_task_path.write_text(custom_task_content) + + # Test 1: User path should override default when include_defaults=True + task_manager = TaskManager(include_defaults=True, include_path=custom_dir) + + # Load the task + task_dict = task_manager.load_task_or_group(["arc_easy"]) + arc_easy_task = task_dict["arc_easy"] + + # Check that the custom version was loaded (has f1 metric and custom doc_to_text) + assert any( + metric["metric"] == "f1" + for metric in arc_easy_task.config["metric_list"] + ), "Custom task should have f1 metric" + assert "Custom Question:" in arc_easy_task.config["doc_to_text"], ( + "Custom task should have custom doc_to_text" + ) + assert arc_easy_task.config["metadata"]["version"] == 2.0, ( + "Custom task should have version 2.0" + ) + + # Test 2: Verify default is used when no custom path is provided + # Use shared_task_manager instead of creating a new one (saves ~9s) + default_task_dict = shared_task_manager.load_task_or_group(["arc_easy"]) + default_arc_easy = default_task_dict["arc_easy"] + + # Default should not have f1 metric or custom text + assert not any( + metric["metric"] == "f1" + for metric in default_arc_easy.config.get("metric_list", []) + ), "Default task should not have f1 metric" + assert "Custom Question:" not in default_arc_easy.config["doc_to_text"], ( + "Default task should not have custom doc_to_text" + ) + + def test_include_defaults_false_with_custom_path(self): + """Test that when include_defaults=False, only custom tasks are available.""" + with tempfile.TemporaryDirectory() as custom_dir: + # Create a custom task using a real dataset + custom_task_content = """task: custom_arc_task +dataset_path: allenai/ai2_arc +dataset_name: ARC-Challenge +output_type: multiple_choice +training_split: train +validation_split: validation +test_split: test +doc_to_text: "Q: {{question}}\nA:" +doc_to_target: "{{choices.label.index(answerKey)}}" +doc_to_choice: "{{choices.text}}" +metric_list: + - metric: acc + aggregation: mean + higher_is_better: true +metadata: + version: 1.0 + custom: true +""" + # Write the custom task file + custom_task_path = Path(custom_dir) / "custom_arc_task.yaml" + custom_task_path.write_text(custom_task_content) + + # Initialize with include_defaults=False + task_manager = TaskManager(include_defaults=False, include_path=custom_dir) + + # Custom task should be available + assert "custom_arc_task" in task_manager.all_tasks, ( + "Custom task should be available when include_defaults=False" + ) + + # Default tasks should NOT be available + assert "arc_easy" not in task_manager.all_tasks, ( + "Default arc_easy should not be available when include_defaults=False" + ) + assert "arc_challenge" not in task_manager.all_tasks, ( + "Default arc_challenge should not be available when include_defaults=False" + ) + + # Check that only our custom task is present + assert len(task_manager.all_tasks) == 1, ( + f"Should only have 1 task, but found {len(task_manager.all_tasks)}" + ) + + # Check task metadata using Entry object API + entry = task_manager.task_index["custom_arc_task"] + assert entry.kind == Kind.TASK + assert custom_dir in str(entry.yaml_path) + + def test_include_defaults_true_with_new_tasks(self, shared_task_manager): + """Test that new tasks from include_path are added alongside default tasks.""" + with tempfile.TemporaryDirectory() as custom_dir: + # Create a completely new task (not overriding any default) + new_task_content = """task: arc_custom_generation +dataset_path: allenai/ai2_arc +dataset_name: ARC-Easy +output_type: generate_until +training_split: train +validation_split: validation +test_split: test +doc_to_text: "Question: {{question}}\nGenerate answer:" +doc_to_target: "{{choices.text[choices.label.index(answerKey)]}}" +generation_kwargs: + max_gen_toks: 50 + temperature: 0.1 + until: + - "\n" +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +metadata: + version: 1.0 + custom_benchmark: true +""" + # Write the new task file + new_task_path = Path(custom_dir) / "arc_custom_generation.yaml" + new_task_path.write_text(new_task_content) + + # Initialize with include_defaults=True (default behavior) + task_manager = TaskManager(include_defaults=True, include_path=custom_dir) + + # Both custom and default tasks should be available + assert "arc_custom_generation" in task_manager.all_tasks, ( + "New custom task should be available" + ) + assert "arc_easy" in task_manager.all_tasks, ( + "Default arc_easy should still be available" + ) + assert "arc_challenge" in task_manager.all_tasks, ( + "Default arc_challenge should still be available" + ) + + # Check task metadata using Entry object API + entry = task_manager.task_index["arc_custom_generation"] + assert entry.kind == Kind.TASK + assert custom_dir in str(entry.yaml_path) + + # Verify the counts - should have more tasks than just defaults + assert len(task_manager.all_tasks) > len(shared_task_manager.all_tasks), ( + "Should have more tasks when including custom path" + ) + + def test_tag_expansion_in_group(self, test_configs_task_manager): + """Test that TAGs inside groups are expanded and each task is namespaced individually. + + This tests the MMLU-like structure: GROUP -> TAG -> multiple tasks + Without proper TAG handling, all tasks in the tag get the same namespaced name + and collide, leaving only one task. + """ + # Load the subgroup that contains a TAG reference + result = test_configs_task_manager.load_task_or_group(["tag_subgroup"]) + + # Get the children dict from the group + group_key = list(result.keys())[0] + children = result[group_key] + + # All 3 tasks from the tag should be expanded and namespaced + assert "tag_subgroup::tag_task_1" in children, ( + "tag_task_1 should be namespaced under tag_subgroup" + ) + assert "tag_subgroup::tag_task_2" in children, ( + "tag_task_2 should be namespaced under tag_subgroup" + ) + assert "tag_subgroup::tag_task_3" in children, ( + "tag_task_3 should be namespaced under tag_subgroup" + ) + + # Verify we have exactly 3 tasks (not 1 due to collision) + assert len(children) == 3, ( + f"Should have 3 tasks from TAG expansion, got {len(children)}" + ) + + def test_nested_group_with_tag(self, test_configs_task_manager): + """Test nested groups with TAG: parent_group -> subgroup -> TAG -> tasks. + + This simulates the full MMLU structure where: + - mmlu (GROUP) contains mmlu_humanities (GROUP) + - mmlu_humanities contains mmlu_humanities_tasks (TAG) + - The TAG expands to individual tasks + """ + # Load the parent group + result = test_configs_task_manager.load_task_or_group(["tag_parent_group"]) + + # Navigate the nested structure + parent_key = list(result.keys())[0] + parent_children = result[parent_key] + + # Should contain the subgroup + assert len(parent_children) == 1, "Parent should have 1 child (the subgroup)" + + # Get the subgroup + subgroup_key = list(parent_children.keys())[0] + subgroup_children = parent_children[subgroup_key] + + # The subgroup should have all 3 tasks expanded from the TAG + # Tasks are namespaced under their immediate parent group (tag_subgroup) + assert "tag_subgroup::tag_task_1" in subgroup_children + assert "tag_subgroup::tag_task_2" in subgroup_children + assert "tag_subgroup::tag_task_3" in subgroup_children + assert len(subgroup_children) == 3, ( + f"Subgroup should have 3 tasks, got {len(subgroup_children)}" + ) + + +# ============================================================================= +# Hierarchical Task Tests (children: syntax) +# ============================================================================= + + +class TestHierarchicalTasks: + """Tests for the new hierarchical task system with children: syntax.""" + + @pytest.fixture + def hierarchical_task_manager(self): + """TaskManager with test_configs including hierarchical groups.""" + test_configs_path = Path(__file__).parent / "test_configs" + return TaskManager(include_path=str(test_configs_path), include_defaults=False) + + def test_hierarchical_group_indexed(self, hierarchical_task_manager): + """Test that hierarchical group and its inline children are indexed.""" + tm = hierarchical_task_manager + + # The group should be indexed + assert "hierarchical_group" in tm.all_groups + + # Inline children should be indexed with :: paths + assert "hierarchical_group::inline_task_a" in tm.task_index + assert "hierarchical_group::inline_task_b" in tm.task_index + + # They should be TASK kind + assert tm.task_index["hierarchical_group::inline_task_a"].kind == Kind.TASK + assert tm.task_index["hierarchical_group::inline_task_b"].kind == Kind.TASK + + def test_hierarchical_group_parent_tracking(self, hierarchical_task_manager): + """Test that inline children track their parent.""" + tm = hierarchical_task_manager + + entry_a = tm.task_index["hierarchical_group::inline_task_a"] + assert entry_a.parent == "hierarchical_group" + + entry_b = tm.task_index["hierarchical_group::inline_task_b"] + assert entry_b.parent == "hierarchical_group" + + def test_hierarchical_group_load(self, hierarchical_task_manager): + """Test loading a hierarchical group builds all inline children.""" + tm = hierarchical_task_manager + + result = tm.load_task_or_group(["hierarchical_group"]) + + # Should have group structure + group_key = list(result.keys())[0] + children = result[group_key] + + # Both inline tasks should be present + assert "hierarchical_group::inline_task_a" in children + assert "hierarchical_group::inline_task_b" in children + assert len(children) == 2 + + def test_hierarchical_direct_task_access(self, hierarchical_task_manager): + """Test loading an inline task directly by its :: path.""" + tm = hierarchical_task_manager + + # Should be able to load inline task directly + result = tm.load_task_or_group(["hierarchical_group::inline_task_a"]) + + assert "hierarchical_group::inline_task_a" in result + task = result["hierarchical_group::inline_task_a"] + assert task.config.task == "hierarchical_group::inline_task_a" + + def test_nested_hierarchical_group_indexed(self, hierarchical_task_manager): + """Test that nested hierarchical groups are properly indexed.""" + tm = hierarchical_task_manager + + # Top-level group + assert "nested_group" in tm.all_groups + + # Subgroups should be indexed with :: paths + assert "nested_group::subgroup_a" in tm.task_index + assert "nested_group::subgroup_b" in tm.task_index + + # Nested tasks should be indexed with full :: paths + assert "nested_group::subgroup_a::task_1" in tm.task_index + assert "nested_group::subgroup_a::task_2" in tm.task_index + assert "nested_group::subgroup_b::task_3" in tm.task_index + + def test_nested_hierarchical_parent_chain(self, hierarchical_task_manager): + """Test that nested children have correct parent chain.""" + tm = hierarchical_task_manager + + # Subgroup's parent is the top group + subgroup_a = tm.task_index["nested_group::subgroup_a"] + assert subgroup_a.parent == "nested_group" + + # Task's parent is the subgroup + task_1 = tm.task_index["nested_group::subgroup_a::task_1"] + assert task_1.parent == "nested_group::subgroup_a" + + def test_nested_hierarchical_load(self, hierarchical_task_manager): + """Test loading nested hierarchical group builds full tree.""" + tm = hierarchical_task_manager + + result = tm.load_task_or_group(["nested_group"]) + + # Navigate the structure + top_key = list(result.keys())[0] + top_children = result[top_key] + + # Should have 2 subgroups + assert len(top_children) == 2 + + # Find subgroup_a's children + subgroup_a_key = None + for key in top_children: + if hasattr(key, "config") and key.config.get("group") == "subgroup_a": + subgroup_a_key = key + break + + if subgroup_a_key: + subgroup_a_children = top_children[subgroup_a_key] + assert "nested_group::subgroup_a::task_1" in subgroup_a_children + assert "nested_group::subgroup_a::task_2" in subgroup_a_children + + def test_nested_direct_subgroup_access(self, hierarchical_task_manager): + """Test loading a subgroup directly by its :: path.""" + tm = hierarchical_task_manager + + result = tm.load_task_or_group(["nested_group::subgroup_a"]) + + # Should get the subgroup + group_key = list(result.keys())[0] + children = result[group_key] + + # Should have the 2 tasks from subgroup_a + assert "nested_group::subgroup_a::task_1" in children + assert "nested_group::subgroup_a::task_2" in children + assert len(children) == 2 + + def test_nested_direct_task_access(self, hierarchical_task_manager): + """Test loading a deeply nested task directly.""" + tm = hierarchical_task_manager + + result = tm.load_task_or_group(["nested_group::subgroup_a::task_1"]) + + assert "nested_group::subgroup_a::task_1" in result + task = result["nested_group::subgroup_a::task_1"] + assert task.config.task == "nested_group::subgroup_a::task_1" + + def test_hierarchical_with_ref(self, hierarchical_task_manager): + """Test hierarchical group with ref: to external task.""" + tm = hierarchical_task_manager + + # The ref entry should be indexed + assert "hierarchical_refs_group::external_ref" in tm.task_index + + # It should have ref_target set + entry = tm.task_index["hierarchical_refs_group::external_ref"] + assert entry.ref_target == "simple_task" + + def test_hierarchical_with_tag(self, hierarchical_task_manager): + """Test hierarchical group with tag: expansion.""" + tm = hierarchical_task_manager + + # The tag ref entry should be indexed + assert "hierarchical_refs_group::tagged_tasks" in tm.task_index + + # It should have tag_ref set + entry = tm.task_index["hierarchical_refs_group::tagged_tasks"] + assert entry.tag_ref == "test_tag_tasks" + + def test_hierarchical_ref_resolution(self, hierarchical_task_manager): + """Test that ref: children resolve to their targets when built.""" + tm = hierarchical_task_manager + + result = tm.load_task_or_group(["hierarchical_refs_group"]) + + group_key = list(result.keys())[0] + children = result[group_key] + + # The inline task should be present + assert "hierarchical_refs_group::inline_task" in children + + # The ref should resolve to simple_task + # Note: the resolved task keeps its original name + assert "simple_task" in children + + def test_hierarchical_tag_expansion(self, hierarchical_task_manager): + """Test that tag: children expand to tagged tasks when built.""" + tm = hierarchical_task_manager + + result = tm.load_task_or_group(["hierarchical_refs_group"]) + + group_key = list(result.keys())[0] + children = result[group_key] + + # The tag should expand to all tasks with test_tag_tasks tag + # (tag_task_1, tag_task_2, tag_task_3 from test_configs) + assert "tag_task_1" in children + assert "tag_task_2" in children + assert "tag_task_3" in children