diff --git a/browsergym/core/src/browsergym/core/env.py b/browsergym/core/src/browsergym/core/env.py index 12c5e8bc..26581d3a 100644 --- a/browsergym/core/src/browsergym/core/env.py +++ b/browsergym/core/src/browsergym/core/env.py @@ -15,6 +15,7 @@ from .action.highlevel import HighLevelActionSet from .chat import Chat from .constants import BROWSERGYM_ID_ATTRIBUTE, EXTRACT_OBS_MAX_TRIES +from .hint_labeling import HintLabeling from .observation import ( MarkingError, _post_extract, @@ -68,6 +69,8 @@ def __init__( tags_to_mark: Literal["all", "standard_html"] = "standard_html", # interactive / debugging arguments headless: bool = True, + use_chat_ui: bool = False, + use_hint_labeling_ui: bool = False, wait_for_user_message: bool = False, terminate_on_infeasible: bool = True, resizeable_window: bool = False, @@ -120,6 +123,8 @@ def __init__( self.action_mapping = action_mapping self.use_raw_page_output = use_raw_page_output self.pre_observation_delay = pre_observation_delay + self.use_chat_ui = use_chat_ui + self.use_hint_labeling_ui = use_hint_labeling_ui # check argument values assert tags_to_mark in ("all", "standard_html") @@ -136,6 +141,10 @@ def __init__( # chat self.chat: Chat = None + # hint labeling + if self.use_hint_labeling_ui: + self.hint_labeling: HintLabeling = None + # observation space if use_raw_page_output: self.observation_space = gym.spaces.Dict( @@ -211,6 +220,10 @@ def close(self): if self.chat: self.chat.close() self.chat = None + + if self.use_hint_labeling_ui and self.hint_labeling: + self.hint_labeling.close() + self.hint_labeling = None # close the browser context if self.context: self.context.close() @@ -228,6 +241,8 @@ def reset(self, seed=None, *args, **kwargs): self.task.teardown() self.context.close() self.chat.close() + if self.use_hint_labeling_ui and self.hint_labeling: + self.hint_labeling.close() self.browser.close() # create a new task @@ -327,12 +342,22 @@ def override_property(task, env, property): ) # create the chat + # chat always exist even is use_chat_ui is False since we use it to keep track of messages + # we only show the chat ui if both headless is False and use_chat_ui is True self.chat = Chat( - headless=self.headless, + headless=not (self.headless is False and self.use_chat_ui is True), chat_size=(500, max(viewport["height"], 800)), record_video_dir=self.record_video_dir, ) + # create the hint labeling ui + if self.use_hint_labeling_ui: + self.hint_labeling = HintLabeling( + headless=self.headless, + window_size=(600, max(viewport["height"], 1000)), + record_video_dir=self.record_video_dir, + ) + # create a new page self.page = self.context.new_page() recording_start_time = time.time() diff --git a/browsergym/core/src/browsergym/core/hint_labeling.py b/browsergym/core/src/browsergym/core/hint_labeling.py new file mode 100644 index 00000000..45efac48 --- /dev/null +++ b/browsergym/core/src/browsergym/core/hint_labeling.py @@ -0,0 +1,150 @@ +import json +import logging +from importlib import resources +from queue import Queue +from typing import Dict, List, Optional + +import playwright.sync_api +from pydantic import BaseModel, Field + +from . import _get_global_playwright, hint_labeling_files + +logger = logging.getLogger(__name__) + +HINT_LABELING_DIR = resources.files(hint_labeling_files) + + +class HintLabelingInputs(BaseModel): + goal: str + error_feedback: str = "" + screenshot: str # base64 screenshot + axtree: str + history: List[Dict[str, str]] = Field(default_factory=list) + hint: str = "" + suggestions: List[Dict[str, str]] = Field(default_factory=list) + + +class HintLabeling: + def __init__(self, headless: bool, window_size=(600, 1000), *args, **kwargs): + + pw: playwright.sync_api.Playwright = _get_global_playwright() + self.browser = pw.chromium.launch( + headless=headless, args=[f"--window-size={window_size[0]},{window_size[1]}"] + ) + self.context = self.browser.new_context( + no_viewport=True, + ) + self.page = self.context.new_page() + self._resp_queue: "Queue[dict]" = Queue() + + self.page.route("**/api/reprompt", self._route_reprompt) + self.page.route("**/api/submit", self._route_submit) + self.page.set_content(get_hint_labeling_ui(HINT_LABELING_DIR)) + + # internal state + self._context: HintLabelingInputs = None + self._running = False + + def _route_reprompt( + self, route: playwright.sync_api.Route, request: playwright.sync_api.Request + ): + logger.info("Route hit: %s %s", request.method, request.url) + try: + body = json.loads(request.post_data() or "{}") + except Exception: + body = {} + # enqueue output 1 (reprompt) + msg = {"type": "reprompt", "payload": {"hint": body.get("hint", "")}} + self._resp_queue.put(msg) + # Respond something minimal so UI doesn’t break; it will be refreshed by a later update_context() + route.fulfill( + status=200, + content_type="application/json", + body=json.dumps({"suggestions": []}), + ) + + def _route_submit(self, route: playwright.sync_api.Route, request: playwright.sync_api.Request): + logger.info("Route hit: %s %s", request.method, request.url) + try: + body = json.loads(request.post_data() or "{}") + except Exception: + body = {} + # Map UI payload -> your step shape + msg = { + "type": "step", + "payload": { + "think": body.get("think", ""), + "action": body.get("action", ""), + }, + } + self._resp_queue.put(msg) + # UI expects 200 JSON; we can optionally send new suggestions here too. + route.fulfill( + status=200, + content_type="application/json", + body=json.dumps({"suggestions": []}), + ) + + def _to_ui_bootstrap(self, ctx: HintLabelingInputs) -> dict: + return { + "goal": ctx.goal, + "error_feedback": ctx.error_feedback, + "screenshot": ctx.screenshot, + "axtree": ctx.axtree, + "history": ctx.history, + "hint": ctx.hint, + "suggestions": ctx.suggestions, + } + + def update_context(self, context: HintLabelingInputs): + self._context = context + ui_payload = self._to_ui_bootstrap(context) + # call JS function with arg (no string concat) + self.page.evaluate("(d) => updateContext(d)", ui_payload) + + def wait_for_response(self, timeout: Optional[float] = 600) -> dict: + """ + Wait until the page makes a request to /api/reprompt or /api/submit, + then parse the request body and return it in your schema. + """ + logger.info("Waiting for response from Hint Labeling UI...") + + def is_api(req: playwright.sync_api.Request) -> bool: + u = req.url + return ( + u.endswith("/api/reprompt") or u.endswith("/api/submit") + ) and req.method == "POST" + + # This pumps Playwright internally; no busy waiting. + with self.page.expect_request( + is_api, timeout=(timeout * 1000 if timeout else 0) + ) as req_info: + req = req_info.value + + body_text = req.post_data or "{}" + try: + body = json.loads(body_text) + except Exception as e: + print("JSON parse error:", e) + body = {} + + if req.url.endswith("/api/reprompt"): + msg = {"type": "reprompt", "payload": {"hint": body.get("hint", "")}} + else: + msg = { + "type": "step", + "payload": {"think": body.get("think", ""), "action": body.get("action", "")}, + } + + logger.info("Response received: %s", msg) + return msg + + def close(self): + self.context.close() + self.browser.close() + + +def get_hint_labeling_ui(hint_labeling_dir) -> str: + with open(hint_labeling_dir / "hint_labeling_ui.html", "r") as file: + hint_labeling_html = file.read() + return hint_labeling_html diff --git a/browsergym/core/src/browsergym/core/hint_labeling_files/hint_labeling_ui.html b/browsergym/core/src/browsergym/core/hint_labeling_files/hint_labeling_ui.html new file mode 100644 index 00000000..6fc9c636 --- /dev/null +++ b/browsergym/core/src/browsergym/core/hint_labeling_files/hint_labeling_ui.html @@ -0,0 +1,348 @@ + + + + + + + Agent Reprompt UI + + + +
+ +
+
+

Goal

+
+
+
+

Error Feedback

+
+
+
+ + +
+
+ + + +
+
+
+ screenshot +
+ + +
+
+ + +
+

Hints

+ + + +
+ + +
+

Suggestions

+
+ + + +
+
+ + + + diff --git a/browsergym/experiments/src/browsergym/experiments/loop.py b/browsergym/experiments/src/browsergym/experiments/loop.py index 0ebb9e94..40d36dd5 100644 --- a/browsergym/experiments/src/browsergym/experiments/loop.py +++ b/browsergym/experiments/src/browsergym/experiments/loop.py @@ -17,16 +17,15 @@ from pathlib import Path from typing import Optional -from browsergym.core.env import BrowserEnv import gymnasium as gym import numpy as np +from browsergym.core.action.parsers import highlevel_action_parser +from browsergym.core.chat import Chat +from browsergym.core.env import BrowserEnv from dataclasses_json import DataClassJsonMixin from PIL import Image from tqdm import tqdm -from browsergym.core.action.parsers import highlevel_action_parser -from browsergym.core.chat import Chat - from .agent import Agent from .utils import count_messages_token, count_tokens @@ -47,6 +46,8 @@ class EnvArgs(DataClassJsonMixin): slow_mo: Optional[int] = None # use default value from BrowserGym storage_state: Optional[str | Path | dict] = None task_kwargs: Optional[dict] = None # use default value from BrowserGym + use_chat_ui: bool = False + use_hint_labeling_ui: bool = False def make_env( self, action_mapping, exp_dir, exp_task_kwargs: dict = {}, use_raw_page_output=False