Skip to content

Commit 5297157

Browse files
recursixgasse
authored andcommitted
add support for tab visibility in observation flags and update related components
1 parent cf05bc6 commit 5297157

File tree

4 files changed

+26
-38
lines changed

4 files changed

+26
-38
lines changed

src/agentlab/agents/dynamic_prompting.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
import bgym
1212
from browsergym.core.action.base import AbstractActionSet
13-
from browsergym.core.action.highlevel import HighLevelActionSet
14-
from browsergym.core.action.python import PythonActionSet
1513
from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, overlay_som, prune_html
1614

1715
from agentlab.llm.llm_utils import (
@@ -71,6 +69,7 @@ class ObsFlags(Flags):
7169

7270
use_html: bool = True
7371
use_ax_tree: bool = False
72+
use_tabs: bool = False
7473
use_focused_element: bool = False
7574
use_error_logs: bool = False
7675
use_history: bool = False
@@ -386,11 +385,7 @@ def _prompt(self) -> str:
386385
URL: {page_url}
387386
"""
388387
prompt_pieces.append(prompt_piece)
389-
self._prompt = "\n".join(prompt_pieces)
390-
391-
392-
def has_tab_action(action_set: bgym.HighLevelActionSetArgs):
393-
return "tab" in action_set.subsets
388+
return "\n".join(prompt_pieces)
394389

395390

396391
class Observation(Shrinkable):
@@ -399,14 +394,14 @@ class Observation(Shrinkable):
399394
Contains the html, the accessibility tree and the error logs.
400395
"""
401396

402-
def __init__(self, obs, flags: ObsFlags, use_tabs=False) -> None:
397+
def __init__(self, obs, flags: ObsFlags) -> None:
403398
super().__init__()
404399
self.flags = flags
405400
self.obs = obs
406401

407402
self.tabs = Tabs(
408403
obs,
409-
visible=use_tabs,
404+
visible=lambda: flags.use_tabs,
410405
prefix="## ",
411406
)
412407

src/agentlab/agents/generic_agent/generic_agent.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode):
3232
if benchmark.name.startswith("miniwob"):
3333
self.flags.obs.use_html = True
3434

35+
self.flags.obs.use_tabs = benchmark.is_multi_tab
3536
self.flags.action.action_set = deepcopy(benchmark.high_level_action_set_args)
3637

3738
# for backward compatibility with old traces
@@ -268,5 +269,3 @@ def get_action_post_hoc(agent: GenericAgent, obs: dict, ans_dict: dict):
268269
output += f"\n<action>\n{action}\n</action>"
269270

270271
return system_prompt, instruction_prompt, output
271-
return system_prompt, instruction_prompt, output
272-
return system_prompt, instruction_prompt, output

src/agentlab/agents/generic_agent/generic_agent_prompt.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def __init__(
7777
self.obs = dp.Observation(
7878
obs_history[-1],
7979
self.flags.obs,
80-
use_tabs=dp.has_tab_action(self.flags.action.action_set),
8180
)
8281

8382
self.action_prompt = dp.ActionPrompt(action_set, action_flags=flags.action)

tests/agents/test_generic_prompt.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,33 +20,28 @@
2020
</html>
2121
"""
2222

23+
base_obs = {
24+
"goal": "do this and that",
25+
"goal_object": [{"type": "text", "text": "do this and that"}],
26+
"chat_messages": [{"role": "user", "message": "do this and that"}],
27+
"axtree_txt": "[1] Click me",
28+
"focused_element_bid": "45-256",
29+
"open_pages_urls": ["https://example.com"],
30+
"open_pages_titles": ["Example"],
31+
"active_page_index": 0,
32+
}
2333

2434
OBS_HISTORY = [
25-
{
26-
"goal": "do this and that",
27-
"goal_object": [{"type": "text", "text": "do this and that"}],
28-
"chat_messages": [{"role": "user", "message": "do this and that"}],
35+
base_obs | {
2936
"pruned_html": html_template.format(1),
30-
"axtree_txt": "[1] Click me",
31-
"focused_element_bid": "45-256",
3237
"last_action_error": "",
3338
},
34-
{
35-
"goal": "do this and that",
36-
"goal_object": [{"type": "text", "text": "do this and that"}],
37-
"chat_messages": [{"role": "user", "message": "do this and that"}],
39+
base_obs | {
3840
"pruned_html": html_template.format(2),
39-
"axtree_txt": "[1] Click me",
40-
"focused_element_bid": "45-256",
4141
"last_action_error": "Hey, this is an error in the past",
4242
},
43-
{
44-
"goal": "do this and that",
45-
"goal_object": [{"type": "text", "text": "do this and that"}],
46-
"chat_messages": [{"role": "user", "message": "do this and that"}],
43+
base_obs | {
4744
"pruned_html": html_template.format(3),
48-
"axtree_txt": "[1] Click me",
49-
"focused_element_bid": "45-256",
5045
"last_action_error": "Hey, there is an error now",
5146
},
5247
]
@@ -58,6 +53,7 @@
5853
obs=dp.ObsFlags(
5954
use_html=True,
6055
use_ax_tree=True,
56+
use_tabs=True,
6157
use_focused_element=True,
6258
use_error_logs=True,
6359
use_history=True,
@@ -104,6 +100,10 @@
104100
"obs.use_ax_tree",
105101
("AXTree:", "Click me"),
106102
),
103+
(
104+
"obs.use_tabs",
105+
("Currently open tabs:","(active tab)"),
106+
),
107107
(
108108
"obs.use_focused_element",
109109
("Focused element:", "bid='45-256'"),
@@ -251,11 +251,6 @@ def test_main_prompt_elements_present():
251251
# for debugging
252252
test_shrinking_observation()
253253
test_main_prompt_elements_present()
254-
for flag, expected_prompts in FLAG_EXPECTED_PROMPT:
255-
test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts)
256-
test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts)
257-
test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts)
258-
test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts)
259-
test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts)
260-
test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts)
261-
test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts)
254+
# for flag, expected_prompts in FLAG_EXPECTED_PROMPT:
255+
# test_main_prompt_elements_gone_one_at_a_time(flag, expected_prompts)
256+

0 commit comments

Comments
 (0)