22Tests that are not specific to any particular kind of task.
33
44"""
5-
65import logging
76import os
87
1413from playwright .sync_api import Page , TimeoutError
1514from tenacity import retry , stop_after_attempt , retry_if_exception_type
1615
17- from browsergym .workarena import ALL_COMPOSITIONAL_TASKS , get_all_tasks_agents
18- from browsergym .workarena .tasks .compositional .utils .curriculum import AGENT_CURRICULUM
19-
20-
21- AGENT_L2_SAMPLED_SET = get_all_tasks_agents (filter = "l2" )
22-
23- AGENT_L2_SAMPLED_TASKS , AGENT_L2_SEEDS = [sampled_set [0 ] for sampled_set in AGENT_L2_SAMPLED_SET ], [
24- sampled_set [1 ] for sampled_set in AGENT_L2_SAMPLED_SET
25- ]
26-
27- AGENT_L3_SAMPLED_SET = get_all_tasks_agents (filter = "l3" )
28-
29- AGENT_L3_SAMPLED_TASKS , AGENT_L3_SEEDS = [sampled_set [0 ] for sampled_set in AGENT_L3_SAMPLED_SET ], [
30- sampled_set [1 ] for sampled_set in AGENT_L3_SAMPLED_SET
31- ]
16+ from browsergym .workarena import get_all_tasks_agents
17+ from browsergym .workarena .tasks .compositional .base import CompositionalTask
3218
19+ # Combine all tasks into a single list for parameterization
20+ AGENT_L2_SAMPLED_SET = get_all_tasks_agents (filter = "l2" , is_agent_curriculum = True )
21+ AGENT_L3_SAMPLED_SET = get_all_tasks_agents (filter = "l3" , is_agent_curriculum = True )
3322HUMAN_L2_SAMPLED_SET = get_all_tasks_agents (filter = "l2" , is_agent_curriculum = False )
34-
35- HUMAN_L2_SAMPLED_TASKS , HUMAN_L2_SEEDS = [sampled_set [0 ] for sampled_set in HUMAN_L2_SAMPLED_SET ], [
36- sampled_set [1 ] for sampled_set in HUMAN_L2_SAMPLED_SET
37- ]
38-
3923HUMAN_L3_SAMPLED_SET = get_all_tasks_agents (filter = "l3" , is_agent_curriculum = False )
4024
41- HUMAN_L3_SAMPLED_TASKS , HUMAN_L3_SEEDS = [ sampled_set [ 0 ] for sampled_set in HUMAN_L3_SAMPLED_SET ], [
42- sampled_set [ 1 ] for sampled_set in HUMAN_L3_SAMPLED_SET
43- ]
25+ all_tasks_to_test = (
26+ AGENT_L2_SAMPLED_SET + AGENT_L3_SAMPLED_SET + HUMAN_L2_SAMPLED_SET + HUMAN_L3_SAMPLED_SET
27+ )
4428
4529test_category = os .environ .get ("TEST_CATEGORY" )
46-
4730if test_category :
48- tasks_to_test = []
49- items = AGENT_CURRICULUM .get (test_category )
50- if items :
51- for bucket in items ["buckets" ]:
52- tasks_to_test .extend (bucket )
31+ # If a category is specified, filter the tasks to test
32+ tasks_to_test = get_all_tasks_agents (filter = f"l3.{ test_category } " , is_agent_curriculum = True )
5333else :
54- tasks_to_test = ALL_COMPOSITIONAL_TASKS
55-
56-
57- @retry (
58- stop = stop_after_attempt (5 ),
59- reraise = True ,
60- before_sleep = lambda _ : logging .info ("Retrying due to a TimeoutError..." ),
61- )
62- @pytest .mark .parametrize ("task_entrypoint" , tasks_to_test )
63- @pytest .mark .parametrize ("random_seed" , range (1 ))
64- @pytest .mark .parametrize ("level" , range (2 , 4 ))
65- @pytest .mark .pricy
66- def test_cheat_compositional (task_entrypoint , random_seed , level , page : Page ):
67- task = task_entrypoint (seed = random_seed , level = level )
68- goal , info = task .setup (page = page )
69- chat_messages = []
70- for i in range (len (task )):
71- page .wait_for_timeout (1000 )
72- task .cheat (page = page , chat_messages = chat_messages , subtask_idx = i )
73- page .wait_for_timeout (1000 )
74- reward , done , message , info = task .validate (page = page , chat_messages = chat_messages )
75- if i < len (task ) - 1 :
76- assert done is False and reward == 0.0
77-
78- task .teardown ()
79-
80- assert done is True and reward == 1.0
34+ tasks_to_test = all_tasks_to_test
8135
8236
8337@retry (
@@ -86,89 +40,14 @@ def test_cheat_compositional(task_entrypoint, random_seed, level, page: Page):
8640 reraise = True ,
8741 before_sleep = lambda _ : logging .info ("Retrying due to a TimeoutError..." ),
8842)
89- @pytest .mark .parametrize ("task_entrypoint, seed" , zip (AGENT_L2_SAMPLED_TASKS , AGENT_L2_SEEDS ))
90- @pytest .mark .slow
91- @pytest .mark .skip (reason = "Tests are too slow" )
92- def test_cheat_compositional_sampled_agent_set_l2 (task_entrypoint , seed , page : Page ):
93- task = task_entrypoint (seed = seed )
94- goal , info = task .setup (page = page )
95- chat_messages = []
96- for i in range (len (task )):
97- page .wait_for_timeout (1000 )
98- task .cheat (page = page , chat_messages = chat_messages , subtask_idx = i )
99- page .wait_for_timeout (1000 )
100- reward , done , message , info = task .validate (page = page , chat_messages = chat_messages )
101- if i < len (task ) - 1 :
102- assert done is False and reward == 0.0
103-
104- task .teardown ()
105-
106- assert done is True and reward == 1.0
107-
108-
109- @retry (
110- stop = stop_after_attempt (5 ),
111- retry = retry_if_exception_type (TimeoutError ),
112- reraise = True ,
113- before_sleep = lambda _ : logging .info ("Retrying due to a TimeoutError..." ),
114- )
115- @pytest .mark .parametrize ("task_entrypoint, seed" , zip (AGENT_L3_SAMPLED_TASKS , AGENT_L3_SEEDS ))
116- @pytest .mark .slow
117- @pytest .mark .skip (reason = "Tests are too slow" )
118- def test_cheat_compositional_sampled_agent_set_l3 (task_entrypoint , seed , page : Page ):
119- task = task_entrypoint (seed = seed )
120- goal , info = task .setup (page = page )
121- chat_messages = []
122- for i in range (len (task )):
123- page .wait_for_timeout (1000 )
124- task .cheat (page = page , chat_messages = chat_messages , subtask_idx = i )
125- page .wait_for_timeout (1000 )
126- reward , done , message , info = task .validate (page = page , chat_messages = chat_messages )
127- if i < len (task ) - 1 :
128- assert done is False and reward == 0.0
129-
130- task .teardown ()
131-
132- assert done is True and reward == 1.0
133-
134-
135- @retry (
136- stop = stop_after_attempt (5 ),
137- retry = retry_if_exception_type (TimeoutError ),
138- reraise = True ,
139- before_sleep = lambda _ : logging .info ("Retrying due to a TimeoutError..." ),
140- )
141- @pytest .mark .parametrize ("task_entrypoint, seed" , zip (HUMAN_L2_SAMPLED_TASKS , HUMAN_L2_SEEDS ))
142- @pytest .mark .slow
143- @pytest .mark .skip (reason = "Tests are too slow" )
144- def test_cheat_compositional_sampled_human_set_l2 (task_entrypoint , seed , page : Page ):
145- task = task_entrypoint (seed = seed )
146- goal , info = task .setup (page = page )
147- chat_messages = []
148- for i in range (len (task )):
149- page .wait_for_timeout (1000 )
150- task .cheat (page = page , chat_messages = chat_messages , subtask_idx = i )
151- page .wait_for_timeout (1000 )
152- reward , done , message , info = task .validate (page = page , chat_messages = chat_messages )
153- if i < len (task ) - 1 :
154- assert done is False and reward == 0.0
155-
156- task .teardown ()
157-
158- assert done is True and reward == 1.0
159-
160-
161- @retry (
162- stop = stop_after_attempt (5 ),
163- retry = retry_if_exception_type (TimeoutError ),
164- reraise = True ,
165- before_sleep = lambda _ : logging .info ("Retrying due to a TimeoutError..." ),
166- )
167- @pytest .mark .parametrize ("task_entrypoint, seed" , zip (HUMAN_L3_SAMPLED_TASKS , HUMAN_L3_SEEDS ))
168- @pytest .mark .slow
169- @pytest .mark .skip (reason = "Tests are too slow" )
170- def test_cheat_compositional_sampled_human_set_l3 (task_entrypoint , seed , page : Page ):
171- task = task_entrypoint (seed = seed )
43+ @pytest .mark .parametrize ("task_class, seed" , tasks_to_test )
44+ @pytest .mark .pricy
45+ def test_cheat_compositional (task_class , seed , page : Page ):
46+ """
47+ Test that the cheat method works for all compositional tasks.
48+ This test is parameterized to run for all tasks in the agent and human curricula.
49+ """
50+ task = task_class (seed = seed )
17251 goal , info = task .setup (page = page )
17352 chat_messages = []
17453 for i in range (len (task )):
0 commit comments