Skip to content

Commit 379508b

Browse files
TLSDCgasse
authored andcommitted
updates to recent API changes
1 parent da78fa0 commit 379508b

File tree

1 file changed

+26
-35
lines changed
  • src/agentlab/agents/visualwebarena

1 file changed

+26
-35
lines changed

src/agentlab/agents/visualwebarena/agent.py

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import dataclasses
33
import io
44
import re
5-
import tempfile
65
from io import BytesIO
76

87
from browsergym.core.action.highlevel import HighLevelActionSet
@@ -11,8 +10,9 @@
1110
from PIL import Image
1211

1312
from agentlab.agents.agent_args import AgentArgs
14-
from agentlab.llm.chat_api import BaseModelArgs, make_system_message, make_user_message
15-
from agentlab.llm.llm_utils import ParseError, extract_code_blocks, retry
13+
from agentlab.llm.chat_api import BaseModelArgs
14+
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT
15+
from agentlab.llm.llm_utils import Discussion, HumanMessage, ParseError, SystemMessage, retry
1616

1717

1818
def pil_to_b64(img: Image.Image) -> str:
@@ -74,10 +74,20 @@ def get_action(self, obs: dict) -> tuple[str, dict]:
7474
possible next action to accomplish your goal. Your answer will be interpreted
7575
and executed by a program, make sure to follow the formatting instructions."""
7676

77-
user_prompt = f"""\
77+
prompt = Discussion(SystemMessage(system_prompt))
78+
79+
prompt.append(
80+
HumanMessage(
81+
f"""\
7882
# Goal:
79-
{obs["goal_object"][0]["text"]}
83+
"""
84+
)
85+
)
86+
for goal in obs["goal_object"]:
87+
prompt.add_content(goal["type"], goal[goal["type"]])
8088

89+
prompt.add_text(
90+
f"""
8191
# Current Accessibility Tree:
8292
{obs["axtree_txt"]}
8393
@@ -95,31 +105,12 @@ def get_action(self, obs: dict) -> tuple[str, dict]:
95105
```send_msg_to_user("blue")```
96106
"
97107
"""
98-
# prompt
99-
user_msgs = [{"type": "text", "text": user_prompt}]
100-
101-
# screenshot
102-
user_msgs = [
103-
{
104-
"type": "text",
105-
"text": "IMAGES: current page screenshot",
106-
},
107-
{
108-
"type": "image_url",
109-
"image_url": {
110-
"url": pil_to_b64(
111-
Image.fromarray(overlay_som(obs["screenshot"], obs["extra_properties"]))
112-
)
113-
},
114-
},
115-
]
116-
# additional images
117-
user_msgs.extend(obs["goal_object"][1:])
118-
119-
messages = [
120-
make_system_message(system_prompt),
121-
make_user_message(user_prompt),
122-
]
108+
)
109+
110+
prompt.add_text("IMAGES: current page screenshot")
111+
prompt.add_image(
112+
pil_to_b64(Image.fromarray(overlay_som(obs["screenshot"], obs["extra_properties"])))
113+
)
123114

124115
def parser(response: str) -> tuple[dict, bool, str]:
125116
pattern = r"```((.|\\n)*?)```"
@@ -130,12 +121,12 @@ def parser(response: str) -> tuple[dict, bool, str]:
130121
thought = response
131122
return {"action": action, "think": thought}
132123

133-
response = retry(self.chat_llm, messages, n_retry=self.n_retry, parser=parser)
124+
response = retry(self.chat_llm, prompt, n_retry=self.n_retry, parser=parser)
134125

135126
action = response.get("action", None)
136-
stats = dict(response.usage)
127+
stats = self.chat_llm.get_stats()
137128
return action, AgentInfo(
138-
chat_messages=messages,
129+
chat_messages=prompt.to_markdown(),
139130
think=response.get("think", None),
140131
stats=stats,
141132
)
@@ -155,10 +146,10 @@ class VWAAgentArgs(AgentArgs):
155146
chat_model_args: BaseModelArgs = None
156147

157148
def make_agent(self):
158-
return VWAAgent()
149+
return VWAAgent(chat_model_args=self.chat_model_args, n_retry=3)
159150

160151

161-
CONFIG = VWAAgentArgs(model_name="gpt-4-1106-vision-preview")
152+
CONFIG = VWAAgentArgs(CHAT_MODEL_ARGS_DICT["openai/gpt-4o-mini-2024-07-18"])
162153

163154

164155
def main():

0 commit comments

Comments
 (0)