Skip to content

Commit 1b1c4dc

Browse files
committed
Add GenericAgent and prompt builder implementations for AgentLab
1 parent 6e052b6 commit 1b1c4dc

File tree

2 files changed

+414
-0
lines changed

2 files changed

+414
-0
lines changed
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
"""
2+
GenericAgent implementation for AgentLab
3+
4+
This module defines a `GenericAgent` class and its associated arguments for use in the AgentLab framework. \
5+
The `GenericAgent` class is designed to interact with a chat-based model to determine actions based on \
6+
observations. It includes methods for preprocessing observations, generating actions, and managing internal \
7+
state such as plans, memories, and thoughts. The `GenericAgentArgs` class provides configuration options for \
8+
the agent, including model arguments and flags for various behaviors.
9+
"""
10+
11+
from copy import deepcopy
12+
from dataclasses import asdict, dataclass
13+
from warnings import warn
14+
15+
import bgym
16+
from browsergym.experiments.agent import Agent, AgentInfo
17+
18+
from agentlab.agents import dynamic_prompting as dp
19+
from agentlab.agents.agent_args import AgentArgs
20+
from agentlab.llm.chat_api import BaseModelArgs
21+
from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry
22+
from agentlab.llm.tracking import cost_tracker_decorator
23+
24+
from .visual_agent_prompts import GenericPromptFlags, MainPrompt
25+
from functools import partial
26+
27+
28+
@dataclass
29+
class ToolAgentFlags:
30+
pass
31+
32+
33+
@dataclass
34+
class ToolAgentArgs(AgentArgs):
35+
chat_model_args: BaseModelArgs = None
36+
flags: GenericPromptFlags = None
37+
max_retry: int = 4
38+
39+
def __post_init__(self):
40+
try: # some attributes might be temporarily args.CrossProd for hyperparameter generation
41+
self.agent_name = f"GenericAgent-{self.chat_model_args.model_name}".replace("/", "_")
42+
except AttributeError:
43+
pass
44+
45+
def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode):
46+
"""Override Some flags based on the benchmark."""
47+
if benchmark.name.startswith("miniwob"):
48+
self.flags.obs.use_html = True
49+
50+
self.flags.obs.use_tabs = benchmark.is_multi_tab
51+
self.flags.action.action_set = deepcopy(benchmark.high_level_action_set_args)
52+
53+
# for backward compatibility with old traces
54+
if self.flags.action.multi_actions is not None:
55+
self.flags.action.action_set.multiaction = self.flags.action.multi_actions
56+
if self.flags.action.is_strict is not None:
57+
self.flags.action.action_set.strict = self.flags.action.is_strict
58+
59+
# verify if we can remove this
60+
if demo_mode:
61+
self.flags.action.action_set.demo_mode = "all_blue"
62+
63+
def set_reproducibility_mode(self):
64+
self.chat_model_args.temperature = 0
65+
66+
def prepare(self):
67+
return self.chat_model_args.prepare_server()
68+
69+
def close(self):
70+
return self.chat_model_args.close_server()
71+
72+
def make_agent(self):
73+
return ToolAgent(
74+
chat_model_args=self.chat_model_args, flags=self.flags, max_retry=self.max_retry
75+
)
76+
77+
78+
class ToolAgent(Agent):
79+
80+
def __init__(
81+
self,
82+
chat_model_args: BaseModelArgs,
83+
flags: GenericPromptFlags,
84+
max_retry: int = 4,
85+
):
86+
87+
self.chat_llm = chat_model_args.make_model()
88+
self.chat_model_args = chat_model_args
89+
self.max_retry = max_retry
90+
91+
self.flags = flags
92+
self.action_set = self.flags.action.action_set.make_action_set()
93+
self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs)
94+
95+
self._check_flag_constancy()
96+
self.reset(seed=None)
97+
98+
def obs_preprocessor(self, obs: dict) -> dict:
99+
return self._obs_preprocessor(obs)
100+
101+
@cost_tracker_decorator
102+
def get_action(self, obs):
103+
104+
self.obs_history.append(obs)
105+
main_prompt = MainPrompt(
106+
action_set=self.action_set,
107+
obs_history=self.obs_history,
108+
actions=self.actions,
109+
memories=self.memories,
110+
thoughts=self.thoughts,
111+
previous_plan=self.plan,
112+
step=self.plan_step,
113+
flags=self.flags,
114+
)
115+
116+
max_prompt_tokens, max_trunc_itr = self._get_maxes()
117+
118+
system_prompt = SystemMessage(dp.SystemPrompt().prompt)
119+
120+
human_prompt = dp.fit_tokens(
121+
shrinkable=main_prompt,
122+
max_prompt_tokens=max_prompt_tokens,
123+
model_name=self.chat_model_args.model_name,
124+
max_iterations=max_trunc_itr,
125+
additional_prompts=system_prompt,
126+
)
127+
try:
128+
# TODO, we would need to further shrink the prompt if the retry
129+
# cause it to be too long
130+
131+
chat_messages = Discussion([system_prompt, human_prompt])
132+
ans_dict = retry(
133+
self.chat_llm,
134+
chat_messages,
135+
n_retry=self.max_retry,
136+
parser=main_prompt._parse_answer,
137+
)
138+
ans_dict["busted_retry"] = 0
139+
# inferring the number of retries, TODO: make this less hacky
140+
ans_dict["n_retry"] = (len(chat_messages) - 3) / 2
141+
except ParseError as e:
142+
ans_dict = dict(
143+
action=None,
144+
n_retry=self.max_retry + 1,
145+
busted_retry=1,
146+
)
147+
148+
stats = self.chat_llm.get_stats()
149+
stats["n_retry"] = ans_dict["n_retry"]
150+
stats["busted_retry"] = ans_dict["busted_retry"]
151+
152+
self.plan = ans_dict.get("plan", self.plan)
153+
self.plan_step = ans_dict.get("step", self.plan_step)
154+
self.actions.append(ans_dict["action"])
155+
self.memories.append(ans_dict.get("memory", None))
156+
self.thoughts.append(ans_dict.get("think", None))
157+
158+
agent_info = AgentInfo(
159+
think=ans_dict.get("think", None),
160+
chat_messages=chat_messages,
161+
stats=stats,
162+
extra_info={"chat_model_args": asdict(self.chat_model_args)},
163+
)
164+
return ans_dict["action"], agent_info
165+
166+
def reset(self, seed=None):
167+
self.seed = seed
168+
self.plan = "No plan yet"
169+
self.plan_step = -1
170+
self.memories = []
171+
self.thoughts = []
172+
self.actions = []
173+
self.obs_history = []
174+
175+
def _check_flag_constancy(self):
176+
flags = self.flags
177+
if flags.obs.use_som:
178+
if not flags.obs.use_screenshot:
179+
warn(
180+
"""
181+
Warning: use_som=True requires use_screenshot=True. Disabling use_som."""
182+
)
183+
flags.obs.use_som = False
184+
if flags.obs.use_screenshot:
185+
if not self.chat_model_args.vision_support:
186+
warn(
187+
"""
188+
Warning: use_screenshot is set to True, but the chat model \
189+
does not support vision. Disabling use_screenshot."""
190+
)
191+
flags.obs.use_screenshot = False
192+
return flags
193+
194+
def _get_maxes(self):
195+
maxes = (
196+
self.flags.max_prompt_tokens,
197+
self.chat_model_args.max_total_tokens,
198+
self.chat_model_args.max_input_tokens,
199+
)
200+
maxes = [m for m in maxes if m is not None]
201+
max_prompt_tokens = min(maxes) if maxes else None
202+
max_trunc_itr = (
203+
self.flags.max_trunc_itr
204+
if self.flags.max_trunc_itr
205+
else 20 # dangerous to change the default value here?
206+
)
207+
return max_prompt_tokens, max_trunc_itr

0 commit comments

Comments
 (0)