Skip to content

Commit a16f024

Browse files
Merge remote-tracking branch 'origin/allac/next-agent' into aj/tool_use_agent_chat_completion_support
2 parents d70433f + 60eed9e commit a16f024

File tree

2 files changed

+154
-48
lines changed

2 files changed

+154
-48
lines changed

src/agentlab/agents/agent_utils.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
from logging import warning
2-
from playwright.sync_api import Page
32

43
from PIL import Image, ImageDraw
5-
from logging import warning
6-
from playwright.sync_api import Page
7-
4+
from playwright.sync_api import Page
85

96
"""
107
This module contains utility functions for handling observations and actions in the context of agent interactions.
@@ -87,6 +84,57 @@ def draw_mouse_pointer(image: Image.Image, x: int, y: int) -> Image.Image:
8784

8885
return Image.alpha_composite(image.convert("RGBA"), overlay)
8986

87+
88+
def draw_click_indicator(image: Image.Image, x: int, y: int) -> Image.Image:
89+
"""
90+
Draws a click indicator (+ shape with disconnected lines) at (x, y) on the image.
91+
Returns a new image with the click indicator drawn.
92+
"""
93+
line_length = 10 # Length of each line segment
94+
gap = 4 # Gap from center point
95+
line_width = 2 # Thickness of lines
96+
97+
overlay = image.convert("RGBA").copy()
98+
draw = ImageDraw.Draw(overlay)
99+
100+
# Draw 4 lines forming a + shape with gaps in the center
101+
# Each line has a white outline and black center for visibility on any background
102+
103+
# Top line
104+
draw.line(
105+
[(x, y - gap - line_length), (x, y - gap)], fill=(255, 255, 255, 200), width=line_width + 2
106+
) # White outline
107+
draw.line(
108+
[(x, y - gap - line_length), (x, y - gap)], fill=(0, 0, 0, 255), width=line_width
109+
) # Black center
110+
111+
# Bottom line
112+
draw.line(
113+
[(x, y + gap), (x, y + gap + line_length)], fill=(255, 255, 255, 200), width=line_width + 2
114+
) # White outline
115+
draw.line(
116+
[(x, y + gap), (x, y + gap + line_length)], fill=(0, 0, 0, 255), width=line_width
117+
) # Black center
118+
119+
# Left line
120+
draw.line(
121+
[(x - gap - line_length, y), (x - gap, y)], fill=(255, 255, 255, 200), width=line_width + 2
122+
) # White outline
123+
draw.line(
124+
[(x - gap - line_length, y), (x - gap, y)], fill=(0, 0, 0, 255), width=line_width
125+
) # Black center
126+
127+
# Right line
128+
draw.line(
129+
[(x + gap, y), (x + gap + line_length, y)], fill=(255, 255, 255, 200), width=line_width + 2
130+
) # White outline
131+
draw.line(
132+
[(x + gap, y), (x + gap + line_length, y)], fill=(0, 0, 0, 255), width=line_width
133+
) # Black center
134+
135+
return Image.alpha_composite(image.convert("RGBA"), overlay)
136+
137+
90138
def zoom_webpage(page: Page, zoom_factor: float = 1.5):
91139
"""
92140
Zooms the webpage to the specified zoom factor.

src/agentlab/agents/tool_use_agent/multi_tool_agent.py

Lines changed: 102 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
from abc import ABC, abstractmethod
44
from copy import copy
5-
from dataclasses import asdict, dataclass
5+
from dataclasses import asdict, dataclass, field
66
from pathlib import Path
77
from typing import Any
88

@@ -54,6 +54,55 @@ def apply(self, llm, messages: list[MessageBuilder], **kwargs):
5454
pass
5555

5656

57+
@dataclass
58+
class MsgGroup:
59+
name: str = None
60+
messages: list[MessageBuilder] = field(default_factory=list)
61+
summary: MessageBuilder = None
62+
63+
64+
class StructuredDiscussion:
65+
66+
def __init__(self, keep_last_n_obs=None):
67+
self.groups: list[MsgGroup] = []
68+
self.keep_last_n_obs = keep_last_n_obs
69+
70+
def append(self, message: MessageBuilder):
71+
"""Append a message to the last group."""
72+
self.groups[-1].messages.append(message)
73+
74+
def new_group(self, name: str = None):
75+
"""Start a new group of messages."""
76+
if name is None:
77+
name = f"group_{len(self.groups)}"
78+
self.groups.append(MsgGroup(name))
79+
80+
def flatten(self) -> list[MessageBuilder]:
81+
"""Flatten the groups into a single list of messages."""
82+
83+
keep_last_n_obs = self.keep_last_n_obs or len(self.groups)
84+
messages = []
85+
for i, group in enumerate(self.groups):
86+
is_tail = i >= len(self.groups) - keep_last_n_obs
87+
print(
88+
f"Processing group {i} ({group.name}), is_tail={is_tail}, len(greoup)={len(group.messages)}"
89+
)
90+
if not is_tail and group.summary is not None:
91+
messages.append(group.summary)
92+
else:
93+
messages.extend(group.messages)
94+
95+
return messages
96+
97+
def set_last_summary(self, summary: MessageBuilder):
98+
# append None to summaries until we reach the current group index
99+
self.groups[-1].summary = summary
100+
101+
def is_goal_set(self) -> bool:
102+
"""Check if the goal is set in the first group."""
103+
return len(self.groups) > 0
104+
105+
57106
# @dataclass
58107
# class BlockArgs(ABC):
59108

