Skip to content

Commit be1998c

Browse files
gassefrankxu2004TLSDC
committed
(Visual)WebArena agent again (#142)
--------- Co-authored-by: Frank Xu <[email protected]> Co-authored-by: Thibault Le Sellier de Chezelles <[email protected]>
1 parent 4dde673 commit be1998c

File tree

15 files changed

+635
-5
lines changed

15 files changed

+635
-5
lines changed

main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
"""
88

99
import logging
10+
1011
from agentlab.agents.generic_agent import (
12+
AGENT_LLAMA3_70B,
13+
AGENT_LLAMA31_70B,
1114
RANDOM_SEARCH_AGENT,
1215
AGENT_4o,
1316
AGENT_4o_MINI,
14-
AGENT_LLAMA3_70B,
15-
AGENT_LLAMA31_70B,
1617
)
1718
from agentlab.experiments.study import Study
1819

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ requests
2222
matplotlib
2323
ray[default]
2424
python-slugify
25+
pillow

src/agentlab/agents/dynamic_prompting.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010

1111
import bgym
1212
from browsergym.core.action.base import AbstractActionSet
13-
from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, overlay_som, prune_html
13+
from browsergym.utils.obs import (
14+
flatten_axtree_to_str,
15+
flatten_dom_to_str,
16+
overlay_som,
17+
prune_html,
18+
)
1419

1520
from agentlab.llm.llm_utils import (
1621
BaseMessage,
Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
import base64
2+
import importlib.resources
3+
import io
4+
from copy import deepcopy
5+
from dataclasses import dataclass
6+
from functools import partial
7+
from typing import Any, Literal
8+
9+
import numpy as np
10+
import PIL.Image
11+
from browsergym.core.action.highlevel import HighLevelActionSet
12+
from browsergym.experiments import Agent, AgentInfo
13+
from browsergym.experiments.benchmark import Benchmark, HighLevelActionSetArgs
14+
from browsergym.utils.obs import overlay_som
15+
16+
from agentlab.llm.base_api import AbstractChatModel
17+
from agentlab.llm.chat_api import BaseModelArgs, make_system_message, make_user_message
18+
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT
19+
from agentlab.llm.llm_utils import ParseError, extract_code_blocks, retry
20+
from agentlab.llm.tracking import cost_tracker_decorator
21+
22+
from ..agent_args import AgentArgs
23+
from . import few_shots
24+
from .prompts import TEMPLATES
25+
26+
FEW_SHOT_FILES = importlib.resources.files(few_shots)
27+
VisualWebArenaObservationType = Literal["axtree", "axtree_som", "axtree_screenshot"]
28+
29+
30+
def image_data_to_uri(
31+
image_data: bytes | np.ndarray, output_format: Literal["png", "jpeg"] = "png"
32+
) -> str:
33+
assert output_format in ("png", "jpeg")
34+
# load input image data (auto-detect input format)
35+
if isinstance(image_data, np.ndarray):
36+
image = PIL.Image.fromarray(image_data)
37+
else:
38+
image = PIL.Image.open(io.BytesIO(image_data))
39+
# TODO: is this necessary?
40+
if image.mode in ("RGBA", "LA"):
41+
image = image.convert("RGB")
42+
# convert image to desired output format
43+
with io.BytesIO() as image_buffer:
44+
image.save(image_buffer, format=output_format.upper())
45+
image_data = image_buffer.getvalue()
46+
# convert to base64 data/image URI
47+
image_b64 = base64.b64encode(image_data).decode("utf-8")
48+
image_b64 = f"data:image/{output_format};base64," + image_b64
49+
return image_b64
50+
51+
52+
@dataclass
53+
class VisualWebArenaAgentArgs(AgentArgs):
54+
agent_name: str = "VisualWebArenaAgent"
55+
temperature: float = 0.1
56+
chat_model_args: BaseModelArgs = None
57+
action_set_args: HighLevelActionSetArgs = None
58+
observation_type: VisualWebArenaObservationType = "axtree_som"
59+
with_few_shot_examples: bool = True
60+
61+
def __post_init__(self):
62+
self.agent_name = (
63+
f"{self.agent_name}-{self.observation_type}-{self.chat_model_args.model_name}".replace(
64+
"/", "_"
65+
)
66+
)
67+
68+
def make_agent(self) -> Agent:
69+
return VisualWebArenaAgent(
70+
temperature=self.temperature,
71+
chat_model=self.chat_model_args.make_model(),
72+
action_set=self.action_set_args.make_action_set(),
73+
observation_type=self.observation_type,
74+
with_few_shot_examples=self.with_few_shot_examples,
75+
)
76+
77+
def set_benchmark(self, benchmark: Benchmark, demo_mode: bool):
78+
self.action_set_args = deepcopy(benchmark.high_level_action_set_args)
79+
80+
def set_reproducibility_mode(self):
81+
self.temperature = 0.0
82+
83+
def prepare(self):
84+
return self.chat_model_args.prepare_server()
85+
86+
def close(self):
87+
return self.chat_model_args.close_server()
88+
89+
90+
def parser(response: str) -> dict:
91+
blocks = extract_code_blocks(response)
92+
if len(blocks) == 0:
93+
raise ParseError("No code block found in the response")
94+
action = blocks[0][1]
95+
thought = response
96+
return {"action": action, "think": thought}
97+
98+
99+
class VisualWebArenaAgent(Agent):
100+
def __init__(
101+
self,
102+
temperature: float,
103+
chat_model: AbstractChatModel,
104+
action_set: HighLevelActionSet,
105+
observation_type: VisualWebArenaObservationType,
106+
with_few_shot_examples: bool,
107+
):
108+
self.temperature = temperature
109+
self.chat_model = chat_model
110+
self.action_set = action_set
111+
self.observation_type = observation_type
112+
self.with_few_shot_examples = with_few_shot_examples
113+
114+
self.action_history = ["None"]
115+
116+
self.intro_messages: list[dict] = []
117+
118+
# pre-build the prompt's intro message
119+
self.intro_messages.append(
120+
{
121+
"type": "text",
122+
"text": TEMPLATES[observation_type]["intro"].format(
123+
action_space_description=self.action_set.describe(
124+
with_long_description=True, with_examples=False
125+
)
126+
),
127+
}
128+
)
129+
130+
self.few_shot_messages: list[dict] = []
131+
132+
# pre-build the prompt's few-shot example messages
133+
if with_few_shot_examples:
134+
examples = TEMPLATES[observation_type]["examples"]
135+
for i, example in enumerate(examples):
136+
if len(example) == 2:
137+
# text-only example
138+
observation, action = example
139+
self.few_shot_messages.append(
140+
{
141+
"type": "text",
142+
"text": f"""\
143+
Example {i + 1}/{len(examples)}:
144+
145+
{observation}
146+
ACTION: {action}
147+
""",
148+
}
149+
)
150+
elif len(example) == 3:
151+
# example with screenshot
152+
observation, action, screenshot_filename = example
153+
screenshot_data = FEW_SHOT_FILES.joinpath(screenshot_filename).read_bytes()
154+
self.few_shot_messages.extend(
155+
[
156+
{
157+
"type": "text",
158+
"text": f"""\
159+
Example {i + 1}/{len(examples)}:
160+
161+
{observation}
162+
""",
163+
},
164+
{
165+
"type": "text",
166+
"text": """\
167+
SCREENSHOT:
168+
""",
169+
},
170+
{
171+
"type": "image_url",
172+
"image_url": {"url": image_data_to_uri(screenshot_data)},
173+
},
174+
{
175+
"type": "text",
176+
"text": f"""\
177+
ACTION: {action}
178+
""",
179+
},
180+
]
181+
)
182+
else:
183+
raise ValueError("Unexpected format for few-shot example.")
184+
185+
@cost_tracker_decorator
186+
def get_action(self, obs: Any) -> tuple[str, dict]:
187+
"""
188+
Replica of VisualWebArena agent
189+
https://github.com/web-arena-x/visualwebarena/blob/89f5af29305c3d1e9f97ce4421462060a70c9a03/agent/prompts/prompt_constructor.py#L211
190+
https://github.com/web-arena-x/visualwebarena/blob/89f5af29305c3d1e9f97ce4421462060a70c9a03/agent/prompts/prompt_constructor.py#L272
191+
"""
192+
user_messages = []
193+
194+
# 1. add few-shot examples (if any)
195+
user_messages.extend(self.few_shot_messages)
196+
197+
# 2. add the current observation to the user prompt
198+
active_tab = obs["active_page_index"][0]
199+
open_tab_titles = obs["open_pages_titles"]
200+
cur_tabs_txt = " | ".join(
201+
f"Tab {i}{' (current)' if i == active_tab else ''}: {title}"
202+
for i, title in enumerate(open_tab_titles)
203+
)
204+
cur_axtree_txt = obs["axtree_txt"]
205+
cur_url = obs["url"]
206+
user_messages.append(
207+
{
208+
"type": "text",
209+
"text": f"""\
210+
OBSERVATION:
211+
212+
{cur_tabs_txt}
213+
214+
{cur_axtree_txt}
215+
216+
URL: {cur_url}
217+
218+
PREVIOUS ACTION: {self.action_history[-1]}
219+
""",
220+
}
221+
)
222+
223+
# if desired, add current page's screenshot
224+
if self.observation_type in ("axtree_som", "axtree_screenshot"):
225+
cur_screenshot = obs["screenshot"]
226+
# if desired, overlay set-of-marks on the screenshot
227+
if self.observation_type == "axtree_som":
228+
cur_screenshot = overlay_som(cur_screenshot, obs["extra_element_properties"])
229+
user_messages.extend(
230+
[
231+
{
232+
"type": "text",
233+
"text": """\
234+
SCREENSHOT:
235+
""",
236+
},
237+
{"type": "image_url", "image_url": {"url": image_data_to_uri(cur_screenshot)}},
238+
]
239+
)
240+
241+
# 3. add the objective (goal) to the user prompt
242+
user_messages.append(
243+
{
244+
"type": "text",
245+
"text": f"""\
246+
OBJECTIVE:
247+
""",
248+
}
249+
)
250+
user_messages.extend(obs["goal_object"])
251+
252+
messages = [
253+
# intro prompt
254+
make_system_message(content=self.intro_messages),
255+
# few-shot examples + observation + goal
256+
make_user_message(content=user_messages),
257+
]
258+
259+
# finally, query the chat model
260+
answer: dict = retry(self.chat_model, messages, n_retry=3, parser=parser)
261+
262+
action = answer.get("action", None)
263+
thought = answer.get("think", None)
264+
265+
self.action_history.append(action)
266+
267+
return (
268+
action,
269+
AgentInfo(
270+
think=thought,
271+
chat_messages=messages,
272+
),
273+
)
274+
275+
276+
# A WebArena agent is a VisualWebArena agent with only axtree observation
277+
WebArenaAgent = partial(
278+
VisualWebArenaAgentArgs,
279+
agent_name="WebArenaAgent",
280+
observation_type="axtree",
281+
)
282+
283+
WA_AGENT_4O_MINI = WebArenaAgent(
284+
temperature=0.1,
285+
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-mini-2024-07-18"],
286+
)
287+
288+
WA_AGENT_4O = WebArenaAgent(
289+
temperature=0.1,
290+
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-mini-2024-07-18"],
291+
)
292+
293+
WA_AGENT_SONNET = WebArenaAgent(
294+
temperature=0.1,
295+
chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/anthropic/claude-3.5-sonnet:beta"],
296+
)
297+
298+
VWA_AGENT_4O_MINI = VisualWebArenaAgentArgs(
299+
temperature=0.1,
300+
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-mini-2024-07-18"],
301+
)
302+
303+
VWA_AGENT_4O = VisualWebArenaAgentArgs(
304+
temperature=0.1,
305+
chat_model_args=CHAT_MODEL_ARGS_DICT["azure/gpt-4o-2024-08-06"],
306+
)
307+
308+
VWA_AGENT_SONNET = VisualWebArenaAgentArgs(
309+
temperature=0.1,
310+
chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/anthropic/claude-3.5-sonnet:beta"],
311+
)

src/agentlab/agents/visualwebarena/few_shots/__init__.py

Whitespace-only changes.
253 KB
Loading
210 KB
Loading
312 KB
Loading
282 KB
Loading
222 KB
Loading

0 commit comments

Comments
 (0)