Skip to content

Commit 6481fae

Browse files
committed
adding task hints
1 parent 21be58f commit 6481fae

File tree

3 files changed

+87
-21
lines changed

3 files changed

+87
-21
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
time_stamp,task_name,task_seed,base_llm,agent_name,domain_name,user_name,source,semantic_keys,hint
2+
June 4,miniwob.book-flight,2,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,drop down menu,"some drop down menu will have a list of choice to select from, after typing. Make sure you select an element from that list."
3+
June 4,miniwob.book-flight,2,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,active input,the currently activated input is surrounded by a blue rectangle
4+
June 4,miniwob.book-flight,2,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,active input,GUI elements surrounded by a red rectangle often means there is an error in the content
5+
June 4,miniwob.book-flight,2,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,active input,The scroll bar indicates that there is more than 1 flights available in the search. Make sure to select the one matching the task goal among all possible flights.

src/agentlab/agents/tool_use_agent/multi_tool_agent.py

Lines changed: 79 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
import fnmatch
2+
from abc import ABC, abstractmethod
13
from copy import copy
2-
from dataclasses import dataclass
4+
from dataclasses import asdict, dataclass
5+
from pathlib import Path
36
from typing import Any
47

58
import bgym
69
import numpy as np
10+
import pandas as pd
711
from browsergym.core.observation import extract_screenshot
812
from browsergym.utils.obs import (
913
flatten_axtree_to_str,
@@ -29,10 +33,39 @@
2933
from agentlab.llm.tracking import cost_tracker_decorator
3034

3135

32-
class Block:
36+
@dataclass
37+
class Block(ABC):
38+
39+
def _init(self):
40+
"""Initialize the block."""
41+
pass
42+
43+
def make(self) -> "Block":
44+
"""Returns a copy so the init can start adding some stuff to `self` without changing the
45+
original datatclass that should only contain a config.
46+
The aim is avoid having 2 calss definition for each block, e.g. Block and BlockArgs."""
47+
block = self.__class__(**asdict(self))
48+
block._init()
49+
return block
50+
51+
@abstractmethod
52+
def apply(self, llm, messages: list[MessageBuilder], **kwargs):
53+
pass
3354

34-
def make(self):
35-
return self
55+
56+
# @dataclass
57+
# class BlockArgs(ABC):
58+
59+
# @abstractmethod
60+
# def make(self) -> Block:
61+
# """Make a block from the arguments."""
62+
# return self.__class__(**asdict(self))
63+
64+
65+
SYS_MSG = """You are a web agent. Based on the observation, you will decide which action to take to accomplish your goal.
66+
You strive for excellence and need to be as meticulous as possible. Make sure to explore when not sure.
67+
Your chain of thought should have 3 sections: 1) Analyze the effect of the action, 2) Summarize the current state of the environment, 3) Reflect on the next action to take.
68+
"""
3669

3770

3871
@dataclass
@@ -42,9 +75,7 @@ class Goal(Block):
4275
goal_as_system_msg: bool = True
4376

4477
def apply(self, llm, messages: list[MessageBuilder], obs: dict) -> dict:
45-
system_message = llm.msg.system().add_text(
46-
"You are an agent. Based on the observation, you will decide which action to take to accomplish your goal."
47-
)
78+
system_message = llm.msg.system().add_text(SYS_MSG)
4879
messages.append(system_message)
4980

5081
if self.goal_as_system_msg:
@@ -176,6 +207,39 @@ def apply(self, llm, messages: list[MessageBuilder]) -> dict:
176207
messages.append(summary_msg)
177208

178209

210+
@dataclass
211+
class TaskHint(Block):
212+
use_task_hint: bool = True
213+
hint_db_rel_path: str = "hint_db.csv"
214+
215+
def _init(self):
216+
"""Initialize the block."""
217+
hint_db_path = Path(__file__).parent / self.hint_db_rel_path
218+
self.hint_db = pd.read_csv(hint_db_path, header=0, index_col=None, dtype=str)
219+
220+
# index the task_name for fast lookup
221+
# self.hint_db.set_index("task_name", inplace=True, drop=False)
222+
223+
def apply(self, llm, messages: list[MessageBuilder], task_name: str) -> dict:
224+
if not self.use_task_hint:
225+
return
226+
227+
task_hints = self.hint_db[
228+
self.hint_db["task_name"].apply(lambda x: fnmatch.fnmatch(x, task_name))
229+
]
230+
231+
hints = []
232+
for hint in task_hints["hint"]:
233+
hint = hint.strip()
234+
if hint:
235+
hints.append(f"- {hint}")
236+
237+
hints_str = "Here are some hints for the task you are working on:\n" + "\n".join(hints)
238+
msg = llm.msg.user().add_text(hints_str)
239+
240+
messages.append(msg)
241+
242+
179243
class ToolCall(Block):
180244

181245
def __init__(self, tool_server):
@@ -202,6 +266,7 @@ class PromptConfig:
202266
obs: Obs = None
203267
summarizer: Summarizer = None
204268
general_hints: GeneralHints = None
269+
task_hint: TaskHint = None
205270

206271

207272
@dataclass
@@ -230,11 +295,6 @@ def prepare(self):
230295
def close(self):
231296
return self.model_args.close_server()
232297

233-
def set_benchmark(self, benchmark, demo_mode):
234-
235-
if benchmark in ["miniwob", "miniwob_tiny_test"]:
236-
self.config.obs.use_zoomed_webpage = True
237-
238298

239299
class ToolUseAgent(bgym.Agent):
240300
def __init__(
@@ -253,11 +313,7 @@ def __init__(
253313
self.msg_builder = model_args.get_message_builder()
254314
self.llm.msg = self.msg_builder
255315

256-
# # blocks
257-
# self.goal_block = self.config.goal
258-
# self.obs_block = self.config.obs
259-
# self.summarizer_block = self.config.summarizer
260-
# self.general_hints_block = self.config.general_hints
316+
self.task_hint = self.config.task_hint.make()
261317

262318
self.messages: list[MessageBuilder] = []
263319
self.last_response: LLMOutput = LLMOutput()
@@ -289,19 +345,20 @@ def obs_preprocessor(self, obs):
289345
)
290346
if self.config.obs.use_zoomed_webpage:
291347
pass
292-
# if self.config.tag_screenshot:
293-
# screenshot = Image.fromarray(obs["screenshot"])
294-
# screenshot = agent_utils.tag_screenshot_with_action(screenshot, obs["last_action"])
295-
# obs["screenshot"] = np.array(screenshot)
296348

297349
return obs
298350

351+
def set_task_name(self, task_name: str):
352+
"""Cheater function that is supposed to be called by loop.py before callling get_action"""
353+
self.task_name = task_name
354+
299355
@cost_tracker_decorator
300356
def get_action(self, obs: Any) -> float:
301357
self.llm.reset_stats()
302358
if len(self.messages) == 0:
303359
self.config.goal.apply(self.llm, self.messages, obs)
304360
self.config.general_hints.apply(self.llm, self.messages)
361+
self.task_hint.apply(self.llm, self.messages, self.task_name)
305362

306363
self.config.obs.apply(self.llm, self.messages, obs, last_llm_output=self.last_response)
307364
self.config.summarizer.apply(self.llm, self.messages)
@@ -366,6 +423,7 @@ def get_action(self, obs: Any) -> float:
366423
),
367424
summarizer=Summarizer(),
368425
general_hints=GeneralHints(use_hints=False),
426+
task_hint=TaskHint(use_task_hint=True),
369427
)
370428

371429
AGENT_CONFIG = ToolUseAgentArgs(

src/agentlab/experiments/loop.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,9 @@ def run(self):
399399
try:
400400
logger.info(f"Running experiment {self.exp_name} in:\n {self.exp_dir}")
401401
agent = self.agent_args.make_agent()
402+
if hasattr(agent, "set_task_name"):
403+
agent.set_task_name(self.env_args.task_name)
404+
402405
logger.debug("Agent created.")
403406

404407
env = self.env_args.make_env(

0 commit comments

Comments
 (0)