@@ -74,9 +123,9 @@ class Goal(Block):
74123

75124
goal_as_system_msg: bool = True
76125

77-
def apply(self, llm, messages: list[MessageBuilder], obs: dict) -> dict:
126+
def apply(self, llm, discussion: StructuredDiscussion, obs: dict) -> dict:
78127
system_message = llm.msg.system().add_text(SYS_MSG)
79-
messages.append(system_message)
128+
discussion.append(system_message)
80129

81130
if self.goal_as_system_msg:
82131
goal_message = llm.msg.system()
@@ -89,7 +138,7 @@ def apply(self, llm, messages: list[MessageBuilder], obs: dict) -> dict:
89138
goal_message.add_text(content["text"])
90139
elif content["type"] == "image_url":
91140
goal_message.add_image(content["image_url"])
92-
messages.append(goal_message)
141+
discussion.append(goal_message)
93142

94143

95144
AXTREE_NOTE = """
@@ -108,11 +157,11 @@ class Obs(Block):
108157
use_dom: bool = False
109158
use_som: bool = False
110159
use_tabs: bool = False
111-
add_mouse_pointer: bool = True
160+
add_mouse_pointer: bool = False
112161
use_zoomed_webpage: bool = False
113162

114163
def apply(
115-
self, llm, messages: list[MessageBuilder], obs: dict, last_llm_output: LLMOutput
164+
self, llm, discussion: StructuredDiscussion, obs: dict, last_llm_output: LLMOutput
116165
) -> dict:
117166

118167
if last_llm_output.tool_calls is None:
@@ -147,7 +196,7 @@ def apply(
147196
if self.use_tabs:
148197
obs_msg.add_text(_format_tabs(obs))
149198

150-
messages.append(obs_msg)
199+
discussion.append(obs_msg)
151200
return obs_msg
152201

153202

@@ -172,7 +221,7 @@ class GeneralHints(Block):
172221

173222
use_hints: bool = True
174223

175-
def apply(self, llm, messages: list[MessageBuilder]) -> dict:
224+
def apply(self, llm, discussion: StructuredDiscussion) -> dict:
176225
if not self.use_hints:
177226
return
178227

@@ -182,7 +231,7 @@ def apply(self, llm, messages: list[MessageBuilder]) -> dict:
182231
"""Use ControlOrMeta instead of Control and Meta for keyboard shortcuts, to be cross-platform compatible. E.g. use ControlOrMeta for mutliple selection in lists.\n"""
183232
)
184233

185-
messages.append(llm.msg.user().add_text("\n".join(hints)))
234+
discussion.append(llm.msg.user().add_text("\n".join(hints)))
186235

187236

188237
@dataclass
@@ -192,20 +241,22 @@ class Summarizer(Block):
192241
do_summary: bool = False
193242
high_details: bool = True
194243

195-
def apply(self, llm, messages: list[MessageBuilder]) -> dict:
244+
def apply(self, llm, discussion: StructuredDiscussion) -> dict:
196245
if not self.do_summary:
197246
return
198247

199248
msg = llm.msg.user().add_text("""Summarize\n""")
200249

201-
messages.append(msg)
250+
discussion.append(msg)
202251
# TODO need to make sure we don't force tool use here
203-
summary_response = llm(messages=messages, tool_choice="none")
252+
summary_response = llm(messages=discussion.flatten(), tool_choice="none")
204253

205254
summary_msg = llm.msg.assistant().add_text(summary_response.think)
206-
messages.append(summary_msg)
255+
discussion.append(summary_msg)
256+
discussion.set_last_summary(summary_msg)
257+
return summary_msg
207258

208-
def apply_init(self, llm, messages: list[MessageBuilder]) -> dict:
259+
def apply_init(self, llm, discussion: StructuredDiscussion) -> dict:
209260
"""Initialize the summarizer block."""
210261
if not self.do_summary:
211262
return
@@ -215,19 +266,18 @@ def apply_init(self, llm, messages: list[MessageBuilder]) -> dict:
215266
# Add a system message to the LLM to indicate that it should summarize
216267
system_msg.add_text(
217268
"""# Summarizer instructions:\nWhen asked to summarize, do the following:
218-
1) Summarize the effect of the last action, with attention to details.
219-
2) Give a semantic description of the current state of the environment, with attention to details. If there was a repeating mistake, mention the cause of it.
220-
3) Reason about the overall task at a high level.
221-
4) What hint can be relevant for the next action? Only chose from the hints provided in the task description. Or select none.
222-
5) What is the currently activated item if any.
223-
6) Reason about the next action to take, based on the current state and the goal.
224-
"""
269+
1) Summarize the effect of the last action, with attention to details.
270+
2) Give a semantic description of the current state of the environment, with attention to details. If there was a repeating mistake, mention the cause of it.
271+
3) Reason about the overall task at a high level.
272+
4) What hint can be relevant for the next action? Only chose from the hints provided in the task description. Or select none.
273+
5) Reason about the next action to take, based on the current state and the goal.
274+
"""
225275
)
226276
else:
227277
system_msg.add_text(
228278
"""When asked to summarize, give a semantic description of the current state of the environment."""
229279
)
230-
messages.append(system_msg)
280+
discussion.append(system_msg)
231281

