Skip to content

Commit 6e75e81

Browse files
authored
Create a simple pure visual agent. (#235)
* Add GenericAgent and prompt builder implementations for AgentLab * Implement VisualAgent and associated prompt flags for enhanced agent functionality * less filtering * Remove unused VisualAgentArgs for computer use from agent_configs.py
1 parent 6e052b6 commit 6e75e81

File tree

4 files changed

+362
-2
lines changed

4 files changed

+362
-2
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT
2+
3+
from .visual_agent import VisualAgentArgs
4+
from .visual_agent_prompts import PromptFlags
5+
import agentlab.agents.dynamic_prompting as dp
6+
import bgym
7+
8+
# the other flags are ignored for this agent.
9+
DEFAULT_OBS_FLAGS = dp.ObsFlags(
10+
use_tabs=True, # will be overridden by the benchmark when set_benchmark is called after initalizing the agent
11+
use_error_logs=True,
12+
use_past_error_logs=False,
13+
use_screenshot=True,
14+
use_som=False,
15+
openai_vision_detail="auto",
16+
)
17+
18+
DEFAULT_ACTION_FLAGS = dp.ActionFlags(
19+
action_set=bgym.HighLevelActionSetArgs(subsets=["coord"]),
20+
long_description=True,
21+
individual_examples=False,
22+
)
23+
24+
25+
DEFAULT_PROMPT_FLAGS = PromptFlags(
26+
obs=DEFAULT_OBS_FLAGS,
27+
action=DEFAULT_ACTION_FLAGS,
28+
use_thinking=True,
29+
use_concrete_example=False,
30+
use_abstract_example=True,
31+
enable_chat=False,
32+
extra_instructions=None,
33+
)
34+
35+
VISUAL_AGENT_4o = VisualAgentArgs(
36+
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-2024-05-13"],
37+
flags=DEFAULT_PROMPT_FLAGS,
38+
)
39+
40+
41+
VISUAL_AGENT_CLAUDE_3_5 = VisualAgentArgs(
42+
chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/anthropic/claude-3.5-sonnet:beta"],
43+
flags=DEFAULT_PROMPT_FLAGS,
44+
)
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""
2+
GenericAgent implementation for AgentLab
3+
4+
This module defines a `GenericAgent` class and its associated arguments for use in the AgentLab framework. \
5+
The `GenericAgent` class is designed to interact with a chat-based model to determine actions based on \
6+
observations. It includes methods for preprocessing observations, generating actions, and managing internal \
7+
state such as plans, memories, and thoughts. The `GenericAgentArgs` class provides configuration options for \
8+
the agent, including model arguments and flags for various behaviors.
9+
"""
10+
11+
from dataclasses import asdict, dataclass
12+
13+
import bgym
14+
from browsergym.experiments.agent import Agent, AgentInfo
15+
16+
from agentlab.agents import dynamic_prompting as dp
17+
from agentlab.agents.agent_args import AgentArgs
18+
from agentlab.llm.chat_api import BaseModelArgs
19+
from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry
20+
from agentlab.llm.tracking import cost_tracker_decorator
21+
22+
from .visual_agent_prompts import PromptFlags, MainPrompt
23+
24+
25+
@dataclass
26+
class VisualAgentArgs(AgentArgs):
27+
chat_model_args: BaseModelArgs = None
28+
flags: PromptFlags = None
29+
max_retry: int = 4
30+
31+
def __post_init__(self):
32+
try: # some attributes might be missing temporarily due to args.CrossProd for hyperparameter generation
33+
self.agent_name = f"VisualAgent-{self.chat_model_args.model_name}".replace("/", "_")
34+
except AttributeError:
35+
pass
36+
37+
def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode):
38+
"""Override Some flags based on the benchmark."""
39+
self.flags.obs.use_tabs = benchmark.is_multi_tab
40+
41+
def set_reproducibility_mode(self):
42+
self.chat_model_args.temperature = 0
43+
44+
def prepare(self):
45+
return self.chat_model_args.prepare_server()
46+
47+
def close(self):
48+
return self.chat_model_args.close_server()
49+
50+
def make_agent(self):
51+
return VisualAgent(
52+
chat_model_args=self.chat_model_args, flags=self.flags, max_retry=self.max_retry
53+
)
54+
55+
56+
class VisualAgent(Agent):
57+
58+
def __init__(
59+
self,
60+
chat_model_args: BaseModelArgs,
61+
flags: PromptFlags,
62+
max_retry: int = 4,
63+
):
64+
65+
self.chat_llm = chat_model_args.make_model()
66+
self.chat_model_args = chat_model_args
67+
self.max_retry = max_retry
68+
69+
self.flags = flags
70+
self.action_set = self.flags.action.action_set.make_action_set()
71+
self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs)
72+
73+
self.reset(seed=None)
74+
75+
def obs_preprocessor(self, obs: dict) -> dict:
76+
return self._obs_preprocessor(obs)
77+
78+
@cost_tracker_decorator
79+
def get_action(self, obs):
80+
81+
main_prompt = MainPrompt(
82+
action_set=self.action_set,
83+
obs=obs,
84+
actions=self.actions,
85+
thoughts=self.thoughts,
86+
flags=self.flags,
87+
)
88+
89+
system_prompt = SystemMessage(dp.SystemPrompt().prompt)
90+
try:
91+
# TODO, we would need to further shrink the prompt if the retry
92+
# cause it to be too long
93+
94+
chat_messages = Discussion([system_prompt, main_prompt.prompt])
95+
ans_dict = retry(
96+
self.chat_llm,
97+
chat_messages,
98+
n_retry=self.max_retry,
99+
parser=main_prompt._parse_answer,
100+
)
101+
ans_dict["busted_retry"] = 0
102+
# inferring the number of retries, TODO: make this less hacky
103+
ans_dict["n_retry"] = (len(chat_messages) - 3) / 2
104+
except ParseError:
105+
ans_dict = dict(
106+
action=None,
107+
n_retry=self.max_retry + 1,
108+
busted_retry=1,
109+
)
110+
111+
stats = self.chat_llm.get_stats()
112+
stats["n_retry"] = ans_dict["n_retry"]
113+
stats["busted_retry"] = ans_dict["busted_retry"]
114+
115+
self.actions.append(ans_dict["action"])
116+
self.thoughts.append(ans_dict.get("think", None))
117+
118+
agent_info = AgentInfo(
119+
think=ans_dict.get("think", None),
120+
chat_messages=chat_messages,
121+
stats=stats,
122+
extra_info={"chat_model_args": asdict(self.chat_model_args)},
123+
)
124+
return ans_dict["action"], agent_info
125+
126+
def reset(self, seed=None):
127+
self.seed = seed
128+
self.thoughts = []
129+
self.actions = []
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
"""
2+
Prompt builder for GenericAgent
3+
4+
It is based on the dynamic_prompting module from the agentlab package.
5+
"""
6+
7+
import logging
8+
from dataclasses import dataclass
9+
import bgym
10+
11+
from browsergym.core.action.base import AbstractActionSet
12+
13+
from agentlab.agents import dynamic_prompting as dp
14+
from agentlab.llm.llm_utils import BaseMessage, HumanMessage, image_to_jpg_base64_url
15+
16+
17+
@dataclass
18+
class PromptFlags(dp.Flags):
19+
"""
20+
A class to represent various flags used to control features in an application.
21+
"""
22+
23+
obs: dp.ObsFlags = None
24+
action: dp.ActionFlags = None
25+
use_thinking: bool = True
26+
use_concrete_example: bool = False
27+
use_abstract_example: bool = True
28+
enable_chat: bool = False
29+
extra_instructions: str | None = None
30+
31+
32+
class SystemPrompt(dp.PromptElement):
33+
_prompt = """\
34+
You are an agent trying to solve a web task based on the content of the page and
35+
user instructions. You can interact with the page and explore, and send messages to the user. Each time you
36+
submit an action it will be sent to the browser and you will receive a new page."""
37+
38+
39+
def make_instructions(obs: dict, from_chat: bool, extra_instructions: str | None):
40+
"""Convenient wrapper to extract instructions from either goal or chat"""
41+
if from_chat:
42+
instructions = dp.ChatInstructions(
43+
obs["chat_messages"], extra_instructions=extra_instructions
44+
)
45+
else:
46+
if sum([msg["role"] == "user" for msg in obs.get("chat_messages", [])]) > 1:
47+
logging.warning(
48+
"Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`."
49+
)
50+
instructions = dp.GoalInstructions(
51+
obs["goal_object"], extra_instructions=extra_instructions
52+
)
53+
return instructions
54+
55+
56+
class History(dp.PromptElement):
57+
"""
58+
Format the actions and thoughts of previous steps."""
59+
60+
def __init__(self, actions, thoughts) -> None:
61+
super().__init__()
62+
prompt_elements = []
63+
for i, (action, thought) in enumerate(zip(actions, thoughts)):
64+
prompt_elements.append(
65+
f"""
66+
## Step {i}
67+
### Thoughts:
68+
{thought}
69+
### Action:
70+
{action}
71+
"""
72+
)
73+
self._prompt = "\n".join(prompt_elements) + "\n"
74+
75+
76+
class Observation(dp.PromptElement):
77+
"""Observation of the current step.
78+
79+
Contains the html, the accessibility tree and the error logs.
80+
"""
81+
82+
def __init__(self, obs, flags: dp.ObsFlags) -> None:
83+
super().__init__()
84+
self.flags = flags
85+
self.obs = obs
86+
87+
# for a multi-tab browser, we need to show the current tab
88+
self.tabs = dp.Tabs(
89+
obs,
90+
visible=lambda: flags.use_tabs,
91+
prefix="## ",
92+
)
93+
94+
# if an error is present, we need to show it
95+
self.error = dp.Error(
96+
obs["last_action_error"],
97+
visible=lambda: flags.use_error_logs and obs["last_action_error"],
98+
prefix="## ",
99+
)
100+
101+
@property
102+
def _prompt(self) -> str:
103+
return f"""
104+
# Observation of current step:
105+
{self.tabs.prompt}{self.error.prompt}
106+
107+
"""
108+
109+
def add_screenshot(self, prompt: BaseMessage) -> BaseMessage:
110+
if self.flags.use_screenshot:
111+
if self.flags.use_som:
112+
screenshot = self.obs["screenshot_som"]
113+
prompt.add_text(
114+
"\n## Screenshot:\nHere is a screenshot of the page, it is annotated with bounding boxes and corresponding bids:"
115+
)
116+
else:
117+
screenshot = self.obs["screenshot"]
118+
prompt.add_text("\n## Screenshot:\nHere is a screenshot of the page:")
119+
img_url = image_to_jpg_base64_url(screenshot)
120+
prompt.add_image(img_url, detail=self.flags.openai_vision_detail)
121+
return prompt
122+
123+
124+
class MainPrompt(dp.PromptElement):
125+
126+
def __init__(
127+
self,
128+
action_set: AbstractActionSet,
129+
obs: dict,
130+
actions: list[str],
131+
thoughts: list[str],
132+
flags: PromptFlags,
133+
) -> None:
134+
super().__init__()
135+
self.flags = flags
136+
self.history = History(actions, thoughts)
137+
self.instructions = make_instructions(obs, flags.enable_chat, flags.extra_instructions)
138+
self.obs = Observation(obs, self.flags.obs)
139+
140+
self.action_prompt = dp.ActionPrompt(action_set, action_flags=flags.action)
141+
self.think = dp.Think(visible=lambda: flags.use_thinking)
142+
143+
@property
144+
def _prompt(self) -> HumanMessage:
145+
prompt = HumanMessage(self.instructions.prompt)
146+
prompt.add_text(
147+
f"""\
148+
{self.obs.prompt}\
149+
{self.history.prompt}\
150+
{self.action_prompt.prompt}\
151+
{self.think.prompt}\
152+
"""
153+
)
154+
155+
if self.flags.use_abstract_example:
156+
prompt.add_text(
157+
f"""
158+
# Abstract Example
159+
160+
Here is an abstract version of the answer with description of the content of
161+
each tag. Make sure you follow this structure, but replace the content with your
162+
answer:
163+
{self.think.abstract_ex}\
164+
{self.action_prompt.abstract_ex}\
165+
"""
166+
)
167+
168+
if self.flags.use_concrete_example:
169+
prompt.add_text(
170+
f"""
171+
# Concrete Example
172+
173+
Here is a concrete example of how to format your answer.
174+
Make sure to follow the template with proper tags:
175+
{self.think.concrete_ex}\
176+
{self.action_prompt.concrete_ex}\
177+
"""
178+
)
179+
return self.obs.add_screenshot(prompt)
180+
181+
def _parse_answer(self, text_answer):
182+
ans_dict = {}
183+
ans_dict.update(self.think.parse_answer(text_answer))
184+
ans_dict.update(self.action_prompt.parse_answer(text_answer))
185+
return ans_dict

src/agentlab/experiments/list_openai_models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
df = pd.DataFrame([dict(model) for model in models.data])
77

88
# Filter GPT models or o1 models
9-
df = df[df["id"].str.contains("gpt") | df["id"].str.contains("o1")]
9+
# df = df[df["id"].str.contains("gpt") | df["id"].str.contains("o1")]
1010

1111
# Convert Unix timestamps to dates (YYYY-MM-DD) and remove time
1212
df["created"] = pd.to_datetime(df["created"], unit="s").dt.date
1313
df.sort_values(by="created", inplace=True)
1414
# Print all entries
15-
print(df)
15+
16+
# print all entries
17+
print(df.to_string(index=False))

0 commit comments

Comments
 (0)