Skip to content

Commit 65fc387

Browse files
committed
unroll tags in groups; add tests
1 parent ff39c24 commit 65fc387

File tree

8 files changed

+172
-64
lines changed

8 files changed

+172
-64
lines changed

lm_eval/tasks/factory.py

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,14 @@
77

88
from lm_eval.api.group import ConfigurableGroup, GroupConfig
99
from lm_eval.api.task import ConfigurableTask
10-
from lm_eval.tasks._config_loader import load_yaml as load_cfg
10+
from lm_eval.tasks._config_loader import load_yaml
1111
from lm_eval.tasks.index import Entry, Kind
1212

1313

14-
load_cfg_cached = load_cfg # type: ignore[no-redef]
15-
16-
1714
class TaskFactory:
1815
"""
1916
Turns a *Entry* (plus optional overrides) into a
20-
*Task* (from task_v3) | *ConfigurableTask* | *GroupConfig* hierarchy.
21-
22-
For YAML tasks, uses the task_v3.Task builder pattern to automatically
23-
select the appropriate Task subclass based on output_type.
17+
*Task* | *ConfigurableTask* | *GroupConfig* hierarchy.
2418
"""
2519

2620
def __init__(self, *, meta: dict[str, Any] | None = None):
@@ -35,9 +29,9 @@ def build(
3529
registry: Mapping[str, Entry],
3630
):
3731
"""
38-
entry.kind == TASK / PY_TASK returns instantiated task object
39-
entry.kind == GROUP returns (GroupConfig, mapping-of-subtasks)
40-
entry.kind == TAG returns mapping-of-tasks (tag expansion)
32+
* entry.kind == TASK / PY_TASK -> returns instantiated task object
33+
* entry.kind == GROUP -> returns (GroupConfig, mapping-of-subtasks)
34+
* entry.kind == TAG -> returns mapping-of-tasks (tag expansion)
4135
"""
4236
if entry.kind is Kind.TAG:
4337
return self._build_tag(entry, overrides, registry)
@@ -47,20 +41,22 @@ def build(
4741

4842
return self._build_task(entry, overrides)
4943

50-
def _build_task(self, entry: Entry, overrides: dict[str, Any] | None) -> dict:
44+
def _build_task(self, entry: Entry, overrides: dict[str, Any] | None):
5145
"""Build a task and return it wrapped in a dict {task_name: task_obj}."""
5246
cfg = self._load_full_config(entry, overrides)
47+
# Use cfg["task"] as key (may be overridden, e.g., for namespacing)
48+
task_name = cfg["task"]
5349

5450
if "class" in cfg: # PY_TASK route
5551
cls = cfg["class"]
5652
obj = cls(config=cfg) if _ctor_accepts_config(cls) else cls()
5753
if hasattr(obj, "config") and hasattr(obj.config, "task"):
58-
obj.config.task = entry.name
59-
return {entry.name: obj}
54+
obj.config.task = task_name
55+
return {task_name: obj}
6056

6157
# Regular YAML task - use ConfigurableTask
6258
task_obj = ConfigurableTask(config=cfg)
63-
return {entry.name: task_obj}
59+
return {task_name: task_obj}
6460

6561
def _build_group(
6662
self,
@@ -71,46 +67,58 @@ def _build_group(
7167
raw_cfg = self._load_full_config(entry, None)
7268
grp_cfg = {k: v for k, v in raw_cfg.items() if k in GroupConfig.__annotations__}
7369
grp_cfg["metadata"] = grp_cfg.get("metadata", {}) | self._meta
74-
# Use ConfigurableGroup (hashable) instead of GroupConfig (dict, unhashable)
7570
group_obj = ConfigurableGroup(config=grp_cfg)
7671
group_name = entry.name
7772

7873
children: dict[str, Any] = {}
7974
for item in group_obj.config["task"]:
75+
# Step 1: Normalize - extract base_name and item_overrides
8076
if isinstance(item, str):
81-
# Case 1: String reference - look up in registry
8277
base_name = item
83-
child = self.build(
84-
registry[item],
85-
overrides=overrides, # group-level overrides propagate
86-
registry=registry,
87-
)
78+
item_overrides = overrides or {}
8879
elif isinstance(item, dict):
8980
base_name = item["task"]
90-
if base_name in registry:
91-
# Case 2: Modify existing indexed task
92-
child = self.build(
93-
registry[base_name],
94-
overrides=item, # per-item override
95-
registry=registry,
96-
)
97-
else:
98-
# Case 3: Create new task inline (not indexed)
99-
task_cfg = {**item}
100-
task_cfg["metadata"] = task_cfg.get("metadata", {}) | self._meta
101-
task_obj = ConfigurableTask(config=task_cfg)
102-
child = {base_name: task_obj}
81+
item_overrides = item
10382
else:
10483
raise TypeError(
10584
f"Unsupported sub-entry {item!r} in group '{entry.name}'"
10685
)
10786

108-
# Namespace ALL child tasks with group_name::task_name
109-
namespaced_child = {}
110-
for task_name, task_obj in child.items():
111-
namespaced_name = f"{group_name}::{task_name}"
112-
namespaced_child[namespaced_name] = task_obj
113-
children.update(namespaced_child)
87+
# Step 2: Handle inline task (not in registry)
88+
if base_name not in registry:
89+
namespaced = f"{group_name}::{base_name}"
90+
task_cfg = {**item_overrides, "task": namespaced}
91+
task_cfg["metadata"] = task_cfg.get("metadata", {}) | self._meta
92+
children[namespaced] = ConfigurableTask(config=task_cfg)
93+
continue
94+
95+
# Step 3: Build based on entry kind
96+
child_entry = registry[base_name]
97+
98+
if child_entry.kind is Kind.GROUP:
99+
child = self.build(
100+
child_entry, overrides=item_overrides, registry=registry
101+
)
102+
elif child_entry.kind is Kind.TAG:
103+
child = {}
104+
for task_name in child_entry.tags:
105+
namespaced = f"{group_name}::{task_name}"
106+
child.update(
107+
self.build(
108+
registry[task_name],
109+
overrides={"task": namespaced, **item_overrides},
110+
registry=registry,
111+
)
112+
)
113+
else: # TASK or PY_TASK
114+
namespaced = f"{group_name}::{base_name}"
115+
child = self.build(
116+
child_entry,
117+
overrides={"task": namespaced, **item_overrides},
118+
registry=registry,
119+
)
120+
121+
children.update(child)
114122

115123
return {group_obj: children}
116124

@@ -119,7 +127,7 @@ def _build_tag(
119127
entry: Entry,
120128
overrides: dict[str, Any] | None,
121129
registry: Mapping[str, Entry],
122-
) -> dict:
130+
):
123131
"""Build all tasks in a tag and return merged dict."""
124132
result = {}
125133
for name in entry.tags:
@@ -130,7 +138,7 @@ def _load_full_config(
130138
self, entry: Entry, overrides: dict[str, Any] | None
131139
) -> dict[str, Any]:
132140
if entry.yaml_path:
133-
cfg = deepcopy(load_cfg_cached(entry.yaml_path, resolve_func=True))
141+
cfg = deepcopy(load_yaml(entry.yaml_path, resolve_func=True))
134142
else:
135143
cfg: dict[str, Any] = {
136144
"metadata": {"config": "unknown"}

lm_eval/tasks/index.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def process_cfg(
9090
if kind is Kind.GROUP:
9191
grp_name = cfg["group"]
9292
if grp_name in index:
93-
log.warning(
93+
log.debug(
9494
f"Duplicate group name '{grp_name}' found. "
9595
f"Already registered from: {index[grp_name].yaml_path}. "
9696
f"Skipping duplicate from: {path}"
@@ -105,26 +105,7 @@ def process_cfg(
105105
)
106106
return
107107

108-
if kind is Kind.PY_TASK:
109-
name = cfg["task"]
110-
if name in index:
111-
log.warning(
112-
f"Duplicate task name '{name}' found. "
113-
f"Already registered from: {index[name].yaml_path}. "
114-
f"Skipping duplicate from: {path}"
115-
)
116-
return
117-
index[name] = Entry(
118-
name=name,
119-
kind=Kind.PY_TASK,
120-
yaml_path=path,
121-
tags=TaskIndex._str_to_set(cfg.get("tag")),
122-
cfg=cfg,
123-
)
124-
TaskIndex._register_tags(name, cfg.get("tag"), index)
125-
return
126-
127-
if kind is Kind.TASK:
108+
if kind is Kind.TASK or kind is Kind.PY_TASK:
128109
name = cfg["task"]
129110
if name in index:
130111
log.warning(
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Parent group containing a subgroup - simulates mmlu pattern
2+
# Structure: tag_parent_group -> tag_subgroup -> test_tag_tasks (TAG) -> tag_task_1, tag_task_2, tag_task_3
3+
group: tag_parent_group
4+
task:
5+
- tag_subgroup
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Subgroup that references a TAG - simulates mmlu_humanities pattern
2+
# This group contains a tag reference (test_tag_tasks) which expands to multiple tasks
3+
group: tag_subgroup
4+
task:
5+
- test_tag_tasks

tests/test_configs/tag_task_1.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Task 1 with tag - simulates mmlu_formal_logic pattern
2+
task: tag_task_1
3+
tag: test_tag_tasks
4+
dataset_path: json
5+
dataset_kwargs:
6+
data_files:
7+
test: tests/test_configs/test_data.json
8+
output_type: multiple_choice
9+
doc_to_text: "{{question}}"
10+
doc_to_target: "{{choices[answer]}}"
11+
test_split: test
12+
num_fewshot: 0
13+
metric_list:
14+
- metric: acc
15+
aggregation: mean
16+
higher_is_better: true

tests/test_configs/tag_task_2.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Task 2 with tag - simulates mmlu_high_school_history pattern
2+
task: tag_task_2
3+
tag: test_tag_tasks
4+
dataset_path: json
5+
dataset_kwargs:
6+
data_files:
7+
test: tests/test_configs/test_data.json
8+
output_type: multiple_choice
9+
doc_to_text: "{{question}}"
10+
doc_to_target: "{{choices[answer]}}"
11+
test_split: test
12+
num_fewshot: 0
13+
metric_list:
14+
- metric: acc
15+
aggregation: mean
16+
higher_is_better: true

tests/test_configs/tag_task_3.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Task 3 with tag - simulates mmlu_philosophy pattern
2+
task: tag_task_3
3+
tag: test_tag_tasks
4+
dataset_path: json
5+
dataset_kwargs:
6+
data_files:
7+
test: tests/test_configs/test_data.json
8+
output_type: multiple_choice
9+
doc_to_text: "{{question}}"
10+
doc_to_target: "{{choices[answer]}}"
11+
test_split: test
12+
num_fewshot: 0
13+
metric_list:
14+
- metric: acc
15+
aggregation: mean
16+
higher_is_better: true

tests/test_task_manager.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,3 +745,64 @@ def test_include_defaults_true_with_new_tasks(self, shared_task_manager):
745745
assert len(task_manager.all_tasks) > len(shared_task_manager.all_tasks), (
746746
"Should have more tasks when including custom path"
747747
)
748+
749+
def test_tag_expansion_in_group(self, test_configs_task_manager):
750+
"""Test that TAGs inside groups are expanded and each task is namespaced individually.
751+
752+
This tests the MMLU-like structure: GROUP -> TAG -> multiple tasks
753+
Without proper TAG handling, all tasks in the tag get the same namespaced name
754+
and collide, leaving only one task.
755+
"""
756+
# Load the subgroup that contains a TAG reference
757+
result = test_configs_task_manager.load_task_or_group(["tag_subgroup"])
758+
759+
# Get the children dict from the group
760+
group_key = list(result.keys())[0]
761+
children = result[group_key]
762+
763+
# All 3 tasks from the tag should be expanded and namespaced
764+
assert "tag_subgroup::tag_task_1" in children, (
765+
"tag_task_1 should be namespaced under tag_subgroup"
766+
)
767+
assert "tag_subgroup::tag_task_2" in children, (
768+
"tag_task_2 should be namespaced under tag_subgroup"
769+
)
770+
assert "tag_subgroup::tag_task_3" in children, (
771+
"tag_task_3 should be namespaced under tag_subgroup"
772+
)
773+
774+
# Verify we have exactly 3 tasks (not 1 due to collision)
775+
assert len(children) == 3, (
776+
f"Should have 3 tasks from TAG expansion, got {len(children)}"
777+
)
778+
779+
def test_nested_group_with_tag(self, test_configs_task_manager):
780+
"""Test nested groups with TAG: parent_group -> subgroup -> TAG -> tasks.
781+
782+
This simulates the full MMLU structure where:
783+
- mmlu (GROUP) contains mmlu_humanities (GROUP)
784+
- mmlu_humanities contains mmlu_humanities_tasks (TAG)
785+
- The TAG expands to individual tasks
786+
"""
787+
# Load the parent group
788+
result = test_configs_task_manager.load_task_or_group(["tag_parent_group"])
789+
790+
# Navigate the nested structure
791+
parent_key = list(result.keys())[0]
792+
parent_children = result[parent_key]
793+
794+
# Should contain the subgroup
795+
assert len(parent_children) == 1, "Parent should have 1 child (the subgroup)"
796+
797+
# Get the subgroup
798+
subgroup_key = list(parent_children.keys())[0]
799+
subgroup_children = parent_children[subgroup_key]
800+
801+
# The subgroup should have all 3 tasks expanded from the TAG
802+
# Tasks are namespaced under their immediate parent group (tag_subgroup)
803+
assert "tag_subgroup::tag_task_1" in subgroup_children
804+
assert "tag_subgroup::tag_task_2" in subgroup_children
805+
assert "tag_subgroup::tag_task_3" in subgroup_children
806+
assert len(subgroup_children) == 3, (
807+
f"Subgroup should have 3 tasks, got {len(subgroup_children)}"
808+
)

0 commit comments

Comments
 (0)