Skip to content

Commit bd80bf4

Browse files
committed
add hierarchical groups
1 parent 276d3cf commit bd80bf4

File tree

3 files changed

+393
-14
lines changed

3 files changed

+393
-14
lines changed

lm_eval/tasks/factory.py

Lines changed: 100 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,25 @@ def build(
3232
* entry.kind == TASK / PY_TASK -> returns instantiated task object
3333
* entry.kind == GROUP -> returns (GroupConfig, mapping-of-subtasks)
3434
* entry.kind == TAG -> returns mapping-of-tasks (tag expansion)
35+
* entry with ref_target -> resolves reference and builds target
36+
* entry with tag_ref -> expands tag and builds tasks
3537
"""
38+
# Handle external references (ref: in children)
39+
if entry.ref_target:
40+
if entry.ref_target not in registry:
41+
raise KeyError(
42+
f"Reference '{entry.ref_target}' not found for '{entry.name}'"
43+
)
44+
target_entry = registry[entry.ref_target]
45+
return self.build(target_entry, overrides=overrides, registry=registry)
46+
47+
# Handle tag expansion (tag: in children)
48+
if entry.tag_ref:
49+
if entry.tag_ref not in registry:
50+
raise KeyError(f"Tag '{entry.tag_ref}' not found for '{entry.name}'")
51+
tag_entry = registry[entry.tag_ref]
52+
return self._build_tag(tag_entry, overrides, registry)
53+
3654
if entry.kind is Kind.TAG:
3755
return self._build_tag(entry, overrides, registry)
3856

@@ -44,6 +62,11 @@ def build(
4462
def _build_task(self, entry: Entry, overrides: dict[str, Any] | None):
4563
"""Build a task and return it wrapped in a dict {task_name: task_obj}."""
4664
cfg = self._load_full_config(entry, overrides)
65+
66+
# Remove structural keys that aren't part of task config
67+
for key in ("children", "ref", "tag", "group"):
68+
cfg.pop(key, None)
69+
4770
# Use cfg["task"] as key (may be overridden, e.g., for namespacing)
4871
task_name = cfg["task"]
4972

@@ -71,25 +94,91 @@ def _build_group(
7194
group_name = entry.name
7295

7396
children: dict[str, Any] = {}
74-
for item in group_obj.config["task"]:
97+
98+
# Handle new-style children: dict (hierarchical)
99+
if "children" in raw_cfg:
100+
children.update(
101+
self._build_children(
102+
raw_cfg["children"], group_name, overrides, registry
103+
)
104+
)
105+
106+
# Handle old-style task: list (backward compatibility)
107+
if "task" in grp_cfg and isinstance(grp_cfg["task"], list):
108+
children.update(
109+
self._build_task_list(grp_cfg["task"], group_name, overrides, registry)
110+
)
111+
112+
return {group_obj: children}
113+
114+
def _build_children(
115+
self,
116+
children_cfg: dict[str, Any],
117+
group_name: str,
118+
overrides: dict[str, Any] | None,
119+
registry: Mapping[str, Entry],
120+
) -> dict[str, Any]:
121+
"""Build children defined via children: dict."""
122+
result: dict[str, Any] = {}
123+
124+
for child_name, child_cfg in children_cfg.items():
125+
child_path = f"{group_name}::{child_name}"
126+
127+
# Look up pre-registered entry from index
128+
if child_path in registry:
129+
child_entry = registry[child_path]
130+
child_overrides = overrides or {}
131+
132+
# Merge any inline overrides from child_cfg (excluding structural keys)
133+
inline_overrides = {
134+
k: v
135+
for k, v in child_cfg.items()
136+
if k not in ("ref", "tag", "children")
137+
}
138+
if inline_overrides:
139+
child_overrides = {**child_overrides, **inline_overrides}
140+
141+
child = self.build(
142+
child_entry, overrides=child_overrides, registry=registry
143+
)
144+
result.update(child)
145+
else:
146+
# Fallback: inline task not pre-registered (shouldn't normally happen)
147+
task_cfg = {**child_cfg, "task": child_path}
148+
task_cfg["metadata"] = task_cfg.get("metadata", {}) | self._meta
149+
result[child_path] = ConfigurableTask(config=task_cfg)
150+
151+
return result
152+
153+
def _build_task_list(
154+
self,
155+
task_list: list,
156+
group_name: str,
157+
overrides: dict[str, Any] | None,
158+
registry: Mapping[str, Entry],
159+
) -> dict[str, Any]:
160+
"""Build children defined via task: list (backward compatibility)."""
161+
result: dict[str, Any] = {}
162+
163+
for item in task_list:
75164
# Step 1: Normalize - extract base_name and item_overrides
76165
if isinstance(item, str):
77166
base_name = item
78167
item_overrides = overrides or {}
79168
elif isinstance(item, dict):
80169
base_name = item["task"]
81-
item_overrides = item
170+
item_overrides = {**overrides, **item}
82171
else:
83172
raise TypeError(
84-
f"Unsupported sub-entry {item!r} in group '{entry.name}'"
173+
f"Unsupported sub-entry {item!r} in group '{group_name}'"
85174
)
86175

87176
# Step 2: Handle inline task (not in registry)
88177
if base_name not in registry:
89178
namespaced = f"{group_name}::{base_name}"
90179
task_cfg = {**item_overrides, "task": namespaced}
91180
task_cfg["metadata"] = task_cfg.get("metadata", {}) | self._meta
92-
children[namespaced] = ConfigurableTask(config=task_cfg)
181+
result[namespaced] = ConfigurableTask(config=task_cfg)
93182
continue
94183

95184
# Step 3: Build based on entry kind
@@ -118,9 +207,9 @@ def _build_group(
118207
registry=registry,
119208
)
120209

121-
children.update(child)
210+
result.update(child)
122211

123-
return {group_obj: children}
212+
return result
124213

125214
def _build_tag(
126215
self,
@@ -137,7 +226,11 @@ def _build_tag(
137226
def _load_full_config(
138227
self, entry: Entry, overrides: dict[str, Any] | None
139228
) -> dict[str, Any]:
140-
if entry.yaml_path:
229+
# For inline children (have parent), use the stored cfg directly
230+
# instead of loading from YAML (which would load the parent's full config)
231+
if entry.parent and entry.cfg:
232+
cfg = deepcopy(entry.cfg)
233+
elif entry.yaml_path:
141234
cfg = deepcopy(load_yaml(entry.yaml_path, resolve_func=True))
142235
else:
143236
cfg: dict[str, Any] = {

lm_eval/tasks/index.py

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ class Entry:
2929
cfg: dict[str, str] | None = None
3030
tags: set[str] = field(default_factory=set)
3131
task_list_path: Path | None = None
32+
# Hierarchical task support
33+
parent: str | None = (
34+
None # parent path for inline children (e.g., "mmlu" for "mmlu::stem")
35+
)
36+
ref_target: str | None = None # for children with ref: points to external entry
37+
tag_ref: str | None = None # for children with tag: expands to tagged tasks
3238

3339

3440
log = logging.getLogger(__name__)
@@ -85,24 +91,33 @@ def process_cfg(
8591
cfg: dict[str, Any],
8692
path: Path,
8793
index: dict[str, Entry],
94+
parent_path: str | None = None,
8895
) -> None:
8996
kind = TaskIndex._kind_of(cfg)
9097
if kind is Kind.GROUP:
9198
grp_name = cfg["group"]
92-
if grp_name in index:
99+
# Build full path for hierarchical addressing
100+
full_path = f"{parent_path}::{grp_name}" if parent_path else grp_name
101+
102+
if full_path in index:
93103
log.debug(
94-
f"Duplicate group name '{grp_name}' found. "
95-
f"Already registered from: {index[grp_name].yaml_path}. "
104+
f"Duplicate group name '{full_path}' found. "
105+
f"Already registered from: {index[full_path].yaml_path}. "
96106
f"Skipping duplicate from: {path}"
97107
)
98108
return
99-
index[grp_name] = Entry(
100-
name=grp_name,
109+
index[full_path] = Entry(
110+
name=full_path,
101111
kind=Kind.GROUP,
102112
yaml_path=path,
103113
tags=TaskIndex._str_to_set(cfg.get("tag")),
104114
cfg=cfg,
115+
parent=parent_path,
105116
)
117+
118+
# Process inline children if present
119+
if "children" in cfg:
120+
TaskIndex._process_children(cfg["children"], full_path, path, index)
106121
return
107122

108123
if kind is Kind.TASK or kind is Kind.PY_TASK:
@@ -204,3 +219,77 @@ def _str_to_set(tags: str | list[str] | None = None) -> set[str]:
204219
if isinstance(tags, str)
205220
else set()
206221
)
222+
223+
@staticmethod
224+
def _process_children(
225+
children: dict[str, Any],
226+
parent_path: str,
227+
yaml_path: Path,
228+
index: dict[str, Entry],
229+
) -> None:
230+
"""Process inline children definitions within a group.
231+
232+
Children can be:
233+
- Inline task: dict with task config fields (dataset_path, doc_to_text, etc.)
234+
- Inline subgroup: dict with 'children' key
235+
- External ref: dict with 'ref' key pointing to existing entry
236+
- Tag expansion: dict with 'tag' key to expand tagged tasks
237+
"""
238+
for child_name, child_cfg in children.items():
239+
if not isinstance(child_cfg, dict):
240+
log.warning(
241+
f"Invalid child config for '{child_name}' in '{parent_path}': "
242+
f"expected dict, got {type(child_cfg).__name__}"
243+
)
244+
continue
245+
246+
child_path = f"{parent_path}::{child_name}"
247+
248+
if child_path in index:
249+
log.debug(f"Duplicate child '{child_path}' found, skipping.")
250+
continue
251+
252+
if "ref" in child_cfg:
253+
# External reference - register with ref_target for build-time resolution
254+
index[child_path] = Entry(
255+
name=child_path,
256+
kind=Kind.GROUP, # Assume group, will resolve at build time
257+
yaml_path=yaml_path,
258+
parent=parent_path,
259+
ref_target=child_cfg["ref"],
260+
cfg=child_cfg,
261+
tags=TaskIndex._str_to_set(child_cfg.get("tag")),
262+
)
263+
264+
elif "tag" in child_cfg:
265+
# Tag expansion - register with tag_ref for build-time expansion
266+
index[child_path] = Entry(
267+
name=child_path,
268+
kind=Kind.TAG,
269+
yaml_path=yaml_path,
270+
parent=parent_path,
271+
tag_ref=child_cfg["tag"],
272+
cfg=child_cfg,
273+
tags=TaskIndex._str_to_set(child_cfg.get("tag")),
274+
)
275+
276+
elif "children" in child_cfg:
277+
# Nested inline group - recurse
278+
nested_cfg = {**child_cfg, "group": child_name}
279+
TaskIndex.process_cfg(
280+
nested_cfg, yaml_path, index, parent_path=parent_path
281+
)
282+
283+
else:
284+
# Inline task definition
285+
task_cfg = {**child_cfg, "task": child_path}
286+
index[child_path] = Entry(
287+
name=child_path,
288+
kind=Kind.TASK,
289+
yaml_path=yaml_path,
290+
parent=parent_path,
291+
cfg=task_cfg,
292+
tags=TaskIndex._str_to_set(child_cfg.get("tag")),
293+
)
294+
# Register tags for inline tasks
295+
TaskIndex._register_tags(child_path, child_cfg.get("tag"), index)

0 commit comments

Comments
 (0)