Skip to content

Commit 7f4c018

Browse files
committed
Merge branch 'aj/tool_use_agent_chat_completion_support' into allac/next-agent
2 parents b813df7 + 459afad commit 7f4c018

File tree

4 files changed

+84
-5
lines changed

4 files changed

+84
-5
lines changed

src/agentlab/agents/agent_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from logging import warning
2+
from typing import Optional, Tuple
23

4+
import numpy as np
35
from PIL import Image, ImageDraw
46
from playwright.sync_api import Page
57

@@ -42,10 +44,37 @@ def tag_screenshot_with_action(screenshot: Image, action: str) -> Image:
4244
)
4345
except (ValueError, IndexError) as e:
4446
warning(f"Failed to parse action '{action}': {e}")
47+
48+
elif action.startswith("mouse_drag_and_drop"):
49+
try:
50+
func_name, parsed_args = parse_func_call_string(action)
51+
if func_name == "mouse_drag_and_drop" and parsed_args is not None:
52+
args, kwargs = parsed_args
53+
x1, y1, x2, y2 = None, None, None, None
54+
55+
if args and len(args) >= 4:
56+
# Positional arguments: mouse_drag_and_drop(x1, y1, x2, y2)
57+
x1, y1, x2, y2 = map(float, args[:4])
58+
elif kwargs:
59+
# Keyword arguments: mouse_drag_and_drop(from_x=x1, from_y=y1, to_x=x2, to_y=y2)
60+
x1 = float(kwargs.get("from_x", 0))
61+
y1 = float(kwargs.get("from_y", 0))
62+
x2 = float(kwargs.get("to_x", 0))
63+
y2 = float(kwargs.get("to_y", 0))
64+
65+
if all(coord is not None for coord in [x1, y1, x2, y2]):
66+
draw = ImageDraw.Draw(screenshot)
67+
# Draw the main line
68+
draw.line((x1, y1, x2, y2), fill="red", width=2)
69+
# Draw arrowhead at the end point using the helper function
70+
draw_arrowhead(draw, (x1, y1), (x2, y2))
71+
except (ValueError, IndexError) as e:
72+
warning(f"Failed to parse action '{action}': {e}")
4573
return screenshot
4674

4775

4876
def add_mouse_pointer_from_action(screenshot: Image, action: str) -> Image.Image:
77+
4978
if action.startswith("mouse_click"):
5079
try:
5180
coords = action[action.index("(") + 1 : action.index(")")].split(",")
@@ -85,6 +114,23 @@ def draw_mouse_pointer(image: Image.Image, x: int, y: int) -> Image.Image:
85114
return Image.alpha_composite(image.convert("RGBA"), overlay)
86115

87116

