Skip to content

Commit fbd4e8e

Browse files
refactor tests
1 parent dca8e91 commit fbd4e8e

File tree

3 files changed

+145
-144
lines changed

3 files changed

+145
-144
lines changed

src/browsergym/workarena/__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,18 +111,19 @@ def get_task_category(task_name):
111111
return benchmark, TASK_CATEGORY_MAP.get(task_name, None)
112112

113113

114-
def get_all_tasks_agents(filter="l2", meta_seed=42, n_seed_l1=10, is_agent_curriculum=True):
114+
def get_all_tasks_agents(
115+
filter="l2", meta_seed=42, n_seed_l1=10, is_agent_curriculum=True, task_bucket=None
116+
):
115117
OFFSET = 42
116118
all_task_tuples = []
117119
filter = filter.split(".")
120+
rng = np.random.RandomState(meta_seed)
118121
if len(filter) > 2:
119122
raise Exception("Unsupported filter used.")
120123
if len(filter) == 1:
121124
level = filter[0]
122125
if level not in ["l1", "l2", "l3"]:
123126
raise Exception("Unsupported category of tasks.")
124-
else:
125-
rng = np.random.RandomState(meta_seed)
126127
if level == "l1":
127128
for task in ATOMIC_TASKS:
128129
for seed in rng.randint(0, 1000, n_seed_l1):
@@ -151,9 +152,16 @@ def get_all_tasks_agents(filter="l2", meta_seed=42, n_seed_l1=10, is_agent_curri
151152
for category, items in ALL_COMPOSITIONAL_TASKS_CATEGORIES.items():
152153
if filter_category and category != filter_category:
153154
continue
155+
# If a task_bucket is specified, check if it exists in the current category
156+
if task_bucket and task_bucket not in items["buckets"]:
157+
continue
154158
for curr_seed in rng.randint(0, 1000, items["num_seeds"]):
155159
random_gen = np.random.RandomState(curr_seed)
156-
for task_set, count in zip(items["buckets"], items["weights"]):
160+
for i, task_set in enumerate(items["buckets"]):
161+
# if a task_bucket is specified, only select tasks from that bucket
162+
if task_bucket and task_set != task_bucket:
163+
continue
164+
count = items["weights"][i]
157165
tasks = random_gen.choice(task_set, count, replace=False)
158166
for task in tasks:
159167
all_task_tuples.append((task, int(curr_seed)))

tests/test_compositional.py

Lines changed: 19 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Tests that are not specific to any particular kind of task.
33
44
"""
5-
65
import logging
76
import os
87

@@ -14,70 +13,25 @@
1413
from playwright.sync_api import Page, TimeoutError
1514
from tenacity import retry, stop_after_attempt, retry_if_exception_type
1615

17-
from browsergym.workarena import ALL_COMPOSITIONAL_TASKS, get_all_tasks_agents
18-
from browsergym.workarena.tasks.compositional.utils.curriculum import AGENT_CURRICULUM
19-
20-
21-
AGENT_L2_SAMPLED_SET = get_all_tasks_agents(filter="l2")
22-
23-
AGENT_L2_SAMPLED_TASKS, AGENT_L2_SEEDS = [sampled_set[0] for sampled_set in AGENT_L2_SAMPLED_SET], [
24-
sampled_set[1] for sampled_set in AGENT_L2_SAMPLED_SET
25-
]
26-
27-
AGENT_L3_SAMPLED_SET = get_all_tasks_agents(filter="l3")
28-
29-
AGENT_L3_SAMPLED_TASKS, AGENT_L3_SEEDS = [sampled_set[0] for sampled_set in AGENT_L3_SAMPLED_SET], [
30-
sampled_set[1] for sampled_set in AGENT_L3_SAMPLED_SET
31-
]
16+
from browsergym.workarena import get_all_tasks_agents
17+
from browsergym.workarena.tasks.compositional.base import CompositionalTask
3218

19+
# Combine all tasks into a single list for parameterization
20+
AGENT_L2_SAMPLED_SET = get_all_tasks_agents(filter="l2", is_agent_curriculum=True)
21+
AGENT_L3_SAMPLED_SET = get_all_tasks_agents(filter="l3", is_agent_curriculum=True)
3322
HUMAN_L2_SAMPLED_SET = get_all_tasks_agents(filter="l2", is_agent_curriculum=False)
34-
35-
HUMAN_L2_SAMPLED_TASKS, HUMAN_L2_SEEDS = [sampled_set[0] for sampled_set in HUMAN_L2_SAMPLED_SET], [
36-
sampled_set[1] for sampled_set in HUMAN_L2_SAMPLED_SET
37-
]
38-
3923
HUMAN_L3_SAMPLED_SET = get_all_tasks_agents(filter="l3", is_agent_curriculum=False)
4024

41-
HUMAN_L3_SAMPLED_TASKS, HUMAN_L3_SEEDS = [sampled_set[0] for sampled_set in HUMAN_L3_SAMPLED_SET], [
42-
sampled_set[1] for sampled_set in HUMAN_L3_SAMPLED_SET
43-
]
25+
all_tasks_to_test = (
26+
AGENT_L2_SAMPLED_SET + AGENT_L3_SAMPLED_SET + HUMAN_L2_SAMPLED_SET + HUMAN_L3_SAMPLED_SET
27+
)
4428

4529
test_category = os.environ.get("TEST_CATEGORY")
46-
4730
if test_category:
48-
tasks_to_test = []
49-
items = AGENT_CURRICULUM.get(test_category)
50-
if items:
51-
for bucket in items["buckets"]:
52-
tasks_to_test.extend(bucket)
31+
# If a category is specified, filter the tasks to test
32+
tasks_to_test = get_all_tasks_agents(filter=f"l3.{test_category}", is_agent_curriculum=True)
5333
else:
54-
tasks_to_test = ALL_COMPOSITIONAL_TASKS
55-
56-
57-
@retry(
58-
stop=stop_after_attempt(5),
59-
reraise=True,
60-
before_sleep=lambda _: logging.info("Retrying due to a TimeoutError..."),
61-
)
62-
@pytest.mark.parametrize("task_entrypoint", tasks_to_test)
63-
@pytest.mark.parametrize("random_seed", range(1))
64-
@pytest.mark.parametrize("level", range(2, 4))
65-
@pytest.mark.pricy
66-
def test_cheat_compositional(task_entrypoint, random_seed, level, page: Page):
67-
task = task_entrypoint(seed=random_seed, level=level)
68-
goal, info = task.setup(page=page)
69-
chat_messages = []
70-
for i in range(len(task)):
71-
page.wait_for_timeout(1000)
72-
task.cheat(page=page, chat_messages=chat_messages, subtask_idx=i)
73-
page.wait_for_timeout(1000)
74-
reward, done, message, info = task.validate(page=page, chat_messages=chat_messages)
75-
if i < len(task) - 1:
76-
assert done is False and reward == 0.0
77-
78-
task.teardown()
79-
80-
assert done is True and reward == 1.0
34+
tasks_to_test = all_tasks_to_test
8135

8236

8337
@retry(
@@ -86,89 +40,14 @@ def test_cheat_compositional(task_entrypoint, random_seed, level, page: Page):
8640
reraise=True,
8741
before_sleep=lambda _: logging.info("Retrying due to a TimeoutError..."),
8842
)
89-
@pytest.mark.parametrize("task_entrypoint, seed", zip(AGENT_L2_SAMPLED_TASKS, AGENT_L2_SEEDS))
90-
@pytest.mark.slow
91-
@pytest.mark.skip(reason="Tests are too slow")
92-
def test_cheat_compositional_sampled_agent_set_l2(task_entrypoint, seed, page: Page):
93-
task = task_entrypoint(seed=seed)
94-
goal, info = task.setup(page=page)
95-
chat_messages = []
96-
for i in range(len(task)):
97-
page.wait_for_timeout(1000)
98-
task.cheat(page=page, chat_messages=chat_messages, subtask_idx=i)
99-
page.wait_for_timeout(1000)
100-
reward, done, message, info = task.validate(page=page, chat_messages=chat_messages)
101-
if i < len(task) - 1:
102-
assert done is False and reward == 0.0
103-
104-
task.teardown()
105-
106-
assert done is True and reward == 1.0
107-
108-
109-
@retry(
110-
stop=stop_after_attempt(5),
111-
retry=retry_if_exception_type(TimeoutError),
112-
reraise=True,
113-
before_sleep=lambda _: logging.info("Retrying due to a TimeoutError..."),
114-
)
115-
@pytest.mark.parametrize("task_entrypoint, seed", zip(AGENT_L3_SAMPLED_TASKS, AGENT_L3_SEEDS))
116-
@pytest.mark.slow
117-
@pytest.mark.skip(reason="Tests are too slow")
118-
def test_cheat_compositional_sampled_agent_set_l3(task_entrypoint, seed, page: Page):
119-
task = task_entrypoint(seed=seed)
120-
goal, info = task.setup(page=page)
121-
chat_messages = []
122-
for i in range(len(task)):
123-
page.wait_for_timeout(1000)
124-
task.cheat(page=page, chat_messages=chat_messages, subtask_idx=i)
125-
page.wait_for_timeout(1000)
126-
reward, done, message, info = task.validate(page=page, chat_messages=chat_messages)
127-
if i < len(task) - 1:
128-
assert done is False and reward == 0.0
129-
130-
task.teardown()
131-
132-
assert done is True and reward == 1.0
133-
134-
135-
@retry(
136-
stop=stop_after_attempt(5),
137-
retry=retry_if_exception_type(TimeoutError),
138-
reraise=True,
139-
before_sleep=lambda _: logging.info("Retrying due to a TimeoutError..."),
140-
)
141-
@pytest.mark.parametrize("task_entrypoint, seed", zip(HUMAN_L2_SAMPLED_TASKS, HUMAN_L2_SEEDS))
142-
@pytest.mark.slow
143-
@pytest.mark.skip(reason="Tests are too slow")
144-
def test_cheat_compositional_sampled_human_set_l2(task_entrypoint, seed, page: Page):
145-
task = task_entrypoint(seed=seed)
146-
goal, info = task.setup(page=page)
147-
chat_messages = []
148-
for i in range(len(task)):
149-
page.wait_for_timeout(1000)
150-
task.cheat(page=page, chat_messages=chat_messages, subtask_idx=i)
151-
page.wait_for_timeout(1000)
152-
reward, done, message, info = task.validate(page=page, chat_messages=chat_messages)
153-
if i < len(task) - 1:
154-
assert done is False and reward == 0.0
155-
156-
task.teardown()
157-
158-
assert done is True and reward == 1.0
159-
160-
161-
@retry(
162-
stop=stop_after_attempt(5),
163-
retry=retry_if_exception_type(TimeoutError),
164-
reraise=True,
165-
before_sleep=lambda _: logging.info("Retrying due to a TimeoutError..."),
166-
)
167-
@pytest.mark.parametrize("task_entrypoint, seed", zip(HUMAN_L3_SAMPLED_TASKS, HUMAN_L3_SEEDS))
168-
@pytest.mark.slow
169-
@pytest.mark.skip(reason="Tests are too slow")
170-
def test_cheat_compositional_sampled_human_set_l3(task_entrypoint, seed, page: Page):
171-
task = task_entrypoint(seed=seed)
43+
@pytest.mark.parametrize("task_class, seed", tasks_to_test)
44+
@pytest.mark.pricy
45+
def test_cheat_compositional(task_class, seed, page: Page):
46+
"""
47+
Test that the cheat method works for all compositional tasks.
48+
This test is parameterized to run for all tasks in the agent and human curricula.
49+
"""
50+
task = task_class(seed=seed)
17251
goal, info = task.setup(page=page)
17352
chat_messages = []
17453
for i in range(len(task)):

tests/test_workarena_utils.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""
2+
Tests for workarena utility functions.
3+
"""
4+
import pytest
5+
from browsergym.workarena import get_all_tasks_agents
6+
from browsergym.workarena.tasks.compositional import (
7+
AGENT_CURRICULUM_L2,
8+
AGENT_CURRICULUM_L3,
9+
HUMAN_CURRICULUM_L2,
10+
HUMAN_CURRICULUM_L3,
11+
specialize_task_class_to_level,
12+
)
13+
from browsergym.workarena.tasks.compositional.base import CompositionalTask
14+
from browsergym.workarena.tasks.compositional.mark_duplicate_problems import (
15+
BasicFilterProblemsAndMarkDuplicatesSmallTask,
16+
PriorityFilterProblemsAndMarkDuplicatesSmallTask,
17+
)
18+
from browsergym.workarena.tasks.compositional.navigate_and_do_infeasible import (
19+
InfeasibleNavigateAndCreateUserWithReasonTask,
20+
)
21+
22+
23+
def get_tasks_from_curriculum(curriculum):
24+
"""Helper function to extract all unique tasks from a curriculum."""
25+
all_tasks = set()
26+
for category, items in curriculum.items():
27+
for bucket in items["buckets"]:
28+
for task in bucket:
29+
all_tasks.add(task)
30+
return all_tasks
31+
32+
33+
def test_get_all_tasks_agents():
34+
"""Test that get_all_tasks_agents returns the correct tasks from the curricula."""
35+
# Test L1 filter (atomic tasks)
36+
tasks_with_seeds_l1 = get_all_tasks_agents(filter="l1")
37+
assert len(tasks_with_seeds_l1) > 0
38+
for task, seed in tasks_with_seeds_l1:
39+
assert not issubclass(task, CompositionalTask)
40+
assert isinstance(seed, int)
41+
42+
# Test L2 Human Curriculum
43+
tasks_with_seeds_l2_human = get_all_tasks_agents(filter="l2", is_agent_curriculum=False)
44+
expected_l2_human_tasks = get_tasks_from_curriculum(HUMAN_CURRICULUM_L2)
45+
assert len(tasks_with_seeds_l2_human) > 0
46+
for task, seed in tasks_with_seeds_l2_human:
47+
assert task in expected_l2_human_tasks
48+
49+
# Test L3 Human Curriculum
50+
tasks_with_seeds_l3_human = get_all_tasks_agents(filter="l3", is_agent_curriculum=False)
51+
expected_l3_human_tasks = get_tasks_from_curriculum(HUMAN_CURRICULUM_L3)
52+
assert len(tasks_with_seeds_l3_human) > 0
53+
for task, seed in tasks_with_seeds_l3_human:
54+
assert task in expected_l3_human_tasks
55+
56+
# Test category filtering
57+
category = "planning_and_problem_solving"
58+
tasks_with_seeds_cat = get_all_tasks_agents(
59+
filter=f"l3.{category}", is_agent_curriculum=True
60+
)
61+
assert len(tasks_with_seeds_cat) > 0
62+
# Expected tasks from the specified category's buckets
63+
expected_cat_tasks = set()
64+
for bucket in AGENT_CURRICULUM_L3[category]["buckets"]:
65+
expected_cat_tasks.update(bucket)
66+
67+
returned_tasks = {task for task, seed in tasks_with_seeds_cat}
68+
assert returned_tasks.issubset(expected_cat_tasks)
69+
70+
# Check that tasks from other categories are not present
71+
for other_category, items in AGENT_CURRICULUM_L3.items():
72+
if other_category != category:
73+
for bucket in items["buckets"]:
74+
for task in bucket:
75+
assert task not in returned_tasks
76+
77+
# Test task_bucket filtering
78+
category = "planning_and_problem_solving"
79+
# This bucket contains BasicFilterProblemsAndMarkDuplicatesSmallTask
80+
bucket_to_test = AGENT_CURRICULUM_L3[category]["buckets"][0]
81+
82+
tasks_with_seeds_bucket = get_all_tasks_agents(
83+
filter=f"l3.{category}", is_agent_curriculum=True, task_bucket=bucket_to_test
84+
)
85+
assert len(tasks_with_seeds_bucket) > 0
86+
87+
returned_tasks_from_bucket = {task for task, seed in tasks_with_seeds_bucket}
88+
89+
# 1. All returned tasks are from the specified bucket
90+
assert returned_tasks_from_bucket.issubset(set(bucket_to_test))
91+
92+
# 2. A specific task from the bucket is present
93+
expected_task_base = BasicFilterProblemsAndMarkDuplicatesSmallTask
94+
# Find the specialized task in the bucket that corresponds to the base task
95+
expected_task_specialized = next(
96+
task
97+
for task in bucket_to_test
98+
if expected_task_base in task.__mro__
99+
)
100+
assert expected_task_specialized in returned_tasks_from_bucket
101+
102+
# A task from a different category is not present
103+
unexpected_task = specialize_task_class_to_level(
104+
InfeasibleNavigateAndCreateUserWithReasonTask, level=3
105+
)
106+
assert unexpected_task not in returned_tasks_from_bucket
107+
108+
# Test invalid filter
109+
with pytest.raises(Exception):
110+
get_all_tasks_agents(filter="invalid")
111+
112+
# Test invalid category filter
113+
with pytest.raises(Exception):
114+
get_all_tasks_agents(filter="l3.invalid_category")

0 commit comments

Comments
 (0)