Skip to content

Commit f112e15

Browse files
committed
enhance the action overlay mechanism
1 parent dce2633 commit f112e15

File tree

4 files changed

+535
-172
lines changed

4 files changed

+535
-172
lines changed

src/agentlab/agents/agent_utils.py

Lines changed: 0 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,6 @@
1-
from logging import warning
2-
from typing import Optional, Tuple
3-
4-
import numpy as np
51
from PIL import Image, ImageDraw
62
from playwright.sync_api import Page
73

8-
"""
9-
This module contains utility functions for handling observations and actions in the context of agent interactions.
10-
"""
11-
12-
13-
def tag_screenshot_with_action(screenshot: Image, action: str) -> Image:
14-
"""
15-
If action is a coordinate action, try to render it on the screenshot.
16-
17-
e.g. mouse_click(120, 130) -> draw a dot at (120, 130) on the screenshot
18-
19-
Args:
20-
screenshot: The screenshot to tag.
21-
action: The action to tag the screenshot with.
22-
23-
Returns:
24-
The tagged screenshot.
25-
26-
Raises:
27-
ValueError: If the action parsing fails.
28-
"""
29-
if action.startswith("mouse_click"):
30-
try:
31-
coords = action[action.index("(") + 1 : action.index(")")].split(",")
32-
coords = [c.strip() for c in coords]
33-
if len(coords) not in [2, 3]:
34-
raise ValueError(f"Invalid coordinate format: {coords}")
35-
if coords[0].startswith("x="):
36-
coords[0] = coords[0][2:]
37-
if coords[1].startswith("y="):
38-
coords[1] = coords[1][2:]
39-
x, y = float(coords[0].strip()), float(coords[1].strip())
40-
draw = ImageDraw.Draw(screenshot)
41-
radius = 5
42-
draw.ellipse(
43-
(x - radius, y - radius, x + radius, y + radius), fill="blue", outline="blue"
44-
)
45-
except (ValueError, IndexError) as e:
46-
warning(f"Failed to parse action '{action}': {e}")
47-
48-
elif action.startswith("mouse_drag_and_drop"):
49-
try:
50-
func_name, parsed_args = parse_func_call_string(action)
51-
if func_name == "mouse_drag_and_drop" and parsed_args is not None:
52-
args, kwargs = parsed_args
53-
x1, y1, x2, y2 = None, None, None, None
54-
55-
if args and len(args) >= 4:
56-
# Positional arguments: mouse_drag_and_drop(x1, y1, x2, y2)
57-
x1, y1, x2, y2 = map(float, args[:4])
58-
elif kwargs:
59-
# Keyword arguments: mouse_drag_and_drop(from_x=x1, from_y=y1, to_x=x2, to_y=y2)
60-
x1 = float(kwargs.get("from_x", 0))
61-
y1 = float(kwargs.get("from_y", 0))
62-
x2 = float(kwargs.get("to_x", 0))
63-
y2 = float(kwargs.get("to_y", 0))
64-
65-
if all(coord is not None for coord in [x1, y1, x2, y2]):
66-
draw = ImageDraw.Draw(screenshot)
67-
# Draw the main line
68-
draw.line((x1, y1, x2, y2), fill="red", width=2)
69-
# Draw arrowhead at the end point using the helper function
70-
draw_arrowhead(draw, (x1, y1), (x2, y2))
71-
except (ValueError, IndexError) as e:
72-
warning(f"Failed to parse action '{action}': {e}")
73-
return screenshot
74-
75-
76-
def add_mouse_pointer_from_action(screenshot: Image, action: str) -> Image.Image:
77-
78-
if action.startswith("mouse_click"):
79-
try:
80-
coords = action[action.index("(") + 1 : action.index(")")].split(",")
81-
coords = [c.strip() for c in coords]
82-
if len(coords) not in [2, 3]:
83-
raise ValueError(f"Invalid coordinate format: {coords}")
84-
if coords[0].startswith("x="):
85-
coords[0] = coords[0][2:]
86-
if coords[1].startswith("y="):
87-
coords[1] = coords[1][2:]
88-
x, y = int(coords[0].strip()), int(coords[1].strip())
89-
screenshot = draw_mouse_pointer(screenshot, x, y)
90-
except (ValueError, IndexError) as e:
91-
warning(f"Failed to parse action '{action}': {e}")
92-
return screenshot
93-
944

