Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion browsergym/core/src/browsergym/core/env.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import copy
import logging
import re
Expand All @@ -19,6 +20,7 @@
MarkingError,
_post_extract,
_pre_extract,
add_mouse_pointer_to_screenshot,
extract_dom_extra_properties,
extract_dom_snapshot,
extract_focused_element_bid,
Expand Down Expand Up @@ -77,6 +79,7 @@ def __init__(
# agent-related arguments
action_mapping: Optional[callable] = HighLevelActionSet().to_python_code,
use_raw_page_output: bool = False,
show_mouse_pointer: bool = True,
):
"""
Instantiate a ready to use BrowserEnv gym environment.
Expand All @@ -98,6 +101,7 @@ def __init__(
pw_context_kwargs: extra parameters for the playwright BrowserContext. Should only be used for debugging/testing.
action_mapping: if set, the environment will use this function to map every received action to executable Python code.
use_raw_page_output: if set, the environment will use the raw page output instead of the default processing.
show_mouse_pointer: if set to True (default), the environment will add mouse pointer visualization to screenshots for click actions. Set to False to disable.

"""
super().__init__()
Expand All @@ -118,6 +122,7 @@ def __init__(
self.pw_context_kwargs = pw_context_kwargs
self.action_mapping = action_mapping
self.use_raw_page_output = use_raw_page_output
self.show_mouse_pointer = show_mouse_pointer

# check argument values
assert tags_to_mark in ("all", "standard_html")
Expand Down Expand Up @@ -664,7 +669,13 @@ def _get_obs(self):
"open_pages_titles": tuple(page.title() for page in self.context.pages),
"active_page_index": np.asarray([self.context.pages.index(self.page)]),
"url": self.page.url, # redundant with "open_pages_urls" and "active_page_index"
"screenshot": extract_screenshot(self.page),
"screenshot": (
extract_screenshot(self.page)
if not self.show_mouse_pointer
else add_mouse_pointer_to_screenshot(
extract_screenshot(self.page), self.last_action
)
),
"dom_object": dom,
"axtree_object": axtree,
"extra_element_properties": extra_properties,
Expand Down
136 changes: 135 additions & 1 deletion browsergym/core/src/browsergym/core/observation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import ast
import base64
import io
import logging
import pkgutil
import re
from typing import Literal
from typing import Literal, Optional, Tuple

import numpy as np
import PIL.Image
import playwright.sync_api
from PIL import Image, ImageDraw

from .constants import BROWSERGYM_ID_ATTRIBUTE as BID_ATTR
from .constants import BROWSERGYM_SETOFMARKS_ATTRIBUTE as SOM_ATTR
Expand Down Expand Up @@ -573,3 +575,135 @@ def extract_focused_element_bid(page: playwright.sync_api.Page):
focused_bid = ""

return focused_bid


def parse_func_call_string(call_str: str) -> Tuple[Optional[str], Optional[Tuple[list, dict]]]:
"""
Parse a function call string and extract the function name and arguments.

Args:
call_str (str): A string like "mouse_click(100, 200)" or "mouse_drag_and_drop(x=10, y=20)"

Returns:
Tuple (func_name, (args, kwargs)), or (None, None) if parsing fails
"""
try:
tree = ast.parse(call_str.strip(), mode="eval")
if not isinstance(tree.body, ast.Call):
return None, None

call_node = tree.body

# Function name
if isinstance(call_node.func, ast.Name):
func_name = call_node.func.id
else:
return None, None

# Positional arguments
args = []
for arg in call_node.args:
try:
args.append(ast.literal_eval(arg))
except (ValueError, TypeError):
return None, None

# Keyword arguments
kwargs = {}
for kw in call_node.keywords:
try:
kwargs[kw.arg] = ast.literal_eval(kw.value)
except (ValueError, TypeError):
return None, None

return func_name, (args, kwargs)

except (SyntaxError, ValueError, TypeError):
return None, None


def try_extract_coords(source: dict | list, x_key, y_key) -> Optional[Tuple[int, int]]:
try:
x = int(float(source[x_key]))
y = int(float(source[y_key]))
return x, y
except (KeyError, IndexError, ValueError, TypeError):
return None


def extract_mouse_coords_from_action(action: str) -> tuple[int, int] | None:
"""
Extract mouse coordinates from a mouse action string.

Args:
action: A string like "mouse_click(100, 200)" or "click(x=100, y=200)"

Returns:
(x, y) tuple or None if extraction fails
"""
if not action or not isinstance(action, str):
return None

mouse_actions = {
"mouse_click",
"mouse_dblclick",
"mouse_down",
"mouse_up",
"mouse_move",
"scroll_at",
"mouse_upload_file",
}

func_name, parsed_args = parse_func_call_string(action)
if func_name is None or func_name not in mouse_actions or parsed_args is None:
return None

args, kwargs = parsed_args
if args:
# If there are positional arguments, assume the first two are x and y coordinates
if len(args) >= 2:
return try_extract_coords(args, 0, 1)
elif kwargs:
# If there are keyword arguments, look for 'x' and 'y'
return try_extract_coords(kwargs, "x", "y")


def add_mouse_pointer_to_screenshot(screenshot: Image.Image, action: str) -> Image.Image:
"""
Add mouse pointer visualization to screenshot if the action involves mouse clicks.

Args:
screenshot: The screenshot PIL Image
action: The action string to check for mouse coordinates

Returns:
Screenshot with mouse pointer overlay if applicable, otherwise original screenshot
"""
coords = extract_mouse_coords_from_action(action)
if not coords:
return screenshot # No mouse coordinates found, return original screenshot
else:
# Convert numpy array to PIL Image first
if isinstance(screenshot, np.ndarray):
pil_image = Image.fromarray(screenshot)
else:
pil_image = screenshot

x, y = coords
pointer_size = 20 # Length of the pointer
overlay = pil_image.convert("RGBA").copy()
draw = ImageDraw.Draw(overlay)

# Define pointer shape (a simple arrow)
pointer_shape = [
(x, y),
(x + pointer_size, y + pointer_size // 2),
(x + pointer_size // 2, y + pointer_size // 2),
(x + pointer_size // 2, y + pointer_size),
]
Comment on lines +698 to +703
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mouse pointer boundary check missing category Functionality

Tell me more
What is the issue?

The mouse pointer visualization could appear outside the screenshot boundaries when coordinates are near the edges, potentially causing IndexError or visual artifacts.

Why this matters

This could lead to runtime errors or partial/incorrect pointer visualization when the agent interacts with elements near screen edges.

Suggested change ∙ Feature Preview

Add boundary checks before drawing the pointer to ensure coordinates stay within image dimensions:

def add_mouse_pointer_to_screenshot(screenshot: Image.Image, action: str) -> Image.Image:
    coords = extract_mouse_coords_from_action(action)
    if not coords:
        return screenshot

    if isinstance(screenshot, np.ndarray):
        pil_image = Image.fromarray(screenshot)
    else:
        pil_image = screenshot

    width, height = pil_image.size
    x, y = coords
    pointer_size = 20

    # Constrain pointer coordinates to image boundaries
    x = max(0, min(x, width - pointer_size))
    y = max(0, min(y, height - pointer_size))

    overlay = pil_image.convert("RGBA").copy()
    draw = ImageDraw.Draw(overlay)
    pointer_shape = [
        (x, y),
        (x + pointer_size, y + pointer_size // 2),
        (x + pointer_size // 2, y + pointer_size // 2),
        (x + pointer_size // 2, y + pointer_size),
    ]
    draw.polygon(pointer_shape, fill=(0, 0, 0, 128))
    result_image = Image.alpha_composite(pil_image.convert("RGBA"), overlay)
    return np.array(result_image.convert("RGB"))
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.


draw.polygon(pointer_shape, fill=(0, 0, 0, 128)) # 50% transparent black
result_image = Image.alpha_composite(pil_image.convert("RGBA"), overlay)

# Convert back to numpy array to match expected format
return np.array(result_image.convert("RGB"))
Loading