Skip to content

Commit ff39c24

Browse files
committed
namescape groups
1 parent 28728af commit ff39c24

File tree

6 files changed

+118
-33
lines changed

6 files changed

+118
-33
lines changed

lm_eval/tasks/_config_loader.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -173,13 +173,14 @@ def load_yaml(
173173
merged = {}
174174
for inc in includes if isinstance(includes, list) else [includes]:
175175
inc_path = (path.parent / inc) if not Path(inc).is_absolute() else Path(inc)
176-
merged.update(
177-
load_yaml(
178-
inc_path,
179-
resolve_func=resolve_func,
180-
recursive=True,
181-
_seen=_seen,
182-
),
176+
inc_cfg = load_yaml(
177+
inc_path,
178+
resolve_func=resolve_func,
179+
recursive=True,
180+
_seen=_seen,
183181
)
182+
# Don't inherit task_list - it defines tasks for the included file only
183+
inc_cfg.pop("task_list", None)
184+
merged.update(inc_cfg)
184185
merged.update(cfg) # local keys win
185186
return merged
Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +0,0 @@
1-
include: arc_easy.yaml
2-
task: arc_challenge
3-
dataset_name: ARC-Challenge

lm_eval/tasks/arc/arc_easy.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,7 @@ metric_list:
2121
higher_is_better: true
2222
metadata:
2323
version: 1.0
24+
25+
task_list:
26+
- task: arc_challenge
27+
dataset_name: ARC-Challenge

lm_eval/tasks/factory.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,29 +73,45 @@ def _build_group(
7373
grp_cfg["metadata"] = grp_cfg.get("metadata", {}) | self._meta
7474
# Use ConfigurableGroup (hashable) instead of GroupConfig (dict, unhashable)
7575
group_obj = ConfigurableGroup(config=grp_cfg)
76+
group_name = entry.name
7677

7778
children: dict[str, Any] = {}
7879
for item in group_obj.config["task"]:
79-
if isinstance(item, str): # task: hellaswag
80+
if isinstance(item, str):
81+
# Case 1: String reference - look up in registry
82+
base_name = item
8083
child = self.build(
8184
registry[item],
8285
overrides=overrides, # group-level overrides propagate
8386
registry=registry,
8487
)
85-
elif isinstance(item, dict): # task: {task: hellaswag, num_fewshot: 5}
88+
elif isinstance(item, dict):
8689
base_name = item["task"]
87-
child = self.build(
88-
registry[base_name],
89-
overrides=item, # per-item override
90-
registry=registry,
91-
)
90+
if base_name in registry:
91+
# Case 2: Modify existing indexed task
92+
child = self.build(
93+
registry[base_name],
94+
overrides=item, # per-item override
95+
registry=registry,
96+
)
97+
else:
98+
# Case 3: Create new task inline (not indexed)
99+
task_cfg = {**item}
100+
task_cfg["metadata"] = task_cfg.get("metadata", {}) | self._meta
101+
task_obj = ConfigurableTask(config=task_cfg)
102+
child = {base_name: task_obj}
92103
else:
93104
raise TypeError(
94105
f"Unsupported sub-entry {item!r} in group '{entry.name}'"
95106
)
96107

97-
# `child` itself is a mapping (task-name -> obj) or {ConfigurableGroup: ...}
98-
children.update(child)
108+
# Namespace ALL child tasks with group_name::task_name
109+
namespaced_child = {}
110+
for task_name, task_obj in child.items():
111+
namespaced_name = f"{group_name}::{task_name}"
112+
namespaced_child[namespaced_name] = task_obj
113+
children.update(namespaced_child)
114+
99115
return {group_obj: children}
100116

101117
def _build_tag(

lm_eval/tasks/index.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,16 +144,45 @@ def process_cfg(
144144
return
145145

146146
if kind is Kind.TASK_LIST:
147+
# If config also has a top-level "task", register it as the base task
148+
if "task" in cfg and isinstance(cfg["task"], str):
149+
base_name = cfg["task"]
150+
if base_name not in index:
151+
index[base_name] = Entry(
152+
name=base_name,
153+
kind=Kind.TASK,
154+
yaml_path=path,
155+
tags=TaskIndex._str_to_set(cfg.get("tag")),
156+
cfg=cfg,
157+
)
158+
TaskIndex._register_tags(base_name, cfg.get("tag"), index)
159+
160+
# Register each task in task_list
161+
base_tag = cfg.get("tag")
147162
for entry in cfg["task_list"]:
148163
task_name = entry["task"] if isinstance(entry, dict) else entry
164+
if task_name in index:
165+
log.warning(
166+
f"Duplicate task name '{task_name}' found. "
167+
f"Already registered from: {index[task_name].yaml_path}. "
168+
f"Skipping duplicate from: {path}"
169+
)
170+
continue
171+
# Combine base tag with per-entry tag
172+
entry_tag = entry.get("tag") if isinstance(entry, dict) else None
173+
combined_tags = TaskIndex._str_to_set(base_tag) | TaskIndex._str_to_set(
174+
entry_tag
175+
)
149176
index[task_name] = Entry(
150177
name=task_name,
151178
kind=Kind.TASK,
152179
yaml_path=path,
153-
tags=TaskIndex._str_to_set(cfg.get("tag")),
180+
tags=combined_tags,
154181
cfg=cfg,
155182
)
156-
TaskIndex._register_tags(task_name, entry.get("tag"), index)
183+
# Register both base config's tag and per-entry tag
184+
TaskIndex._register_tags(task_name, base_tag, index)
185+
TaskIndex._register_tags(task_name, entry_tag, index)
157186
return
158187

159188
@staticmethod

tests/test_task_manager.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,13 @@ def shared_task_manager():
375375
return TaskManager()
376376

377377

378+
@pytest.fixture(scope="module")
379+
def test_configs_task_manager():
380+
"""TaskManager with only test_configs tasks (fast - no default task scanning)"""
381+
test_configs_path = Path(__file__).parent / "test_configs"
382+
return TaskManager(include_path=str(test_configs_path), include_defaults=False)
383+
384+
378385
class TestTaskManagerIntegration:
379386
def test_initialization(self, shared_task_manager):
380387
"""TaskManager initializes with default tasks"""
@@ -407,16 +414,20 @@ def test_all_tags_property(self, shared_task_manager):
407414
for t in tags[:5]: # Check first 5
408415
assert shared_task_manager._name_is_tag(t)
409416

410-
def test_load_task_by_name(self, shared_task_manager):
417+
def test_load_task_by_name(self, test_configs_task_manager):
411418
"""Load a single task by name"""
412-
result = shared_task_manager.load_task_or_group(["arc_easy"])
413-
assert "arc_easy" in result
414-
415-
def test_load_group_by_name(self, shared_task_manager):
416-
"""Load a group and get nested structure"""
417-
result = shared_task_manager.load_task_or_group(["ai2_arc"])
418-
# ai2_arc is a tag that contains arc_easy and arc_challenge
419-
assert "arc_easy" in result or "arc_challenge" in result
419+
result = test_configs_task_manager.load_task_or_group(["simple_task"])
420+
assert "simple_task" in result
421+
422+
def test_load_group_by_name(self, test_configs_task_manager):
423+
"""Load a group and get nested structure with namespaced task names"""
424+
result = test_configs_task_manager.load_task_or_group(["test_group"])
425+
# Result is {ConfigurableGroup: {task_name: task_obj}}
426+
# Get the children dict from the group
427+
children = list(result.values())[0]
428+
# test_group contains inline tasks, namespaced as group_name::task_name
429+
assert "test_group::group_task_fs0" in children
430+
assert "test_group::group_task_fs2" in children
420431

421432
def test_load_tag_by_name(self, shared_task_manager):
422433
"""Load all tasks in a tag"""
@@ -522,6 +533,33 @@ def test_task_list_overrides(self):
522533
assert "acc" in metric_names
523534
assert "acc_norm" in metric_names
524535

536+
def test_task_list_base_field_inheritance(self):
537+
"""Test that task_list tasks inherit base fields from the shared config"""
538+
test_configs_path = Path(__file__).parent / "test_configs"
539+
tm = TaskManager(include_path=str(test_configs_path), include_defaults=False)
540+
541+
result = tm.load_task_or_group(["task_list_fs0"])
542+
task = result["task_list_fs0"]
543+
544+
# Base fields should be inherited from the shared config
545+
assert task.config.dataset_path == "json", (
546+
"Should inherit dataset_path from base"
547+
)
548+
assert task.config.output_type == "multiple_choice", (
549+
"Should inherit output_type from base"
550+
)
551+
assert task.config.doc_to_text == "{{question}}", (
552+
"Should inherit doc_to_text from base"
553+
)
554+
assert task.config.test_split == "test", "Should inherit test_split from base"
555+
556+
# Default metric_list should be inherited (task_list_fs0 doesn't override it)
557+
metric_names = [m["metric"] for m in task.config.metric_list]
558+
assert "acc" in metric_names, "Should inherit metric_list from base"
559+
560+
# Per-task override should still be applied
561+
assert task.config.num_fewshot == 0, "Should have per-task num_fewshot override"
562+
525563
def test_match_tasks_glob(self, shared_task_manager):
526564
"""match_tasks handles glob patterns"""
527565
matches = shared_task_manager.match_tasks(["arc_*"])
@@ -543,7 +581,7 @@ def test_name_is_tag(self, shared_task_manager):
543581
assert shared_task_manager._name_is_tag("ai2_arc")
544582
assert not shared_task_manager._name_is_tag("arc_easy") # This is a task
545583

546-
def test_include_path_precedence(self):
584+
def test_include_path_precedence(self, shared_task_manager):
547585
"""Test that user-specified include paths take precedence over default paths when tasks have the same name."""
548586
with tempfile.TemporaryDirectory() as custom_dir:
549587
# Create a custom arc_easy.yaml that has a different metric
@@ -589,8 +627,8 @@ def test_include_path_precedence(self):
589627
)
590628

591629
# Test 2: Verify default is used when no custom path is provided
592-
default_task_manager = TaskManager(include_defaults=True)
593-
default_task_dict = default_task_manager.load_task_or_group(["arc_easy"])
630+
# Use shared_task_manager instead of creating a new one (saves ~9s)
631+
default_task_dict = shared_task_manager.load_task_or_group(["arc_easy"])
594632
default_arc_easy = default_task_dict["arc_easy"]
595633

596634
# Default should not have f1 metric or custom text

0 commit comments

Comments
 (0)