Skip to content

Commit f8ad869

Browse files
authored
Add --task and --task-category filtering for eval command (#72)
1 parent b3a7b49 commit f8ad869

File tree

3 files changed

+307
-1
lines changed

3 files changed

+307
-1
lines changed

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ readme = "README.md"
1010
requires-python = ">=3.10"
1111
dependencies = [
1212
"click",
13-
"inspect-ai>=0.3.104",
13+
# Pin inspect-ai to avoid breaking changes in 0.3.137+ (Event classes moved)
14+
# See allenai/astabench-issues#275 for upgrade process
15+
"inspect-ai>=0.3.104,<0.3.137",
1416
# pin litellm so that we know what model costs we're using
1517
# see the Development.md doc before changing
1618
"litellm>=1.67.4.post1,<=1.75.8",

src/agenteval/cli.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,18 @@ def view_command(
972972
help="Display format. Defaults to plain.",
973973
default="plain",
974974
)
975+
@click.option(
976+
"--task",
977+
"task_filters",
978+
multiple=True,
979+
help="Filter to only run tasks whose name contains this string (can be specified multiple times).",
980+
)
981+
@click.option(
982+
"--task-category",
983+
"task_category_filters",
984+
multiple=True,
985+
help="Filter to only run tasks with this tag (can be specified multiple times).",
986+
)
975987
@click.argument("args", nargs=-1, type=click.UNPROCESSED)
976988
def eval_command(
977989
log_dir: str | None,
@@ -980,12 +992,43 @@ def eval_command(
980992
ignore_git: bool,
981993
config_only: bool,
982994
display: str,
995+
task_filters: tuple[str, ...],
996+
task_category_filters: tuple[str, ...],
983997
args: tuple[str],
984998
):
985999
"""Run inspect eval-set with arguments and append tasks"""
9861000
suite_config = load_suite_config(config_path)
9871001
tasks = suite_config.get_tasks(split)
9881002

1003+
# Apply task filtering
1004+
if task_filters or task_category_filters:
1005+
original_count = len(tasks)
1006+
filtered_tasks = []
1007+
for task in tasks:
1008+
# Check task name filter (substring match)
1009+
if task_filters:
1010+
name_match = any(f in task.name for f in task_filters)
1011+
if not name_match:
1012+
continue
1013+
1014+
# Check task category filter (exact tag match)
1015+
if task_category_filters:
1016+
task_tags = task.get_tag_names()
1017+
category_match = any(cat in task_tags for cat in task_category_filters)
1018+
if not category_match:
1019+
continue
1020+
1021+
filtered_tasks.append(task)
1022+
1023+
tasks = filtered_tasks
1024+
click.echo(f"Filtered to {len(tasks)} of {original_count} tasks")
1025+
1026+
if not tasks:
1027+
raise click.ClickException(
1028+
"No tasks match the specified filters. "
1029+
f"Task filters: {task_filters}, Category filters: {task_category_filters}"
1030+
)
1031+
9891032
# Verify git status for reproducibility
9901033
if not ignore_git:
9911034
verify_git_reproducibility()

tests/test_cli.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import os
2+
import tempfile
3+
14
from click.testing import CliRunner
25

36
from agenteval.cli import cli
@@ -8,3 +11,261 @@ def test_help_displays_usage():
811
result = runner.invoke(cli, ["--help"])
912
assert result.exit_code == 0
1013
assert "Usage:" in result.output
14+
15+
16+
class TestEvalTaskFiltering:
17+
"""Tests for --task and --task-category filtering in eval command."""
18+
19+
def _create_test_config(self, tmpdir):
20+
"""Create a test config file with multiple tasks and tags."""
21+
config_content = """
22+
name: test-suite
23+
version: "1.0.0"
24+
splits:
25+
- name: test
26+
tasks:
27+
- name: ArxivDIGESTables_Clean_train
28+
path: tasks/arxiv_clean
29+
primary_metric: accuracy
30+
tags:
31+
- lit
32+
- data
33+
- name: CodeGenTask_v1
34+
path: tasks/codegen
35+
primary_metric: pass_rate
36+
tags:
37+
- code
38+
- name: DiscoveryBenchmark_2024
39+
path: tasks/discovery
40+
primary_metric: f1_score
41+
tags:
42+
- discovery
43+
- lit
44+
"""
45+
config_path = os.path.join(tmpdir, "test_config.yml")
46+
with open(config_path, "w") as f:
47+
f.write(config_content)
48+
return config_path
49+
50+
def test_eval_shows_task_filter_options_in_help(self):
51+
"""Test that --task and --task-category options appear in eval help."""
52+
runner = CliRunner()
53+
result = runner.invoke(cli, ["eval", "--help"])
54+
assert result.exit_code == 0
55+
assert "--task TEXT" in result.output
56+
assert "--task-category TEXT" in result.output
57+
assert "Filter to only run tasks whose name contains" in result.output
58+
assert "Filter to only run tasks with this tag" in result.output
59+
60+
def test_eval_filter_by_task_name(self):
61+
"""Test filtering by task name substring."""
62+
runner = CliRunner()
63+
with tempfile.TemporaryDirectory() as tmpdir:
64+
config_path = self._create_test_config(tmpdir)
65+
log_dir = os.path.join(tmpdir, "logs")
66+
67+
result = runner.invoke(
68+
cli,
69+
[
70+
"eval",
71+
"--config-path",
72+
config_path,
73+
"--split",
74+
"test",
75+
"--ignore-git",
76+
"--config-only",
77+
"--log-dir",
78+
log_dir,
79+
"--task",
80+
"CodeGen",
81+
],
82+
)
83+
84+
assert result.exit_code == 0
85+
assert "Filtered to 1 of 3 tasks" in result.output
86+
assert "tasks/codegen" in result.output
87+
assert "tasks/arxiv_clean" not in result.output
88+
assert "tasks/discovery" not in result.output
89+
90+
def test_eval_filter_by_task_category(self):
91+
"""Test filtering by task category/tag."""
92+
runner = CliRunner()
93+
with tempfile.TemporaryDirectory() as tmpdir:
94+
config_path = self._create_test_config(tmpdir)
95+
log_dir = os.path.join(tmpdir, "logs")
96+
97+
result = runner.invoke(
98+
cli,
99+
[
100+
"eval",
101+
"--config-path",
102+
config_path,
103+
"--split",
104+
"test",
105+
"--ignore-git",
106+
"--config-only",
107+
"--log-dir",
108+
log_dir,
109+
"--task-category",
110+
"lit",
111+
],
112+
)
113+
114+
assert result.exit_code == 0
115+
assert "Filtered to 2 of 3 tasks" in result.output
116+
assert "tasks/arxiv_clean" in result.output
117+
assert "tasks/discovery" in result.output
118+
assert "tasks/codegen" not in result.output
119+
120+
def test_eval_filter_by_task_and_category_combined(self):
121+
"""Test filtering by both task name and category (AND logic)."""
122+
runner = CliRunner()
123+
with tempfile.TemporaryDirectory() as tmpdir:
124+
config_path = self._create_test_config(tmpdir)
125+
log_dir = os.path.join(tmpdir, "logs")
126+
127+
result = runner.invoke(
128+
cli,
129+
[
130+
"eval",
131+
"--config-path",
132+
config_path,
133+
"--split",
134+
"test",
135+
"--ignore-git",
136+
"--config-only",
137+
"--log-dir",
138+
log_dir,
139+
"--task",
140+
"Arxiv",
141+
"--task-category",
142+
"lit",
143+
],
144+
)
145+
146+
assert result.exit_code == 0
147+
assert "Filtered to 1 of 3 tasks" in result.output
148+
assert "tasks/arxiv_clean" in result.output
149+
# Discovery has "lit" tag but doesn't match "Arxiv"
150+
assert "tasks/discovery" not in result.output
151+
assert "tasks/codegen" not in result.output
152+
153+
def test_eval_filter_multiple_task_names(self):
154+
"""Test filtering with multiple --task options (OR logic within names)."""
155+
runner = CliRunner()
156+
with tempfile.TemporaryDirectory() as tmpdir:
157+
config_path = self._create_test_config(tmpdir)
158+
log_dir = os.path.join(tmpdir, "logs")
159+
160+
result = runner.invoke(
161+
cli,
162+
[
163+
"eval",
164+
"--config-path",
165+
config_path,
166+
"--split",
167+
"test",
168+
"--ignore-git",
169+
"--config-only",
170+
"--log-dir",
171+
log_dir,
172+
"--task",
173+
"CodeGen",
174+
"--task",
175+
"Discovery",
176+
],
177+
)
178+
179+
assert result.exit_code == 0
180+
assert "Filtered to 2 of 3 tasks" in result.output
181+
assert "tasks/codegen" in result.output
182+
assert "tasks/discovery" in result.output
183+
assert "tasks/arxiv_clean" not in result.output
184+
185+
def test_eval_filter_multiple_categories(self):
186+
"""Test filtering with multiple --task-category options (OR logic)."""
187+
runner = CliRunner()
188+
with tempfile.TemporaryDirectory() as tmpdir:
189+
config_path = self._create_test_config(tmpdir)
190+
log_dir = os.path.join(tmpdir, "logs")
191+
192+
result = runner.invoke(
193+
cli,
194+
[
195+
"eval",
196+
"--config-path",
197+
config_path,
198+
"--split",
199+
"test",
200+
"--ignore-git",
201+
"--config-only",
202+
"--log-dir",
203+
log_dir,
204+
"--task-category",
205+
"code",
206+
"--task-category",
207+
"discovery",
208+
],
209+
)
210+
211+
assert result.exit_code == 0
212+
assert "Filtered to 2 of 3 tasks" in result.output
213+
assert "tasks/codegen" in result.output
214+
assert "tasks/discovery" in result.output
215+
assert "tasks/arxiv_clean" not in result.output
216+
217+
def test_eval_no_filter_runs_all_tasks(self):
218+
"""Test that no filter runs all tasks."""
219+
runner = CliRunner()
220+
with tempfile.TemporaryDirectory() as tmpdir:
221+
config_path = self._create_test_config(tmpdir)
222+
log_dir = os.path.join(tmpdir, "logs")
223+
224+
result = runner.invoke(
225+
cli,
226+
[
227+
"eval",
228+
"--config-path",
229+
config_path,
230+
"--split",
231+
"test",
232+
"--ignore-git",
233+
"--config-only",
234+
"--log-dir",
235+
log_dir,
236+
],
237+
)
238+
239+
assert result.exit_code == 0
240+
# Should not show "Filtered to" message
241+
assert "Filtered to" not in result.output
242+
assert "tasks/arxiv_clean" in result.output
243+
assert "tasks/codegen" in result.output
244+
assert "tasks/discovery" in result.output
245+
246+
def test_eval_filter_no_matches_fails(self):
247+
"""Test that filtering with no matches raises an error."""
248+
runner = CliRunner()
249+
with tempfile.TemporaryDirectory() as tmpdir:
250+
config_path = self._create_test_config(tmpdir)
251+
log_dir = os.path.join(tmpdir, "logs")
252+
253+
result = runner.invoke(
254+
cli,
255+
[
256+
"eval",
257+
"--config-path",
258+
config_path,
259+
"--split",
260+
"test",
261+
"--ignore-git",
262+
"--config-only",
263+
"--log-dir",
264+
log_dir,
265+
"--task",
266+
"NonExistentTask",
267+
],
268+
)
269+
270+
assert result.exit_code != 0
271+
assert "No tasks match the specified filters" in result.output

0 commit comments

Comments
 (0)