117+
def draw_arrowhead(draw, start, end, arrow_length=15, arrow_angle=30):
118+
from math import atan2, cos, radians, sin
119+
120+
angle = atan2(end[1] - start[1], end[0] - start[0])
121+
left = (
122+
end[0] - arrow_length * cos(angle - radians(arrow_angle)),
123+
end[1] - arrow_length * sin(angle - radians(arrow_angle)),
124+
)
125+
right = (
126+
end[0] - arrow_length * cos(angle + radians(arrow_angle)),
127+
end[1] - arrow_length * sin(angle + radians(arrow_angle)),
128+
)
129+
draw.line([end, left], fill="red", width=4)
130+
draw.line([end, right], fill="red", width=4)
131+
132+
133+
88134
def draw_click_indicator(image: Image.Image, x: int, y: int) -> Image.Image:
89135
"""
90136
Draws a click indicator (+ shape with disconnected lines) at (x, y) on the image.

src/agentlab/agents/tool_use_agent/hint_db.csv

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,12 @@ June 4,miniwob.book-flight,2,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-
44
June 4,miniwob.book-flight,2,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,Filling up form,GUI elements surrounded by a red rectangle often means there is an error in the content
55
June 4,miniwob.book-flight,2,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,Search results,The scroll bar indicates that there is more than 1 flights available in the search. Make sure to select the one matching the task goal among all possible flights.
66
June 7,miniwob.book-flight,2,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,Filling up form,"If you suspect an error in an element of a form, ""ControlOrMeta+a"" to select all and overwrite the content"
7-
June 9,miniwob.book-flight,2,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,Filling up form,"The main mistake is to type in the wrong field. Make sure the correct field is activated by clicking into it and seeing it activated with a blue rectangle before proceeding with writing. Otherwise, it will append to the currently activated field."
7+
June 9,miniwob.book-flight,2,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,allac,Filling up form,"The main mistake is to type in the wrong field. Make sure the correct field is activated by clicking into it and seeing it activated with a blue rectangle before proceeding with writing. Otherwise, it will append to the currently activated field."
8+
June 11,miniwob.drag-items-grid,30,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,aman,Dragging items,"For dragging tasks, use only mouse_down, move_mouse, and mouse_up. Do not perform a single continuous drag. Instead, move the mouse in steps using move_mouse, and then release with mouse_up when the target is reached."
9+
June 11,miniwob.drag-items-grid,30,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,aman,Dragging items,"Avoid dragging items outside the grid. Ensure the drop coordinates are within valid grid boundaries, or the task will fail."
10+
June 11,miniwob.drag-items-grid,30,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,aman,Dragging items,"Grid items are reactive: dragging an item over another may cause them to swap. Only perform mouse_down when the dragged item is correctly positioned. Intermediate movement helps trigger these reactive behaviors reliably."
11+
June 11,miniwob.drag-items-grid,30,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,aman,Dragging items,"Use intermediate positions when moving items. This often leads to more stable behavior than dragging directly to the final target."
12+
June 11,miniwob.drag-items-grid,30,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,aman,Dragging items,"Hovering near a target item can prompt it to move to an adjacent grid cell. Use this behavior to clear the destination before dropping the dragged item"
13+
June 11,miniwob.drag-items-grid,30,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,aman,Dragging items,"Move items along slightly curved instead of straight lines to better imitate human-like dragging behavior."
14+
June 11,miniwob.drag-items,30,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,aman,Dragging items,"Move items along slightly curved instead of straight lines to better imitate human-like dragging behavior."
15+
June 11,miniwob.drag-items,30,claude-3-7-sonnet-20250219,MultiToolUse-claude-3-7-sonnet-20250219,miniwob,miniwob,aman,Dragging items,"For dragging tasks, use only mouse_down, move_mouse, and mouse_up. Do not perform a single continuous drag. Instead, move the mouse in steps using move_mouse, and then release with mouse_up when the target is reached."

src/agentlab/agents/tool_use_agent/multi_tool_agent.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class StructuredDiscussion:
7272

7373
def __init__(self, keep_last_n_obs=None):
7474
self.groups: list[MsgGroup] = []
75-
self.keep_last_n_obs = keep_last_n_obs
75+
self.keep_last_n_obs: int | None = keep_last_n_obs
7676

7777
def append(self, message: MessageBuilder):
7878
"""Append a message to the last group."""
@@ -96,7 +96,9 @@ def flatten(self) -> list[MessageBuilder]:
9696
messages.append(group.summary)
9797
else:
9898
messages.extend(group.messages)
99-
99+
# Mark all summarized messages for caching
100+
if i == len(self.groups) - keep_last_n_obs:
101+
messages[i].mark_all_previous_msg_for_caching()
100102
return messages
101103

102104
def set_last_summary(self, summary: MessageBuilder):
@@ -452,7 +454,8 @@ def get_action(self, obs: Any) -> float:
452454
messages=messages,
453455
tool_choice="any",
454456
cache_tool_definition=True,
455-
cache_complete_prompt=True,
457+
cache_complete_prompt=False,
458+
use_cache_breakpoints=True,
456459
)
457460

458461
action = response.action

src/agentlab/llm/response_api.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ def add_image_url(self, image_url: str) -> "MessageBuilder":
107107
self.content.append({"image": image_to_png_base64_url(image_url)})
108108
return self
109109

110+
def mark_all_previous_msg_for_caching(self):
111+
"""Insert a cache breakpoint in the message content."""
112+
# This is a placeholder for future implementation.
113+
raise NotImplementedError
114+
110115

111116
# TODO: Support parallel tool calls.
112117

@@ -216,6 +221,10 @@ def transform_content(self, content: ContentItem) -> ContentItem:
216221
else:
217222
raise ValueError(f"Unsupported content type: {content}")
218223

224+
def mark_all_previous_msg_for_caching(self) -> List[Message]:
225+
"""Insert a cache breakpoint in the message content to mark all previous messages for caching."""
226+
self._cache_breakpoint = True
227+
219228

220229
class OpenAIChatCompletionAPIMessageBuilder(MessageBuilder):
221230

@@ -521,7 +530,10 @@ def _call_api(
521530
sys_msg, other_msgs = self.filter_system_messages(messages)
522531
sys_msg_text = "\n".join(c["text"] for m in sys_msg for c in m.content)
523532
for msg in other_msgs:
524-
input.extend(msg.prepare_message() if isinstance(msg, MessageBuilder) else [msg])
533+
temp = msg.prepare_message() if isinstance(msg, MessageBuilder) else [msg]
534+
if kwargs.pop("use_cache_breakpoints", False):
535+
temp = self.apply_cache_breakpoints(msg, temp)
536+
input.extend(temp)
525537

526538
api_params: Dict[str, Any] = {
527539
"model": self.model_name,
@@ -581,6 +593,16 @@ def _parse_response(self, response: dict) -> dict:
581593
result.think += output.text
582594
return result
583595

596+
# def ensure_cache_conditions(self, msgs: List[Message]) -> bool:
597+
# """Ensure API specific cache conditions are met."""
598+
# assert sum(getattr(msg, "_cache_breakpoint", 0) for msg in msgs) <= 4, "Too many cache breakpoints in the message."
599+
600+
def apply_cache_breakpoints(self, msg: Message, prepared_msg: dict) -> List[Message]:
601+
"""Apply cache breakpoints to the messages."""
602+
if getattr(msg, "_cache_breakpoint", False):
603+
prepared_msg[-1]["content"][-1]["cache_control"] = {"type": "ephemeral"}
604+
return prepared_msg
605+
584606

585607
def cua_response_to_text(action):
586608
"""

0 commit comments

Comments
 (0)