Skip to content

Commit 544bf9f

Browse files
ollmerTLSDC
andauthored
MCP Server for the gym (#337)
* basic MCP server that exposes all tools from the action set * separate pre and post-processing of a gym step from the step execution * MCP server wraps every function into an async wrapper that calls the gym and sets up required global vars --------- Co-authored-by: ThibaultLSDC <thibault.de.chezelles@gmail.com>
1 parent c3336ef commit 544bf9f

File tree

5 files changed

+265
-25
lines changed

5 files changed

+265
-25
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,5 @@ tests/assistantbench/assistantbench-predictions-test.jsonl
150150

151151
# weblinx
152152
bg_wl_data/
153+
154+
uv.lock

browsergym/core/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ pyparsing>=3
55
Pillow>=10.1
66
beautifulsoup4>=4.12
77
lxml>=4.9
8+
mcp[cli]>=1.6.0

browsergym/core/src/browsergym/core/env.py

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
from abc import ABC
66
from pathlib import Path
7-
from typing import Literal, Optional
7+
from typing import Any, Callable, Literal, Optional
88

99
import gymnasium as gym
1010
import numpy as np
@@ -371,10 +371,7 @@ def override_property(task, env, property):
371371

372372
return obs, info
373373

374-
def step(self, action: str) -> tuple:
375-
376-
self.last_action = action
377-
374+
def pre_step(self) -> tuple[dict[str, Any], Callable, Callable]:
378375
info = {}
379376
info["action_exec_start"] = time.time()
380377
info["action_exec_timeout"] = 0
@@ -391,7 +388,25 @@ def report_infeasible_instructions(reason: str):
391388
self.infeasible_message_received = True
392389

393390
# try to execute the action
394-
logger.debug(f"Executing action")
391+
logger.debug("Executing action")
392+
return info, send_message_to_user, report_infeasible_instructions
393+
394+
def step(self, action: str) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
395+
"""
396+
Execute the action in the environment.
397+
398+
Args:
399+
action: the action to execute. This should be a string with code or a function call
400+
401+
Returns:
402+
obs: the observation after executing the action
403+
reward: the reward received after executing the action
404+
terminated: whether the episode is terminated or not
405+
truncated: whether the episode is truncated or not
406+
info: additional information about the step
407+
"""
408+
self.last_action = action
409+
info, send_message_to_user, report_infeasible_instructions = self.pre_step()
395410
try:
396411
if self.action_mapping:
397412
code = self.action_mapping(action)
@@ -409,7 +424,25 @@ def report_infeasible_instructions(reason: str):
409424
match = re.match("TimeoutError: Timeout ([0-9]+)ms exceeded.", self.last_action_error)
410425
if match:
411426
info["action_exec_timeout"] = float(match.groups()[0]) / 1000 # ms to sec
412-
logger.debug(f"Action executed")
427+
return self.post_step(info)
428+
429+
def post_step(
430+
self, info: dict[str, Any], validate: bool = True
431+
) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
432+
"""
433+
Post step method, called after executing the action.
434+
This method is responsible for extracting the observation after the action.
435+
It also prepares reward, task status, user message and other step info.
436+
Args:
437+
info: dictionary containing information about the step
438+
Returns:
439+
obs: the observation after executing the action
440+
reward: the reward received after executing the action
441+
terminated: whether the episode is terminated or not
442+
truncated: whether the episode is truncated or not
443+
info: additional information about the step
444+
"""
445+
logger.debug("Action executed")
413446
info["action_exec_stop"] = time.time()
414447

415448
# wait a bit (for the JavaScript callback to set the active page)
@@ -419,35 +452,41 @@ def report_infeasible_instructions(reason: str):
419452
# wait for the network to idle before extracting the observation, reward etc.
420453
self._wait_dom_loaded()
421454

422-
# after the action is executed, the active page might have changed
423-
# perform a safety check
424-
self._active_page_check()
425-
logger.debug(f"Active page checked")
426-
427-
# if asked, wait for user message
428-
self._wait_for_user_message()
429-
logger.debug(f"User message done")
430-
431-
logger.debug(f"Initiating task validation")
432-
# extract reward, done, user_message, info (task-specific)
433-
reward, done, user_message, task_info = self._task_validate()
434-
info["task_info"] = task_info
435-
logger.debug(f"Task validation done")
455+
if validate:
456+
# after the action is executed, the active page might have changed
457+
# perform a safety check
458+
self._active_page_check()
459+
logger.debug("Active page checked")
460+
461+
# if asked, wait for user message
462+
self._wait_for_user_message()
463+
logger.debug("User message done")
464+
465+
logger.debug("Initiating task validation")
466+
# extract reward, done, user_message, info (task-specific)
467+
reward, done, user_message, task_info = self._task_validate()
468+
info["task_info"] = task_info
469+
logger.debug("Task validation done")
470+
else:
471+
reward = 0
472+
done = False
473+
user_message = None
474+
info["task_info"] = {}
475+
logger.debug("Task validation skipped")
436476

437477
# add any user message sent by the task to the chat
438478
if user_message:
439479
self.chat.add_message(role="user", msg=user_message)
440480

441481
# extract observation (generic)
442482
obs = self._get_obs()
443-
logger.debug(f"Observation extracted")
483+
logger.debug("Observation extracted")
444484

445485
# new step API wants a 5-tuple (gymnasium)
446486
terminated = done or (
447487
self.terminate_on_infeasible and self.infeasible_message_received
448488
) # task or agent can terminate the episode
449-
truncated = False
450-
489+
truncated: bool = False
451490
return obs, reward, terminated, truncated, info
452491

453492
def _task_validate(self):
@@ -506,7 +545,7 @@ def _active_page_check(self):
506545
# make sure there is always a page open
507546
# if all pages have been closed, create a new page
508547
if len(self.context.pages) == 0:
509-
logger.warning(f"All pages are closed, opening a new page.")
548+
logger.warning("All pages are closed, opening a new page.")
510549
self.page = self.context.new_page()
511550

512551
# if the active page got closed, get the last active page from the history
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# MCP server for BrowserGym
2+
import argparse
3+
import asyncio
4+
import re
5+
from collections.abc import AsyncIterator
6+
from contextlib import asynccontextmanager
7+
from dataclasses import dataclass, field
8+
from typing import Callable
9+
10+
import gymnasium as gym
11+
from mcp.server.fastmcp import FastMCP
12+
13+
from browsergym.core.action.highlevel import ACTION_SUBSETS, HighLevelActionSet
14+
from browsergym.core.env import BrowserEnv
15+
16+
17+
@dataclass
18+
class BgymConfig:
19+
headless: bool = True
20+
timeout_ms: int = 10000
21+
record_video_dir: str | None = None
22+
demo_mode: HighLevelActionSet.DemoMode = "default"
23+
validate_actions: list[str] = field(default_factory=list)
24+
25+
26+
@dataclass
27+
class AppContext:
28+
gym: BrowserEnv
29+
config: BgymConfig
30+
task_id: str
31+
actions: HighLevelActionSet
32+
33+
34+
def get_cli_args():
35+
parser = argparse.ArgumentParser(
36+
description="BrowserGym MCP server",
37+
usage="python browsergym/core/src/browsergym/utils/%(prog)s [options]",
38+
epilog="To run Dev UI: mcp dev browsergym/core/src/browsergym/utils/mcp_server.py -e browsergym/core/",
39+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
40+
)
41+
parser.add_argument(
42+
"-t",
43+
"--task_id",
44+
type=str,
45+
default="browsergym/openended",
46+
help="Task ID to run",
47+
)
48+
parser.add_argument(
49+
"-l",
50+
"--headless",
51+
action="store_true",
52+
help="Run in headless mode",
53+
)
54+
parser.add_argument(
55+
"-r",
56+
"--record_video_dir",
57+
type=str,
58+
default=None,
59+
help="Directory to save recorded videos",
60+
)
61+
parser.add_argument(
62+
"--demo_mode",
63+
type=str,
64+
default="off",
65+
choices=["off", "default", "all_blue", "only_visible_elements"],
66+
help="Demo mode for action set",
67+
)
68+
parser.add_argument(
69+
"--timeout_ms",
70+
type=int,
71+
default=10000,
72+
help="Timeout in milliseconds for each step",
73+
)
74+
parser.add_argument(
75+
"--subset",
76+
type=str,
77+
default="workarena++",
78+
choices=ACTION_SUBSETS.keys(),
79+
help="Subset of actions to use",
80+
)
81+
parser.add_argument(
82+
"--validate_actions",
83+
type=str,
84+
nargs="+",
85+
default=["click", "goto"],
86+
help="Names of actions for which validation should be performed",
87+
)
88+
args, _ = parser.parse_known_args()
89+
return args
90+
91+
92+
args = get_cli_args()
93+
task_id = args.task_id
94+
config = BgymConfig(
95+
headless=args.headless,
96+
timeout_ms=args.timeout_ms,
97+
record_video_dir=args.record_video_dir,
98+
demo_mode=args.demo_mode,
99+
validate_actions=args.validate_actions,
100+
)
101+
102+
103+
@asynccontextmanager
104+
async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]:
105+
"""Manage application lifecycle with type-safe context"""
106+
# Initialize on startup
107+
actions = HighLevelActionSet(demo_mode=config.demo_mode, subsets=args.subset)
108+
_gym: BrowserEnv = await asyncio.to_thread(
109+
gym.make,
110+
task_id,
111+
headless=config.headless,
112+
record_video_dir=config.record_video_dir,
113+
action_mapping=actions.to_python_code,
114+
timeout=config.timeout_ms,
115+
task_kwargs={"start_url": "about:blank"},
116+
) # type: ignore
117+
await asyncio.to_thread(_gym.reset)
118+
119+
try:
120+
yield AppContext(gym=_gym, config=config, task_id=task_id, actions=actions)
121+
finally:
122+
# Cleanup on shutdown
123+
await asyncio.to_thread(_gym.close)
124+
125+
126+
mcp = FastMCP("BrowserGym", lifespan=app_lifespan)
127+
128+
129+
def format_func_call(func: Callable, args, kwargs) -> str:
130+
args_str = ", ".join(repr(arg) for arg in args)
131+
kwargs_str = ", ".join(f"{k}={repr(v)}" for k, v in kwargs.items())
132+
all_args_str = ", ".join(filter(None, [args_str, kwargs_str]))
133+
return f"{func.__name__}({all_args_str})"
134+
135+
136+
def fn_wrapper(func: Callable, validate: bool = True):
137+
async def decorator(*args, **kwargs):
138+
"""
139+
Decorator to execute function from the action space in the context of the gym.
140+
1. Loads the parent module of the function to use as function context
141+
2. Executes the pre_step method of the gym
142+
3. Sets up the module vars from the current state of the gym
143+
4. Executes the function from this module and handles any exceptions
144+
5. Executes the post_step method of the gym
145+
146+
"""
147+
gym: BrowserEnv = mcp.get_context().request_context.lifespan_context.gym # type: ignore
148+
while not isinstance(gym, BrowserEnv):
149+
gym = (
150+
gym.env
151+
) # gym library wraps the BrowserEnv in a few layers (usually 2) of wrappers, this loop unwraps them
152+
153+
# Load the parent module of the function to use as function context
154+
import browsergym.core.action.functions as fn_context
155+
156+
fn = getattr(fn_context, func.__name__)
157+
158+
gym.last_action = format_func_call(fn, args, kwargs)
159+
info, send_message_to_user, report_infeasible_instructions = await asyncio.to_thread(
160+
gym.pre_step
161+
)
162+
163+
# Set up the module vars from the current state of the gym
164+
fn_context.send_message_to_user = send_message_to_user
165+
fn_context.report_infeasible_instructions = report_infeasible_instructions
166+
fn_context.page = gym.page
167+
fn_context.demo_mode = config.demo_mode
168+
169+
try:
170+
fn(*args, **kwargs)
171+
gym.last_action_error = ""
172+
except Exception as e:
173+
gym.last_action_error = f"{type(e).__name__}: {e}"
174+
match = re.match("TimeoutError: Timeout ([0-9]+)ms exceeded.", gym.last_action_error)
175+
if match:
176+
info["action_exec_timeout"] = float(match.groups()[0]) / 1000
177+
178+
results = await asyncio.to_thread(gym.post_step, info, validate)
179+
return results
180+
181+
decorator.__wrapped__ = func # type: ignore
182+
decorator.__name__ = func.__name__
183+
decorator.__doc__ = func.__doc__
184+
return decorator
185+
186+
187+
for fn in ACTION_SUBSETS[args.subset]:
188+
validate = fn.__name__ in config.validate_actions
189+
mcp.add_tool(fn_wrapper(fn, validate))
190+
191+
if __name__ == "__main__":
192+
mcp.run(transport="stdio")

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
[project]
2+
name = "browsergym-meta"
3+
description = "BrowserGym: a gym environment for web task automation in the Chromium browser"
4+
dynamic = ["version"]
5+
[tool.setuptools]
6+
packages = [] # meta distribution, packages are included as dependencies
17
[tool.black]
28
line-length = 100
39
include = '\.pyi?$'

0 commit comments

Comments
 (0)