Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion browsergym/core/src/browsergym/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .action.highlevel import HighLevelActionSet
from .chat import Chat
from .constants import BROWSERGYM_ID_ATTRIBUTE, EXTRACT_OBS_MAX_TRIES
from .hint_labeling import HintLabeling
from .observation import (
MarkingError,
_post_extract,
Expand Down Expand Up @@ -68,6 +69,8 @@ def __init__(
tags_to_mark: Literal["all", "standard_html"] = "standard_html",
# interactive / debugging arguments
headless: bool = True,
use_chat_ui: bool = False,
use_hint_labeling_ui: bool = False,
wait_for_user_message: bool = False,
terminate_on_infeasible: bool = True,
resizeable_window: bool = False,
Expand Down Expand Up @@ -120,6 +123,8 @@ def __init__(
self.action_mapping = action_mapping
self.use_raw_page_output = use_raw_page_output
self.pre_observation_delay = pre_observation_delay
self.use_chat_ui = use_chat_ui
self.use_hint_labeling_ui = use_hint_labeling_ui

# check argument values
assert tags_to_mark in ("all", "standard_html")
Expand All @@ -136,6 +141,10 @@ def __init__(
# chat
self.chat: Chat = None

# hint labeling
if self.use_hint_labeling_ui:
self.hint_labeling: HintLabeling = None

# observation space
if use_raw_page_output:
self.observation_space = gym.spaces.Dict(
Expand Down Expand Up @@ -211,6 +220,10 @@ def close(self):
if self.chat:
self.chat.close()
self.chat = None

if self.use_hint_labeling_ui and self.hint_labeling:
self.hint_labeling.close()
self.hint_labeling = None
# close the browser context
if self.context:
self.context.close()
Expand All @@ -228,6 +241,8 @@ def reset(self, seed=None, *args, **kwargs):
self.task.teardown()
self.context.close()
self.chat.close()
if self.use_hint_labeling_ui and self.hint_labeling:
self.hint_labeling.close()
self.browser.close()

# create a new task
Expand Down Expand Up @@ -327,12 +342,22 @@ def override_property(task, env, property):
)

# create the chat
# chat always exist even is use_chat_ui is False since we use it to keep track of messages
# we only show the chat ui if both headless is False and use_chat_ui is True
self.chat = Chat(
headless=self.headless,
headless=not (self.headless is False and self.use_chat_ui is True),
chat_size=(500, max(viewport["height"], 800)),
record_video_dir=self.record_video_dir,
)

# create the hint labeling ui
if self.use_hint_labeling_ui:
self.hint_labeling = HintLabeling(
headless=self.headless,
window_size=(600, max(viewport["height"], 1000)),
record_video_dir=self.record_video_dir,
)

# create a new page
self.page = self.context.new_page()
recording_start_time = time.time()
Expand Down
150 changes: 150 additions & 0 deletions browsergym/core/src/browsergym/core/hint_labeling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import json
import logging
from importlib import resources
from queue import Queue
from typing import Dict, List, Optional

import playwright.sync_api
from pydantic import BaseModel, Field

from . import _get_global_playwright, hint_labeling_files

logger = logging.getLogger(__name__)

HINT_LABELING_DIR = resources.files(hint_labeling_files)


class HintLabelingInputs(BaseModel):
goal: str
error_feedback: str = ""
screenshot: str # base64 screenshot
axtree: str
history: List[Dict[str, str]] = Field(default_factory=list)
hint: str = ""
suggestions: List[Dict[str, str]] = Field(default_factory=list)


class HintLabeling:
def __init__(self, headless: bool, window_size=(600, 1000), *args, **kwargs):

pw: playwright.sync_api.Playwright = _get_global_playwright()
self.browser = pw.chromium.launch(
headless=headless, args=[f"--window-size={window_size[0]},{window_size[1]}"]
)
self.context = self.browser.new_context(
no_viewport=True,
)
self.page = self.context.new_page()
self._resp_queue: "Queue[dict]" = Queue()

self.page.route("**/api/reprompt", self._route_reprompt)
self.page.route("**/api/submit", self._route_submit)
self.page.set_content(get_hint_labeling_ui(HINT_LABELING_DIR))

# internal state
self._context: HintLabelingInputs = None
self._running = False

def _route_reprompt(
self, route: playwright.sync_api.Route, request: playwright.sync_api.Request
):
logger.info("Route hit: %s %s", request.method, request.url)
try:
body = json.loads(request.post_data() or "{}")
except Exception:
body = {}
# enqueue output 1 (reprompt)
msg = {"type": "reprompt", "payload": {"hint": body.get("hint", "")}}
self._resp_queue.put(msg)
# Respond something minimal so UI doesn’t break; it will be refreshed by a later update_context()
route.fulfill(
status=200,
content_type="application/json",
body=json.dumps({"suggestions": []}),
)

def _route_submit(self, route: playwright.sync_api.Route, request: playwright.sync_api.Request):
logger.info("Route hit: %s %s", request.method, request.url)
try:
body = json.loads(request.post_data() or "{}")
except Exception:
body = {}
# Map UI payload -> your step shape
msg = {
"type": "step",
"payload": {
"think": body.get("think", ""),
"action": body.get("action", ""),
},
}
self._resp_queue.put(msg)
# UI expects 200 JSON; we can optionally send new suggestions here too.
route.fulfill(
status=200,
content_type="application/json",
body=json.dumps({"suggestions": []}),
)

def _to_ui_bootstrap(self, ctx: HintLabelingInputs) -> dict:
return {
"goal": ctx.goal,
"error_feedback": ctx.error_feedback,
"screenshot": ctx.screenshot,
"axtree": ctx.axtree,
"history": ctx.history,
"hint": ctx.hint,
"suggestions": ctx.suggestions,
}

def update_context(self, context: HintLabelingInputs):
self._context = context
ui_payload = self._to_ui_bootstrap(context)
# call JS function with arg (no string concat)
self.page.evaluate("(d) => updateContext(d)", ui_payload)

def wait_for_response(self, timeout: Optional[float] = 600) -> dict:
"""
Wait until the page makes a request to /api/reprompt or /api/submit,
then parse the request body and return it in your schema.
"""
logger.info("Waiting for response from Hint Labeling UI...")

def is_api(req: playwright.sync_api.Request) -> bool:
u = req.url
return (
u.endswith("/api/reprompt") or u.endswith("/api/submit")
) and req.method == "POST"

# This pumps Playwright internally; no busy waiting.
with self.page.expect_request(
is_api, timeout=(timeout * 1000 if timeout else 0)
) as req_info:
req = req_info.value

body_text = req.post_data or "{}"
try:
body = json.loads(body_text)
except Exception as e:
print("JSON parse error:", e)
body = {}

if req.url.endswith("/api/reprompt"):
msg = {"type": "reprompt", "payload": {"hint": body.get("hint", "")}}
else:
msg = {
"type": "step",
"payload": {"think": body.get("think", ""), "action": body.get("action", "")},
}

logger.info("Response received: %s", msg)
return msg

def close(self):
self.context.close()
self.browser.close()


def get_hint_labeling_ui(hint_labeling_dir) -> str:
with open(hint_labeling_dir / "hint_labeling_ui.html", "r") as file:
hint_labeling_html = file.read()
return hint_labeling_html
Loading
Loading