955
def draw_mouse_pointer(image: Image.Image, x: int, y: int) -> Image.Image:
966
"""
@@ -218,50 +128,3 @@ def zoom_webpage(page: Page, zoom_factor: float = 1.5):
218128

219129
page.evaluate(f"document.documentElement.style.zoom='{zoom_factor*100}%'")
220130
return page
221-
222-
223-
def parse_func_call_string(call_str: str) -> Tuple[Optional[str], Optional[Tuple[list, dict]]]:
224-
"""
225-
Parse a function call string and extract the function name and arguments.
226-
227-
Args:
228-
call_str (str): A string like "mouse_click(100, 200)" or "mouse_drag_and_drop(x=10, y=20)"
229-
230-
Returns:
231-
Tuple (func_name, (args, kwargs)), or (None, None) if parsing fails
232-
"""
233-
import ast
234-
235-
try:
236-
tree = ast.parse(call_str.strip(), mode="eval")
237-
if not isinstance(tree.body, ast.Call):
238-
return None, None
239-
240-
call_node = tree.body
241-
242-
# Function name
243-
if isinstance(call_node.func, ast.Name):
244-
func_name = call_node.func.id
245-
else:
246-
return None, None
247-
248-
# Positional arguments
249-
args = []
250-
for arg in call_node.args:
251-
try:
252-
args.append(ast.literal_eval(arg))
253-
except (ValueError, TypeError):
254-
return None, None
255-
256-
# Keyword arguments
257-
kwargs = {}
258-
for kw in call_node.keywords:
259-
try:
260-
kwargs[kw.arg] = ast.literal_eval(kw.value)
261-
except (ValueError, TypeError):
262-
return None, None
263-
264-
return func_name, (args, kwargs)
265-
266-
except (SyntaxError, ValueError, TypeError):
267-
return None, None

src/agentlab/analyze/agent_xray.py

Lines changed: 60 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
from attr import dataclass
1515
from langchain.schema import BaseMessage, HumanMessage
1616
from openai import OpenAI
17+
from openai.types.responses import ResponseFunctionToolCall
1718
from PIL import Image
1819

19-
from agentlab.agents import agent_utils
2020
from agentlab.analyze import inspect_results
21+
from agentlab.analyze.overlay_utils import annotate_action
2122
from agentlab.experiments.exp_utils import RESULTS_DIR
2223
from agentlab.experiments.loop import ExpResult, StepInfo
2324
from agentlab.experiments.study import get_most_recent_study
@@ -351,7 +352,7 @@ def run_gradio(results_dir: Path):
351352
pruned_html_code = gr.Code(language="html", **code_args)
352353

353354
with gr.Tab("AXTree") as tab_axtree:
354-
axtree_code = gr.Code(language=None, **code_args)
355+
axtree_code = gr.Markdown()
355356

356357
with gr.Tab("Chat Messages") as tab_chat:
357358
chat_messages = gr.Markdown()
@@ -536,38 +537,46 @@ def wrapper(*args, **kwargs):
536537

537538
def update_screenshot(som_or_not: str):
538539
global info
539-
action = info.exp_result.steps_info[info.step].action
540-
return agent_utils.tag_screenshot_with_action(
541-
get_screenshot(info, som_or_not=som_or_not), action
542-
)
540+
img, action_str = get_screenshot(info, som_or_not=som_or_not, annotate=True)
541+
return img
543542

544543

545-
def get_screenshot(info: Info, step: int = None, som_or_not: str = "Raw Screenshots"):
544+
def get_screenshot(
545+
info: Info, step: int = None, som_or_not: str = "Raw Screenshots", annotate: bool = False
546+
):
546547
if step is None:
547548
step = info.step
549+
step_info = info.exp_result.steps_info[step]
548550
try:
549551
is_som = som_or_not == "SOM Screenshots"
550-
return info.exp_result.get_screenshot(step, som=is_som)
552+
img = info.exp_result.get_screenshot(step, som=is_som)
553+
if annotate:
554+
action_str = step_info.action
555+
properties = step_info.obs.get("extra_element_properties", None)
556+
action_colored = annotate_action(img, action_string=action_str, properties=properties)
557+
else:
558+
action_colored = None
559+
return img, action_colored
551560
except FileNotFoundError:
552-
return None
561+
return None, None
553562

554563

555564
def update_screenshot_pair(som_or_not: str):
556565
global info
557-
s1 = get_screenshot(info, info.step, som_or_not)
558-
s2 = get_screenshot(info, info.step + 1, som_or_not)
559-
560-
if s1 is not None:
561-
s1 = agent_utils.tag_screenshot_with_action(
562-
s1, info.exp_result.steps_info[info.step].action
563-
)
566+
s1, action_str = get_screenshot(info, info.step, som_or_not, annotate=True)
567+
s2, action_str = get_screenshot(info, info.step + 1, som_or_not)
564568
return s1, s2
565569

566570

567571
def update_screenshot_gallery(som_or_not: str):
568572
global info
569-
screenshots = info.exp_result.get_screenshots(som=som_or_not == "SOM Screenshots")
573+
max_steps = len(info.exp_result.steps_info)
574+
som_or_not == "SOM Screenshots"
575+
576+
screenshots = [get_screenshot(info, step=i, som_or_not=som_or_not)[0] for i in range(max_steps)]
577+
570578
screenshots_and_label = [(s, f"Step {i}") for i, s in enumerate(screenshots)]
579+
571580
gallery = gr.Gallery(
572581
value=screenshots_and_label,
573582
columns=2,
@@ -595,7 +604,8 @@ def update_pruned_html():
595604

596605

597606
def update_axtree():
598-
return get_obs(key="axtree_txt", default="No AXTree")
607+
obs = get_obs(key="axtree_txt", default="No AXTree")
608+
return f"```\n{obs}\n```"
599609

600610

601611
def dict_to_markdown(d: dict):
@@ -645,7 +655,7 @@ def dict_msg_to_markdown(d: dict):
645655
case "text":
646656
parts.append(f"\n```\n{item['text']}\n```\n")
647657
case "tool_use":
648-
tool_use = f"Tool Use: {item['name']} {item['input']} (id = {item['id']})"
658+
tool_use = _format_tool_call(item["name"], item["input"], item["call_id"])
649659
parts.append(f"\n```\n{tool_use}\n```\n")
650660
case _:
651661
parts.append(f"\n```\n{str(item)}\n```\n")
@@ -655,27 +665,40 @@ def dict_msg_to_markdown(d: dict):
655665
return markdown
656666

657667

668+
def _format_tool_call(name: str, input: str, call_id: str):
669+
"""
670+
Format a tool call to markdown.
671+
"""
672+
return f"Tool Call: {name} `{input}` (call_id: {call_id})"
673+
674+
675+
def format_chat_message(message: BaseMessage | MessageBuilder | dict):
676+
"""
677+
Format a message to markdown.
678+
"""
679+
if isinstance(message, BaseMessage):
680+
return message.content
681+
elif isinstance(message, MessageBuilder):
682+
return message.to_markdown()
683+
elif isinstance(message, dict):
684+
return dict_msg_to_markdown(message)
685+
elif isinstance(message, ResponseFunctionToolCall): # type: ignore[return]
686+
too_use_str = _format_tool_call(message.name, message.arguments, message.call_id)
687+
return f"### Tool Use\n```\n{too_use_str}\n```\n"
688+
else:
689+
return str(message)
690+
691+
658692
def update_chat_messages():
659693
global info
660694
agent_info = info.exp_result.steps_info[info.step].agent_info
661695
chat_messages = agent_info.get("chat_messages", ["No Chat Messages"])
662696
if isinstance(chat_messages, Discussion):
663697
return chat_messages.to_markdown()
664698

665-
if isinstance(chat_messages, list) and isinstance(chat_messages[0], MessageBuilder):
666-
chat_messages = [
667-
m.to_markdown() if isinstance(m, MessageBuilder) else dict_msg_to_markdown(m)
668-
for m in chat_messages
669-
]
699+
if isinstance(chat_messages, list):
700+
chat_messages = [format_chat_message(m) for m in chat_messages]
670701
return "\n\n".join(chat_messages)
671-
messages = [] # TODO(ThibaultLSDC) remove this at some point
672-
for i, m in enumerate(chat_messages):
673-
if isinstance(m, BaseMessage): # TODO remove once langchain is deprecated
674-
m = m.content
675-
elif isinstance(m, dict):
676-
m = m.get("content", "No Content")
677-
messages.append(f"""# Message {i}\n```\n{m}\n```\n\n""")
678-
return "\n".join(messages)
679702