232282

233283
@dataclass
@@ -243,7 +293,7 @@ def _init(self):
243293
# index the task_name for fast lookup
244294
# self.hint_db.set_index("task_name", inplace=True, drop=False)
245295

246-
def apply(self, llm, messages: list[MessageBuilder], task_name: str) -> dict:
296+
def apply(self, llm, discussion: StructuredDiscussion, task_name: str) -> dict:
247297
if not self.use_task_hint:
248298
return
249299

@@ -257,12 +307,14 @@ def apply(self, llm, messages: list[MessageBuilder], task_name: str) -> dict:
257307
if hint:
258308
hints.append(f"- {hint}")
259309

260-
hints_str = "# Hints:\nHere are some hints for the task you are working on:\n" + "\n".join(
261-
hints
262-
)
263-
msg = llm.msg.user().add_text(hints_str)
310+
if len(hints) > 0:
311+
hints_str = (
312+
"# Hints:\nHere are some hints for the task you are working on:\n"
313+
+ "\n".join(hints)
314+
)
315+
msg = llm.msg.user().add_text(hints_str)
264316

265-
messages.append(msg)
317+
discussion.append(msg)
266318

267319

268320
class ToolCall(Block):
@@ -292,6 +344,7 @@ class PromptConfig:
292344
summarizer: Summarizer = None
293345
general_hints: GeneralHints = None
294346
task_hint: TaskHint = None
347+
keep_last_n_obs: int = 2
295348

296349

297350
@dataclass
@@ -339,8 +392,9 @@ def __init__(
339392
self.llm.msg = self.msg_builder
340393

341394
self.task_hint = self.config.task_hint.make()
395+
self.obs_block = self.config.obs.make()
342396

343-
self.messages: list[MessageBuilder] = []
397+
self.discussion = StructuredDiscussion(self.config.keep_last_n_obs)
344398
self.last_response: LLMOutput = LLMOutput()
345399
self._responses: list[LLMOutput] = []
346400

@@ -380,37 +434,41 @@ def set_task_name(self, task_name: str):
380434
@cost_tracker_decorator
381435
def get_action(self, obs: Any) -> float:
382436
self.llm.reset_stats()
383-
if len(self.messages) == 0:
384-
self.config.goal.apply(self.llm, self.messages, obs)
385-
self.config.summarizer.apply_init(self.llm, self.messages)
386-
self.config.general_hints.apply(self.llm, self.messages)
387-
self.task_hint.apply(self.llm, self.messages, self.task_name)
388-
389-
logging.info("Appending observation to messages")
390-
self.config.obs.apply(self.llm, self.messages, obs, last_llm_output=self.last_response)
391-
logging.info("Calling summarizer")
392-
self.config.summarizer.apply(self.llm, self.messages)
393-
logging.info("Main tool calling")
437+
if not self.discussion.is_goal_set():
438+
self.discussion.new_group("goal")
439+
self.config.goal.apply(self.llm, self.discussion, obs)
440+
self.config.summarizer.apply_init(self.llm, self.discussion)
441+
self.config.general_hints.apply(self.llm, self.discussion)
442+
self.task_hint.apply(self.llm, self.discussion, self.task_name)
443+
444+
self.discussion.new_group()
445+
446+
self.obs_block.apply(self.llm, self.discussion, obs, last_llm_output=self.last_response)
447+
print("flatten for summary")
448+
449+
self.config.summarizer.apply(self.llm, self.discussion)
450+
451+
messages = self.discussion.flatten()
394452
response: LLMOutput = self.llm(
395-
messages=self.messages,
453+
messages=messages,
396454
tool_choice="any",
397455
cache_tool_definition=True,
398456
cache_complete_prompt=True,
399457
)
400-
logging.info(f"Obtained response {response}")
401458

402459
action = response.action
403460
think = response.think
404461

405-
self.messages.append(response.tool_calls)
462+
self.discussion.new_group()
463+
self.discussion.append(response.tool_calls)
406464

407465
self.last_response = response
408466
self._responses.append(response) # may be useful for debugging
409467
# self.messages.append(response.assistant_message) # this is tool call
410468

411469
agent_info = bgym.AgentInfo(
412470
think=think,
413-
chat_messages=self.messages,
471+
chat_messages=messages,
414472
stats=self.llm.stats.stats_dict,
415473
)
416474
return action, agent_info

0 commit comments

Comments
 (0)