Skip to content

Commit d75e3e6

Browse files
committed
selective action validation
1 parent 6b046eb commit d75e3e6

File tree

2 files changed

+36
-19
lines changed

2 files changed

+36
-19
lines changed

browsergym/core/src/browsergym/core/env.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def step(self, action: str) -> tuple[dict[str, Any], float, bool, bool, dict[str
427427
return self.post_step(info)
428428

429429
def post_step(
430-
self, info: dict[str, Any]
430+
self, info: dict[str, Any], validate: bool = True
431431
) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
432432
"""
433433
Post step method, called after executing the action.
@@ -452,20 +452,27 @@ def post_step(
452452
# wait for the network to idle before extracting the observation, reward etc.
453453
self._wait_dom_loaded()
454454

455-
# after the action is executed, the active page might have changed
456-
# perform a safety check
457-
self._active_page_check()
458-
logger.debug("Active page checked")
459-
460-
# if asked, wait for user message
461-
self._wait_for_user_message()
462-
logger.debug("User message done")
463-
464-
logger.debug("Initiating task validation")
465-
# extract reward, done, user_message, info (task-specific)
466-
reward, done, user_message, task_info = self._task_validate()
467-
info["task_info"] = task_info
468-
logger.debug("Task validation done")
455+
if validate:
456+
# after the action is executed, the active page might have changed
457+
# perform a safety check
458+
self._active_page_check()
459+
logger.debug("Active page checked")
460+
461+
# if asked, wait for user message
462+
self._wait_for_user_message()
463+
logger.debug("User message done")
464+
465+
logger.debug("Initiating task validation")
466+
# extract reward, done, user_message, info (task-specific)
467+
reward, done, user_message, task_info = self._task_validate()
468+
info["task_info"] = task_info
469+
logger.debug("Task validation done")
470+
else:
471+
reward = 0
472+
done = False
473+
user_message = None
474+
info["task_info"] = {}
475+
logger.debug("Task validation skipped")
469476

470477
# add any user message sent by the task to the chat
471478
if user_message:

browsergym/core/src/browsergym/utils/mcp_server.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import re
55
from collections.abc import AsyncIterator
66
from contextlib import asynccontextmanager
7-
from dataclasses import dataclass
7+
from dataclasses import dataclass, field
88
from typing import Callable
99

1010
import gymnasium as gym
@@ -20,6 +20,7 @@ class BgymConfig:
2020
timeout_ms: int = 10000
2121
record_video_dir: str | None = None
2222
demo_mode: HighLevelActionSet.DemoMode = "default"
23+
validate_actions: list[str] = field(default_factory=list)
2324

2425

2526
@dataclass
@@ -77,6 +78,13 @@ def get_cli_args():
7778
choices=ACTION_SUBSETS.keys(),
7879
help="Subset of actions to use",
7980
)
81+
parser.add_argument(
82+
"--validate_actions",
83+
type=str,
84+
nargs="+",
85+
default=["click", "goto"],
86+
help="Names of actions for which validation should be performed",
87+
)
8088
args, _ = parser.parse_known_args()
8189
return args
8290

@@ -88,6 +96,7 @@ def get_cli_args():
8896
timeout_ms=args.timeout_ms,
8997
record_video_dir=args.record_video_dir,
9098
demo_mode=args.demo_mode,
99+
validate_actions=args.validate_actions,
91100
)
92101

93102

@@ -117,7 +126,7 @@ async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]:
117126
mcp = FastMCP("BrowserGym", lifespan=app_lifespan)
118127

119128

120-
def fn_wrapper(func: Callable):
129+
def fn_wrapper(func: Callable, validate: bool = True):
121130
async def decorator(*args, **kwargs):
122131
"""
123132
Decorator to execute function from the action space in the context of the gym.
@@ -157,7 +166,7 @@ async def decorator(*args, **kwargs):
157166
if match:
158167
info["action_exec_timeout"] = float(match.groups()[0]) / 1000
159168

160-
results = await asyncio.to_thread(gym.post_step, info)
169+
results = await asyncio.to_thread(gym.post_step, info, validate)
161170
return results
162171

163172
decorator.__wrapped__ = func # type: ignore
@@ -167,7 +176,8 @@ async def decorator(*args, **kwargs):
167176

168177

169178
for fn in ACTION_SUBSETS[args.subset]:
170-
mcp.add_tool(fn_wrapper(fn))
179+
validate =fn.__name__ in config.validate_actions
180+
mcp.add_tool(fn_wrapper(fn, validate))
171181

172182
if __name__ == "__main__":
173183
mcp.run(transport="stdio")

0 commit comments

Comments
 (0)