680703

681704
def update_task_error():
@@ -722,8 +745,8 @@ def update_agent_info_html():
722745
global info
723746
# screenshots from current and next step
724747
try:
725-
s1 = get_screenshot(info, info.step, False)
726-
s2 = get_screenshot(info, info.step + 1, False)
748+
s1, action_str = get_screenshot(info, info.step, False)
749+
s2, action_str = get_screenshot(info, info.step + 1, False)
727750
agent_info = info.exp_result.steps_info[info.step].agent_info
728751
page = agent_info.get("html_page", ["No Agent Info"])
729752
if page is None:
@@ -854,6 +877,8 @@ def get_episode_info(info: Info):
854877

855878
def get_action_info(info: Info):
856879
steps_info = info.exp_result.steps_info
880+
img, action_str = get_screenshot(info, step=info.step, annotate=True) # to update click_mapper
881+
857882
if len(steps_info) == 0:
858883
return "No steps were taken"
859884
if len(steps_info) <= info.step:
@@ -863,7 +888,7 @@ def get_action_info(info: Info):
863888
action_info = f"""\
864889
**Action:**
865890
866-
{code(step_info.action)}
891+
{action_str}
867892
"""
868893
think = step_info.agent_info.get("think", None)
869894
if think is not None:

0 commit comments

Comments
 (0)