44import re
55from collections .abc import AsyncIterator
66from contextlib import asynccontextmanager
7- from dataclasses import dataclass
7+ from dataclasses import dataclass , field
88from typing import Callable
99
1010import gymnasium as gym
@@ -20,6 +20,7 @@ class BgymConfig:
2020 timeout_ms : int = 10000
2121 record_video_dir : str | None = None
2222 demo_mode : HighLevelActionSet .DemoMode = "default"
23+ validate_actions : list [str ] = field (default_factory = list )
2324
2425
2526@dataclass
@@ -77,6 +78,13 @@ def get_cli_args():
7778 choices = ACTION_SUBSETS .keys (),
7879 help = "Subset of actions to use" ,
7980 )
81+ parser .add_argument (
82+ "--validate_actions" ,
83+ type = str ,
84+ nargs = "+" ,
85+ default = ["click" , "goto" ],
86+ help = "Names of actions for which validation should be performed" ,
87+ )
8088 args , _ = parser .parse_known_args ()
8189 return args
8290
@@ -88,6 +96,7 @@ def get_cli_args():
8896 timeout_ms = args .timeout_ms ,
8997 record_video_dir = args .record_video_dir ,
9098 demo_mode = args .demo_mode ,
99+ validate_actions = args .validate_actions ,
91100)
92101
93102
@@ -117,7 +126,7 @@ async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]:
117126mcp = FastMCP ("BrowserGym" , lifespan = app_lifespan )
118127
119128
120- def fn_wrapper (func : Callable ):
129+ def fn_wrapper (func : Callable , validate : bool = True ):
121130 async def decorator (* args , ** kwargs ):
122131 """
123132 Decorator to execute function from the action space in the context of the gym.
@@ -157,7 +166,7 @@ async def decorator(*args, **kwargs):
157166 if match :
158167 info ["action_exec_timeout" ] = float (match .groups ()[0 ]) / 1000
159168
160- results = await asyncio .to_thread (gym .post_step , info )
169+ results = await asyncio .to_thread (gym .post_step , info , validate )
161170 return results
162171
163172 decorator .__wrapped__ = func # type: ignore
@@ -167,7 +176,8 @@ async def decorator(*args, **kwargs):
167176
168177
169178for fn in ACTION_SUBSETS [args .subset ]:
170- mcp .add_tool (fn_wrapper (fn ))
179+ validate = fn .__name__ in config .validate_actions
180+ mcp .add_tool (fn_wrapper (fn , validate ))
171181
172182if __name__ == "__main__" :
173183 mcp .run (transport = "stdio" )
0 commit comments