diff --git a/environments/mcp_envs/environments/http_mcp_env/README.md b/environments/mcp_envs/environments/http_mcp_env/README.md new file mode 100644 index 000000000..bc25bb4cd --- /dev/null +++ b/environments/mcp_envs/environments/http_mcp_env/README.md @@ -0,0 +1,51 @@ +# http-mcp-env + +> Replace the placeholders below, then remove this callout. + +### Overview +- **Environment ID**: `http-mcp-env` +- **Short description**: +- **Tags**: + +### Datasets +- **Primary dataset(s)**: +- **Source links**: +- **Split sizes**: + +### Task +- **Type**: +- **Parser**: +- **Rubric overview**: + +### Quickstart +Run an evaluation with default settings: + +```bash +uv run vf-eval http-mcp-env +``` + +Configure model and sampling: + +```bash +uv run vf-eval http-mcp-env -m gpt-4.1-mini -n 20 -r 3 -t 1024 -T 0.7 -a '{"key": "value"}' # env-specific args as JSON +``` + +Notes: +- Use `-a` / `--env-args` to pass environment-specific configuration as a JSON object. + +### Environment Arguments +Document any supported environment arguments and their meaning. Example: + +| Arg | Type | Default | Description | +| --- | ---- | ------- | ----------- | +| `foo` | str | `"bar"` | What this controls | +| `max_examples` | int | `-1` | Limit on dataset size (use -1 for all) | + +### Metrics +Summarize key metrics your rubric emits and how they’re interpreted. + +| Metric | Meaning | +| ------ | ------- | +| `reward` | Main scalar reward (weighted sum of criteria) | +| `accuracy` | Exact match on target answer | + diff --git a/environments/mcp_envs/environments/http_mcp_env/http_mcp_env.py b/environments/mcp_envs/environments/http_mcp_env/http_mcp_env.py new file mode 100644 index 000000000..a6196a402 --- /dev/null +++ b/environments/mcp_envs/environments/http_mcp_env/http_mcp_env.py @@ -0,0 +1,59 @@ +import verifiers as vf +from datasets import Dataset +import os +from dotenv import load_dotenv +from urllib.parse import urlencode + +load_dotenv() + + +def get_remote_url(): + API_KEY = os.getenv("SMITHERY_API_KEY") + PROFILE = os.getenv("SMITHERY_PROFILE") + base_url = "https://server.smithery.ai/@smithery-ai/fetch/mcp" + params = {"api_key": API_KEY, "profile": PROFILE} + smithery_url = f"{base_url}?{urlencode(params)}" + return smithery_url + + +def load_environment(**kwargs): + remote_url = get_remote_url() + ds = Dataset.from_dict( + { + "question": [ + "Find out what Prime Intellect's newest announcement was from their website, give me the headline in 2 words. Their url is primeintellect.ai", + ], + "answer": ["ENVIRONMENTS HUB"], # Or whatever the actual top result is + } + ) + + rub = vf.JudgeRubric(judge_model="gpt-4.1-mini") + + async def judge_reward(judge, prompt, completion, answer, state): + judge_response = await judge(prompt, completion, answer, state) + return 1.0 if "yes" in judge_response.lower() else 0.0 + + rub.add_reward_func(judge_reward, weight=1.0) + + env = vf.MCPEnv( + mcp_servers=[ + { + "name": "fetch-mcp-server-http", + "command": "node", # Not used for HTTP transport, but required by MCPServerConfig + "args": [], + } + ], + transport_type="http", + http_urls={ + "fetch-mcp-server-http": remote_url + }, + connection_scope="session", # Reuse connection across rollouts + http_timeout=60.0, + http_max_retries=3, + dataset=ds, + rubric=rub, + max_turns=10, + **kwargs + ) + + return env diff --git a/environments/mcp_envs/environments/http_mcp_env/pyproject.toml b/environments/mcp_envs/environments/http_mcp_env/pyproject.toml new file mode 100644 index 000000000..f4b622cd0 --- /dev/null +++ b/environments/mcp_envs/environments/http_mcp_env/pyproject.toml @@ -0,0 +1,23 @@ +[project] +name = "http-mcp-env" +description = "Your environment description here" +tags = ["placeholder-tag", "train", "eval"] +version = "0.1.0" +requires-python = ">=3.10" +dependencies = [ + "verifiers", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["http_mcp_env.py", "pyproject.toml"] + +[tool.verifiers.eval] +num_examples = 5 +rollouts_per_example = 3 + +[tool.uv.sources] +verifiers = { path = "../../../../" } diff --git a/environments/mcp_envs/environments/sandbox_mcp_env/README.md b/environments/mcp_envs/environments/sandbox_mcp_env/README.md new file mode 100644 index 000000000..99a309c41 --- /dev/null +++ b/environments/mcp_envs/environments/sandbox_mcp_env/README.md @@ -0,0 +1,57 @@ +# sandbox-mcp-env + +1. Define the MCP server or servers you want to use +2. Setup state will + 1. Start by creating a sandbox for the rollout and exposing a port + 2. Then create transport(s) for the mcp servers which provide the interface for using the server + 3. It will run any necessary commands required for the mcp server + 4. Run the server in StreamableHTTP mode + 5. Finally register the MCP server's available tools +3. Rollout proceeds and agent can make mcp tool calls that are safe to interact within the sandbox + +### Overview +- **Environment ID**: `sandbox-mcp-env` +- **Short description**: MCPEnv via sandboxed streaming http MCP servers +- **Tags**: mcp, sandbox + +### Datasets +- **Primary dataset(s)**: NA +- **Source links**: NA +- **Split sizes**: NA + +### Task +- **Type**: tool use +- **Parser**: NA +- **Rubric overview**: NA + +### Quickstart +Run an evaluation with default settings: + +```bash +uv run vf-eval sandbox-mcp-env -n 1 -r 1 +``` + +Configure model and sampling: + +```bash +uv run vf-eval sandbox-mcp-env -m gpt-4.1-mini -n 20 -r 3 -t 1024 -T 0.7 -a '{"key": "value"}' # env-specific args as JSON +``` + +Notes: +- Use `-a` / `--env-args` to pass environment-specific configuration as a JSON object. + +### Environment Arguments +Demo + +| Arg | Type | Default | Description | +| --- | ---- | ------- | ----------- | +| `max_examples` | int | `-1` | Limit on dataset size (use -1 for all) | + +### Metrics +Demo + +| Metric | Meaning | +| ------ | ------- | +| `reward` | Main scalar reward (weighted sum of criteria) | +| `accuracy` | Exact match on target answer | + diff --git a/environments/mcp_envs/environments/sandbox_mcp_env/pyproject.toml b/environments/mcp_envs/environments/sandbox_mcp_env/pyproject.toml new file mode 100644 index 000000000..9236e2a1f --- /dev/null +++ b/environments/mcp_envs/environments/sandbox_mcp_env/pyproject.toml @@ -0,0 +1,24 @@ +[project] +name = "sandbox-mcp-env" +description = "Your environment description here" +tags = ["placeholder-tag", "train", "eval"] +version = "0.1.0" +requires-python = ">=3.10" +dependencies = [ + "verifiers>=0.1.7.post0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["sandbox_mcp_env.py", "pyproject.toml"] + +[tool.verifiers.eval] +num_examples = 5 +rollouts_per_example = 3 + +[tool.uv.sources] +verifiers = { path = "../../../.." } + diff --git a/environments/mcp_envs/environments/sandbox_mcp_env/sandbox_mcp_env.py b/environments/mcp_envs/environments/sandbox_mcp_env/sandbox_mcp_env.py new file mode 100644 index 000000000..e4e8a81b9 --- /dev/null +++ b/environments/mcp_envs/environments/sandbox_mcp_env/sandbox_mcp_env.py @@ -0,0 +1,63 @@ +import verifiers as vf +from verifiers.envs.mcp.mcp_utils.models import MCPServerConfig +from datasets import Dataset + + +def load_environment(**kwargs): + ds = Dataset.from_dict( + { + "question": [ + "Check out what tools are available and try one that looks interesting to you", + ], + "answer": ["Hello World"], + } + ) + + rubric = vf.JudgeRubric(judge_model="gpt-4.1-mini") + + async def judge_reward(judge, prompt, completion, answer, state): + judge_response = await judge(prompt, completion, answer, state) + return 1.0 if "yes" in judge_response.lower() else 0.0 + + rubric.add_reward_func(judge_reward, weight=1.0) + + env = vf.MCPEnv( + mcp_servers=[ + MCPServerConfig( + name="everything-mcp", + command="npx", + args=[ + "@modelcontextprotocol/server-everything", + "streamableHttp", + ], + env={ + "PORT": "8000", + }, + setup_commands=[ + "apt update", + "apt upgrade -y", + "apt install -y git curl", + "curl -fsSL https://deb.nodesource.com/setup_lts.x | bash -", + "apt-get install -y nodejs", + "npm install -g @modelcontextprotocol/server-everything@latest", + ], + ) + ], + transport_type="sandbox", + connection_scope="rollout", # Each rollout gets its own sandbox + # Sandbox configuration + sandbox_image="python:3.11-slim", + sandbox_start_command="tail -f /dev/null", + sandbox_cpu_cores=1, + sandbox_memory_gb=2, + sandbox_disk_size_gb=5, + sandbox_timeout_minutes=15, + sandbox_port_to_expose=8000, # Port the MCP server listens on + # Standard env options + dataset=ds, + rubric=rubric, + max_turns=10, + **kwargs + ) + + return env diff --git a/environments/mcp_envs/environments/stdio_mcp_env/README.md b/environments/mcp_envs/environments/stdio_mcp_env/README.md new file mode 100644 index 000000000..8118e09a0 --- /dev/null +++ b/environments/mcp_envs/environments/stdio_mcp_env/README.md @@ -0,0 +1,51 @@ +# stdio-mcp-env + +> Replace the placeholders below, then remove this callout. + +### Overview +- **Environment ID**: `stdio-mcp-env` +- **Short description**: +- **Tags**: + +### Datasets +- **Primary dataset(s)**: +- **Source links**: +- **Split sizes**: + +### Task +- **Type**: +- **Parser**: +- **Rubric overview**: + +### Quickstart +Run an evaluation with default settings: + +```bash +uv run vf-eval stdio-mcp-env +``` + +Configure model and sampling: + +```bash +uv run vf-eval stdio-mcp-env -m gpt-4.1-mini -n 20 -r 3 -t 1024 -T 0.7 -a '{"key": "value"}' # env-specific args as JSON +``` + +Notes: +- Use `-a` / `--env-args` to pass environment-specific configuration as a JSON object. + +### Environment Arguments +Document any supported environment arguments and their meaning. Example: + +| Arg | Type | Default | Description | +| --- | ---- | ------- | ----------- | +| `foo` | str | `"bar"` | What this controls | +| `max_examples` | int | `-1` | Limit on dataset size (use -1 for all) | + +### Metrics +Summarize key metrics your rubric emits and how they’re interpreted. + +| Metric | Meaning | +| ------ | ------- | +| `reward` | Main scalar reward (weighted sum of criteria) | +| `accuracy` | Exact match on target answer | + diff --git a/environments/mcp_envs/environments/stdio_mcp_env/pyproject.toml b/environments/mcp_envs/environments/stdio_mcp_env/pyproject.toml new file mode 100644 index 000000000..9c4d5ff81 --- /dev/null +++ b/environments/mcp_envs/environments/stdio_mcp_env/pyproject.toml @@ -0,0 +1,23 @@ +[project] +name = "stdio-mcp-env" +description = "Your environment description here" +tags = ["placeholder-tag", "train", "eval"] +version = "0.1.0" +requires-python = ">=3.10" +dependencies = [ + "verifiers>=0.1.7.post0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +include = ["stdio_mcp_env.py", "pyproject.toml"] + +[tool.verifiers.eval] +num_examples = 5 +rollouts_per_example = 3 + +[tool.uv.sources] +verifiers = { path = "../../../../" } diff --git a/environments/mcp_envs/environments/stdio_mcp_env/stdio_mcp_env.py b/environments/mcp_envs/environments/stdio_mcp_env/stdio_mcp_env.py new file mode 100644 index 000000000..f1f7e7c6f --- /dev/null +++ b/environments/mcp_envs/environments/stdio_mcp_env/stdio_mcp_env.py @@ -0,0 +1,31 @@ +import verifiers as vf +from datasets import Dataset + +def load_environment(**kwargs): + ds = Dataset.from_dict( + { + "question": [ + "Find out what Prime Intellect's newest announcement was from their website, give me the headline in 2 words. Their url is primeintellect.ai", + ], + "answer": ["ENVIRONMENTS HUB"], + } + ) + + rub = vf.JudgeRubric(judge_model="gpt-4.1-mini") + + async def judge_reward(judge, prompt, completion, answer, state): + judge_response = await judge(prompt, completion, answer, state) + return 1.0 if "yes" in judge_response.lower() else 0.0 + + env = vf.MCPEnv( + mcp_servers=[ + {"name": "fetch", "command": "uvx", "args": ["mcp-server-fetch"]} + ], + transport_type="stdio", + connection_scope="rollout", + dataset=ds, + rubric=rub, + max_turns=10 + ) + + return env diff --git a/public/mcp-envs.png b/public/mcp-envs.png new file mode 100644 index 000000000..bdd64cfe7 Binary files /dev/null and b/public/mcp-envs.png differ diff --git a/verifiers/__init__.py b/verifiers/__init__.py index bc73f058b..84fc50a65 100644 --- a/verifiers/__init__.py +++ b/verifiers/__init__.py @@ -12,6 +12,7 @@ from .envs.singleturn_env import SingleTurnEnv from .envs.stateful_tool_env import StatefulToolEnv from .envs.tool_env import ToolEnv +from .envs.mcp_env import MCPEnv from .parsers.maybe_think_parser import MaybeThinkParser from .parsers.parser import Parser from .parsers.think_parser import ThinkParser @@ -85,6 +86,7 @@ def setup_logging( "SandboxEnv", "StatefulToolEnv", "ToolEnv", + "MCPEnv", "EnvGroup", "extract_boxed_answer", "extract_hash_answer", @@ -116,6 +118,7 @@ def setup_logging( "PythonEnv": "verifiers.envs.python_env:PythonEnv", "ReasoningGymEnv": "verifiers.envs.reasoninggym_env:ReasoningGymEnv", "TextArenaEnv": "verifiers.envs.textarena_env:TextArenaEnv", + "MCPEnv": "verifiers.envs.mcp_env:MCPEnv", } @@ -137,6 +140,7 @@ def __getattr__(name: str): from .envs.reasoninggym_env import ReasoningGymEnv # noqa: F401 from .envs.sandbox_env import SandboxEnv # noqa: F401 from .envs.textarena_env import TextArenaEnv # noqa: F401 + from .envs.mcp_env import MCPEnv # noqa: F401 from .rl.trainer import ( # noqa: F401 GRPOConfig, GRPOTrainer, diff --git a/verifiers/envs/mcp_env.py b/verifiers/envs/mcp_env.py new file mode 100644 index 000000000..f2300d3a4 --- /dev/null +++ b/verifiers/envs/mcp_env.py @@ -0,0 +1,190 @@ +from typing import Callable, Dict, List, Literal, Optional + +from verifiers.utils.mcp_utils.transports.base import MCPTransport +from verifiers.utils.mcp_utils.models import MCPServerConfig, MCPTransportConfig +from verifiers.utils.mcp_utils.mcp_tool_wrapper import MCPToolWrapper +from verifiers.utils.mcp_utils.mcp_env_utils import validate_config, create_transport + +import verifiers as vf +from verifiers import Messages, State + +class MCPEnv(vf.StatefulToolEnv): + def __init__( + self, + mcp_servers: List[MCPServerConfig | dict], + # transport configuration + transport_type: Literal["stdio", "http", "sandbox"] = "stdio", + # use new connections for each rollout or keep connections alive for the entire session + connection_scope: Literal["rollout", "session"] = "rollout", + # http specific + http_urls: Optional[Dict[str, str]] = None, + http_timeout: float = 30.0, + http_max_retries: int = 3, + # sandbox specific + sandbox_image: str = "python:3.11-slim", + sandbox_start_command: str = "tail -f /dev/null", + sandbox_cpu_cores: int = 1, + sandbox_memory_gb: int = 2, + sandbox_disk_size_gb: int = 5, + sandbox_timeout_minutes: int = 60, + sandbox_environment_vars: Optional[Dict[str, str]] = None, + sandbox_port_to_expose: Optional[int] = None, + # standard options + max_turns: int = 10, + error_formatter: Callable[[Exception], str] = lambda e: f"Error: {str(e)}", + **kwargs + ): + self.mcp_servers = [ + MCPServerConfig(**s) if isinstance(s, dict) else s + for s in mcp_servers + ] + + self.transport_type = transport_type + self.connection_scope = connection_scope + self.http_urls = http_urls or {} + self.http_timeout = http_timeout + self.http_max_retries = http_max_retries + + self.sandbox_image = sandbox_image + self.sandbox_start_command = sandbox_start_command + self.sandbox_cpu_cores = sandbox_cpu_cores + self.sandbox_memory_gb = sandbox_memory_gb + self.sandbox_disk_size_gb = sandbox_disk_size_gb + self.sandbox_timeout_minutes = sandbox_timeout_minutes + self.sandbox_environment_vars = sandbox_environment_vars or {} + self.sandbox_port_to_expose = sandbox_port_to_expose + + validate_config(transport_type, self.mcp_servers, self.connection_scope) + + self.session_transports: Dict[str, MCPTransport] = {} + + super().__init__( + tools=[], + max_turns=max_turns, + error_formatter=error_formatter, + **kwargs + ) + + async def register_tools_from_transport( + self, + server_name: str, + transport: MCPTransport + ): + for tool in transport.tools.values(): + wrapper = MCPToolWrapper(server_name, tool, transport) + self.tools.append(wrapper) + oai_tool = wrapper.to_oai_tool() + if self.oai_tools is None: + self.oai_tools = [] + self.oai_tools.append(oai_tool) + tool_name = wrapper.__name__ + self.tool_map[tool_name] = wrapper + if not hasattr(self, 'skipped_args'): + self.skipped_args = {} + self.skipped_args[tool_name] = [] + + async def setup_session_connections(self): + for server_config in self.mcp_servers: + transport_config = MCPTransportConfig( + server_config=server_config, + transport_type=self.transport_type, + http_urls=self.http_urls, + http_timeout=self.http_timeout, + http_max_retries=self.http_max_retries, + sandbox_image=self.sandbox_image, + sandbox_start_command=self.sandbox_start_command, + sandbox_cpu_cores=self.sandbox_cpu_cores, + sandbox_memory_gb=self.sandbox_memory_gb, + sandbox_disk_size_gb=self.sandbox_disk_size_gb, + sandbox_timeout_minutes=self.sandbox_timeout_minutes, + sandbox_environment_vars=self.sandbox_environment_vars, + sandbox_port_to_expose=self.sandbox_port_to_expose, + ) + transport = await create_transport(transport_config) + if self.transport_type == "sandbox": + await transport.create_sandbox() + await transport.run_setup_commands() + await transport.start_mcp_server() + await transport.expose_port() + + await transport.connect() + self.session_transports[server_config.name] = transport + await self.register_tools_from_transport(server_config.name, transport) + + async def setup_state(self, state: State, **kwargs) -> State: + state = await super().setup_state(state, **kwargs) + + if self.connection_scope == "session": + if not self.session_transports: + await self.setup_session_connections() + + elif self.connection_scope == "rollout": + rollout_transports = {} + + for server_config in self.mcp_servers: + transport_config = MCPTransportConfig( + server_config=server_config, + transport_type=self.transport_type, + http_urls=self.http_urls, + http_timeout=self.http_timeout, + http_max_retries=self.http_max_retries, + sandbox_image=self.sandbox_image, + sandbox_start_command=self.sandbox_start_command, + sandbox_cpu_cores=self.sandbox_cpu_cores, + sandbox_memory_gb=self.sandbox_memory_gb, + sandbox_disk_size_gb=self.sandbox_disk_size_gb, + sandbox_timeout_minutes=self.sandbox_timeout_minutes, + sandbox_environment_vars=self.sandbox_environment_vars, + sandbox_port_to_expose=self.sandbox_port_to_expose, + ) + transport = await create_transport(transport_config) + + if self.transport_type == "sandbox": + await transport.create_sandbox() + await transport.run_setup_commands() + await transport.start_mcp_server() + await transport.expose_port() + + await transport.connect() + rollout_transports[server_config.name] = transport + await self.register_tools_from_transport(server_config.name, transport) + + state["mcp_transports"] = rollout_transports + + if self.oai_tools: + state["info"]["oai_tools"] = self.oai_tools + + return state + + def update_tool_args( + self, + tool_name: str, + tool_args: dict, + messages: Messages, + state: State, + **kwargs, + ) -> dict: + return tool_args + + async def is_completed( + self, + messages: Messages, + state: State, + **kwargs + ) -> bool: + completed = await super().is_completed(messages, state, **kwargs) + if completed and self.connection_scope == "rollout": + await self._cleanup_rollout_transports(state) + + return completed + + async def _cleanup_rollout_transports(self, state: State): + rollout_transports = state.get("mcp_transports", {}) + for transport in rollout_transports.values(): + await transport.disconnect() + state.pop("mcp_transports", None) + + async def cleanup(self): + for transport in self.session_transports.values(): + await transport.disconnect() + self.session_transports.clear() diff --git a/verifiers/utils/mcp_utils/README.md b/verifiers/utils/mcp_utils/README.md new file mode 100644 index 000000000..03daf4b65 --- /dev/null +++ b/verifiers/utils/mcp_utils/README.md @@ -0,0 +1,34 @@ +# Verifiers MCP Environments + +![MCP Environment Overview](../../../../mcp-env-v2/public/mcp-envs.png) + +The MCP Env abstraction aims to allow easily implementing Verifiers environments that use MCP servers and their corresponding tools instead of manually implementing tools. MCP itself supports a few different paths for server builders depending on their use case and the MCP env should enable many of these setups. + +The transports that MCP supports include: + +- stdio +- http +- sse (depracated) + +## Scenarios + +- stdio + - Small scale, stateless, local use cases can be done easily with the stdio transport. Stdio is meant to be a local communication so server and client are ran locally with the server being run on a background process. +- http + - Any scale, stateless, remote use cases in which the MCP you want to use is provided via some MCP server provider. In this case instead of running the server via a command, you just have to provide the URL and all rollouts can make their connections to the remote server. + - example: instead of running your own MCP Server you can use a provider like smithery or many companies run their own remote MCP servers that you can connect to. +- sandbox + - Any scale, stateful use cases are enabled by the "sandbox" transport option which relies on the Streaming HTTP transport communication in which each sandbox will run its own MCP HTTP Server and each rollout can connect to their corresponding MCP via URL. + - example: say you have a sandbox setup that includes setting up some filesystem or database. using the sandbox transport will allow each rollout to have its own sandbox it can connect to via MCP HTTP Server and perform actions that change the state of the sandbox. + + +## Concerns + +Some things to note that MCP based environment implementations are relient on the MCP Server developers implementation. Some MCP developers have provided ways to run their server with the different methods, some only provide a method to run stdio via uvx/node, and some don't provide either but might require something like cloning a repo and running a file manually. These all effect whether or not they will work as is in an environment implementation so something to think about. If you are the MCP server developer yourself obviously you have the control to enable any setup you choose. + +When it comes to running servers via HTTP it is also up to you to handle any authentication concerns as a running server may be available to anyone to connect to. + + +## TODO + +- a stdio-http bridge that will provide a wrapper function in case an MCP server doesn't provide an easy way to run the server in HTTP mode, you can still use the stdio version that is then made available by a HTTP wrapper. diff --git a/verifiers/utils/mcp_utils/mcp_env_utils.py b/verifiers/utils/mcp_utils/mcp_env_utils.py new file mode 100644 index 000000000..b7361bc09 --- /dev/null +++ b/verifiers/utils/mcp_utils/mcp_env_utils.py @@ -0,0 +1,79 @@ +from typing import Optional, Union +from verifiers.envs.mcp.mcp_utils.models import MCPServerConfig, MCPTransportConfig +from verifiers.envs.mcp.transports.stdio import StdioTransport +from verifiers.envs.mcp.transports.streaming_http import StreamingHTTPTransport +from verifiers.envs.mcp.transports.sandbox import SandboxTransport + +async def create_transport( + transport_config: MCPTransportConfig, +) -> Union[StdioTransport, StreamingHTTPTransport, SandboxTransport]: + if transport_config.transport_type == "stdio": + if not transport_config.server_config.command: + raise ValueError(f"'command' required for stdio transport: {transport_config.server_config.name}") + return StdioTransport(transport_config.server_config) + + elif transport_config.transport_type == "http": + url = get_server_url(transport_config.server_config, transport_config.http_urls or {}) + return StreamingHTTPTransport( + transport_config.server_config, + url=url, + timeout=transport_config.http_timeout, + max_retries=transport_config.http_max_retries + ) + + elif transport_config.transport_type == "sandbox": + from prime_sandboxes import CreateSandboxRequest + if not transport_config.server_config.command: + raise ValueError(f"'command' required for sandbox transport: {transport_config.server_config.name}") + + env_vars = {**(transport_config.sandbox_environment_vars or {}), **(transport_config.server_config.env or {})} + + sandbox_request = CreateSandboxRequest( + name=f"mcp-{transport_config.server_config.name}", + docker_image=transport_config.sandbox_image, + start_command=transport_config.sandbox_start_command, + cpu_cores=transport_config.sandbox_cpu_cores, + memory_gb=transport_config.sandbox_memory_gb, + disk_size_gb=transport_config.sandbox_disk_size_gb, + timeout_minutes=transport_config.sandbox_timeout_minutes, + environment_vars=env_vars, + ) + + return SandboxTransport( + transport_config.server_config, + sandbox_request=sandbox_request, + port_to_expose=transport_config.sandbox_port_to_expose, + timeout=transport_config.http_timeout, + max_retries=transport_config.http_max_retries + ) + + else: + raise ValueError(f"Unknown transport type: {transport_config.transport_type}") + + +def get_server_url(server_config: MCPServerConfig, http_urls: dict) -> str: + """Get URL for a server, checking config.url first, then http_urls dict.""" + if server_config.url: + return server_config.url + if server_config.name in http_urls: + return http_urls[server_config.name] + raise ValueError( + f"No URL found for server '{server_config.name}'. " + f"Provide either 'url' in server config or add to 'http_urls' dict." + ) + +def validate_config(transport_type: str, servers: list[MCPServerConfig], connection_scope: str) -> None: + if transport_type == "stdio": + missing = [s.name for s in servers if not s.command] + if missing: + raise ValueError(f"'command' required for stdio. Missing: {missing}") + + elif transport_type == "http": + missing = [s.name for s in servers if not s.url] + if missing: + raise ValueError(f"'url' required for http. Missing: {missing}") + + elif transport_type == "sandbox": + missing = [s.name for s in servers if not s.command] + if missing: + raise ValueError(f"'command' required for sandbox. Missing: {missing}") \ No newline at end of file diff --git a/verifiers/utils/mcp_utils/mcp_tool_wrapper.py b/verifiers/utils/mcp_utils/mcp_tool_wrapper.py new file mode 100644 index 000000000..4685b9754 --- /dev/null +++ b/verifiers/utils/mcp_utils/mcp_tool_wrapper.py @@ -0,0 +1,99 @@ +from typing import Any +import copy +from mcp.types import Tool + +from verifiers.utils.mcp_utils.transports.base import MCPTransport + + +class MCPToolWrapper: + def __init__(self, server_name: str, tool: Tool, server_connection: MCPTransport): + self.server_name = server_name + self.tool = tool + self.server_connection = server_connection + + self.__name__ = tool.name + self.__doc__ = tool.description or "" + + self.__annotations__ = self._build_annotations() + + def _build_annotations(self) -> dict: + annotations = {} + + if self.tool.inputSchema: + properties = self.tool.inputSchema.get("properties", {}) + + for param_name, param_spec in properties.items(): + param_type = param_spec.get("type", "string") + if param_type == "string": + annotations[param_name] = str + elif param_type == "integer": + annotations[param_name] = int + elif param_type == "number": + annotations[param_name] = float + elif param_type == "boolean": + annotations[param_name] = bool + elif param_type == "array": + annotations[param_name] = list + elif param_type == "object": + annotations[param_name] = dict + else: + annotations[param_name] = Any + + annotations["return"] = str + return annotations + + async def __call__(self, **kwargs): + return await self.server_connection.call_tool(self.tool.name, kwargs) + + def _remove_additional_properties(self, schema: dict) -> dict: + """ + Recursively remove additionalProperties from schema to comply with OpenAI strict mode. + """ + if not isinstance(schema, dict): + return schema + + # Create a copy to avoid modifying the original + schema = dict(schema) + + # Remove additionalProperties at this level + schema.pop("additionalProperties", None) + + # Recursively process nested objects + if "properties" in schema and isinstance(schema["properties"], dict): + schema["properties"] = { + key: self._remove_additional_properties(value) + for key, value in schema["properties"].items() + } + + # Handle arrays + if "items" in schema: + schema["items"] = self._remove_additional_properties(schema["items"]) + + # Handle anyOf, oneOf, allOf + for key in ["anyOf", "oneOf", "allOf"]: + if key in schema and isinstance(schema[key], list): + schema[key] = [ + self._remove_additional_properties(sub_schema) + for sub_schema in schema[key] + ] + + return schema + + def to_oai_tool(self) -> dict: + # Get the input schema and ensure it's OpenAI-compatible + parameters = self.tool.inputSchema or {"type": "object", "properties": {}} + + # Deep copy to avoid modifying the original + parameters = copy.deepcopy(parameters) + + # Remove additionalProperties to comply with OpenAI strict schema + parameters = self._remove_additional_properties(parameters) + + return { + "type": "function", + "function": { + "name": self.__name__, + "description": self.__doc__ or "", + "parameters": parameters, + }, + } diff --git a/verifiers/utils/mcp_utils/models.py b/verifiers/utils/mcp_utils/models.py new file mode 100644 index 000000000..daab86758 --- /dev/null +++ b/verifiers/utils/mcp_utils/models.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Literal, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from prime_sandboxes import AsyncSandboxClient + + +@dataclass +class MCPServerConfig: + name: str + command: Optional[str] = None + args: Optional[List[str]] = None + env: Optional[Dict[str, str]] = None + url: Optional[str] = None # For HTTP transport - can specify here instead of http_urls dict + setup_commands: Optional[List[str]] = None # Commands to run before starting the server (for sandbox transport) + + +@dataclass +class MCPTransportConfig: + server_config: MCPServerConfig + transport_type: Literal["stdio", "http", "sandbox"] + # http specific + http_urls: Optional[Dict[str, str]] = None + http_timeout: float = 30.0 + http_max_retries: int = 3 + # sandbox specific + sandbox_client: Optional["AsyncSandboxClient"] = None + sandbox_image: str = "python:3.11-slim" + sandbox_start_command: str = "tail -f /dev/null" + sandbox_environment_vars: Optional[Dict[str, str]] = None + sandbox_cpu_cores: int = 1 + sandbox_memory_gb: int = 2 + sandbox_disk_size_gb: int = 5 + sandbox_timeout_minutes: int = 60 + sandbox_port_to_expose: Optional[int] = 8000 diff --git a/verifiers/utils/mcp_utils/transports/base.py b/verifiers/utils/mcp_utils/transports/base.py new file mode 100644 index 000000000..95d91b2ef --- /dev/null +++ b/verifiers/utils/mcp_utils/transports/base.py @@ -0,0 +1,28 @@ +from abc import ABC, abstractmethod +from typing import Dict +from mcp.types import Tool + +class MCPTransport(ABC): + """Base transport for MCP connections.""" + + @abstractmethod + async def connect(self) -> Dict[str, Tool]: + """Connect and return available tools.""" + pass + + @abstractmethod + async def call_tool(self, tool_name: str, arguments: dict) -> str: + """Execute a tool call""" + pass + + @abstractmethod + async def disconnect(self) -> None: + """Cleanup connection""" + pass + + @abstractmethod + async def is_connected(self) -> bool: + """Check connection status""" + pass + + diff --git a/verifiers/utils/mcp_utils/transports/sandbox.py b/verifiers/utils/mcp_utils/transports/sandbox.py new file mode 100644 index 000000000..06139d71f --- /dev/null +++ b/verifiers/utils/mcp_utils/transports/sandbox.py @@ -0,0 +1,115 @@ +import asyncio +import atexit +import signal +from typing import Dict, Optional +from mcp.types import Tool +from prime_sandboxes import SandboxClient, APIClient + +from verifiers.utils.mcp_utils.transports.streaming_http import StreamingHTTPTransport +from verifiers.utils.mcp_utils.models import MCPServerConfig + +_active_sandboxes = set() +_cleanup_registered = False + +def _register_cleanup(): + global _cleanup_registered + if _cleanup_registered: + return + + def cleanup(): + if not _active_sandboxes: + return + client = SandboxClient(APIClient()) + for sandbox_id in list(_active_sandboxes): + try: + client.delete(sandbox_id) + except Exception: + pass + _active_sandboxes.clear() + + atexit.register(cleanup) + signal.signal(signal.SIGINT, lambda s, f: (cleanup(), signal.default_int_handler(s, f))) + signal.signal(signal.SIGTERM, lambda s, f: (cleanup(), exit(143))) + _cleanup_registered = True + + + +class SandboxTransport(StreamingHTTPTransport): + _async_client: Optional["AsyncSandboxClient"] = None + + @classmethod + def get_client(cls) -> "AsyncSandboxClient": + if cls._async_client is None: + from prime_sandboxes import AsyncSandboxClient + cls._async_client = AsyncSandboxClient() + _register_cleanup() + return cls._async_client + + def __init__( + self, + config: MCPServerConfig, + sandbox_request, + port_to_expose: Optional[int] = None, + **kwargs + ): + self.sandbox_request = sandbox_request + self.sandbox_id: Optional[str] = None + self.port_to_expose: Optional[int] = port_to_expose + super().__init__(config, url="", **kwargs) + + async def create_sandbox(self) -> str: + client = self.get_client() + sandbox = await client.create(self.sandbox_request) + self.sandbox_id = sandbox.id + await client.wait_for_creation(self.sandbox_id) + _active_sandboxes.add(self.sandbox_id) + return self.sandbox_id + + async def run_command(self, command: str) -> str: + if not self.sandbox_id: + raise RuntimeError("Sandbox not created yet") + client = self.get_client() + result = await client.execute_command(self.sandbox_id, command) + return result + + async def run_setup_commands(self) -> None: + if not self.config.setup_commands: + return + for cmd in self.config.setup_commands: + await self.run_command(cmd) + + async def start_mcp_server(self) -> None: + """Start the MCP server process in the sandbox.""" + cmd = f"{self.config.command} {' '.join(self.config.args or [])}" + bg_cmd = f"nohup {cmd} > /tmp/mcp.log 2>&1 &" + await self.run_command(bg_cmd) + await asyncio.sleep(3) + + async def expose_port(self) -> str: + client = self.get_client() + exposed = await client.expose( + self.sandbox_id, + self.port_to_expose + ) + self.url = f"{exposed.url}/mcp" + # TODO: remove this when we have a better way to wait for the port to be exposed + await asyncio.sleep(10) + return self.url + + async def connect(self) -> Dict[str, Tool]: + return await super().connect() + + async def disconnect(self) -> None: + try: + await super().disconnect() + finally: + if self.sandbox_id: + client = self.get_client() + await client.delete(self.sandbox_id) + _active_sandboxes.discard(self.sandbox_id) + self.sandbox_id = None + + async def delete_all_sandboxes(self) -> None: + client = self.get_client() + await client.bulk_delete(list(_active_sandboxes)) + _active_sandboxes.clear() diff --git a/verifiers/utils/mcp_utils/transports/stdio.py b/verifiers/utils/mcp_utils/transports/stdio.py new file mode 100644 index 000000000..fd9a862f5 --- /dev/null +++ b/verifiers/utils/mcp_utils/transports/stdio.py @@ -0,0 +1,103 @@ +import asyncio +from typing import Dict, Optional + +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client +from mcp.types import TextContent, Tool + +from verifiers.utils.mcp_utils.transports.base import MCPTransport +from verifiers.utils.mcp_utils.models import MCPServerConfig + + +class StdioTransport(MCPTransport): + def __init__(self, config: MCPServerConfig): + """Initialize the stdio transport with MCP server configuration.""" + self.config = config + self.session: Optional[ClientSession] = None + self.tools: Dict[str, Tool] = {} + self._connection_task: Optional[asyncio.Task] = None + self._ready = asyncio.Event() + self._error: Optional[Exception] = None + + async def connect(self) -> Dict[str, Tool]: + """Connect by creating a background task that manages the MCP server lifecycle.""" + self._connection_task = asyncio.create_task(self._maintain_connection()) + await self._ready.wait() + + if self._error: + raise self._error + + return self.tools + + async def _maintain_connection(self): + """Background task that maintains the stdio connection.""" + try: + server_params = StdioServerParameters( + command=self.config.command, + args=self.config.args or [], + env=self.config.env, + ) + + # Context managers stay within this single task + async with stdio_client(server_params) as (read, write): + async with ClientSession(read, write) as session: + self.session = session + await session.initialize() + + tools_response = await session.list_tools() + self.tools = {tool.name: tool for tool in tools_response.tools} + + self._ready.set() + + # Keep alive until cancelled + try: + while True: + await asyncio.sleep(1) + except asyncio.CancelledError: + pass # Normal shutdown + + except asyncio.CancelledError: + pass # Normal cancellation during shutdown + except Exception as e: + self._error = e + self._ready.set() + finally: + self.session = None + self.tools = {} + + async def call_tool(self, tool_name: str, arguments: dict) -> str: + """Execute a tool call through the MCP session.""" + if not self.session: + raise RuntimeError(f"Server '{self.config.name}' not connected") + + result = await self.session.call_tool(tool_name, arguments=arguments) + + if result.content: + text_parts = [] + for content_item in result.content: + if isinstance(content_item, TextContent): + text_parts.append(content_item.text) + else: + text_parts.append(str(content_item)) + return "\n".join(text_parts) + + return "No result returned from tool" + + async def disconnect(self) -> None: + """Disconnect by cancelling the background connection task.""" + if self._connection_task and not self._connection_task.done(): + self._connection_task.cancel() + try: + await self._connection_task + except asyncio.CancelledError: + pass + + self.session = None + self.tools = {} + self._ready.clear() + + async def is_connected(self) -> bool: + """Check if the transport is currently connected.""" + return self.session is not None and ( + self._connection_task is not None and not self._connection_task.done() + ) diff --git a/verifiers/utils/mcp_utils/transports/streaming_http.py b/verifiers/utils/mcp_utils/transports/streaming_http.py new file mode 100644 index 000000000..d2692eb70 --- /dev/null +++ b/verifiers/utils/mcp_utils/transports/streaming_http.py @@ -0,0 +1,117 @@ +import asyncio +from typing import Dict, Optional + +from mcp import ClientSession +from mcp.client.streamable_http import streamablehttp_client +from mcp.types import TextContent, Tool + +from verifiers.utils.mcp_utils.transports.base import MCPTransport +from verifiers.utils.mcp_utils.models import MCPServerConfig + + +class StreamingHTTPTransport(MCPTransport): + def __init__( + self, + config: MCPServerConfig, + url: str, + timeout: float = 30.0, + max_retries: int = 3 + ): + """Initialize the Streamable HTTP transport.""" + self.config = config + self.url = url + self.timeout = timeout + self.max_retries = max_retries + self.session: Optional[ClientSession] = None + self.tools: Dict[str, Tool] = {} + self._connection_task: Optional[asyncio.Task] = None + self._ready = asyncio.Event() + self._error: Optional[Exception] = None + + async def connect(self) -> Dict[str, Tool]: + """Connect by creating a background task that manages the HTTP connection.""" + self._connection_task = asyncio.create_task(self._maintain_connection()) + await self._ready.wait() + + if self._error: + raise self._error + + return self.tools + + async def _maintain_connection(self): + """Background task that maintains the Streamable HTTP connection.""" + last_error = None + + for attempt in range(self.max_retries): + try: + async with streamablehttp_client(self.url) as (read, write, _): + async with ClientSession(read, write) as session: + self.session = session + await session.initialize() + + tools_response = await session.list_tools() + self.tools = {tool.name: tool for tool in tools_response.tools} + + self._ready.set() + + while True: + await asyncio.sleep(1) + + except asyncio.CancelledError: + # Normal cancellation during shutdown + break + except Exception as e: + last_error = e + if attempt < self.max_retries - 1: + # Exponential backoff + await asyncio.sleep(1.0 * (attempt + 1)) + continue + else: + # Final attempt failed + self._error = ConnectionError( + f"Failed to connect to MCP server at {self.url} after {self.max_retries} attempts. " + f"Ensure the server is running and accessible. Last error: {str(e)}" + ) + self._ready.set() + break + + # Cleanup + self.session = None + self.tools = {} + + async def call_tool(self, tool_name: str, arguments: dict) -> str: + """Execute a tool call through the MCP session.""" + if not self.session: + raise RuntimeError(f"Server '{self.config.name}' not connected") + + result = await self.session.call_tool(tool_name, arguments=arguments) + + if result.content: + text_parts = [] + for content_item in result.content: + if isinstance(content_item, TextContent): + text_parts.append(content_item.text) + else: + text_parts.append(str(content_item)) + return "\n".join(text_parts) + + return "No result returned from tool" + + async def disconnect(self) -> None: + """Disconnect by cancelling the background connection task.""" + if self._connection_task and not self._connection_task.done(): + self._connection_task.cancel() + try: + await self._connection_task + except asyncio.CancelledError: + pass + + self.session = None + self.tools = {} + self._ready.clear() + + async def is_connected(self) -> bool: + """Check if the transport is currently connected.""" + return self.session is not None and ( + self._connection_task is not None and not self._connection_task.done() + ) diff --git a/verifiers/utils/mcp_utils/transports/synthetic_transport.py b/verifiers/utils/mcp_utils/transports/synthetic_transport.py new file mode 100644 index 000000000..e7615de63 --- /dev/null +++ b/verifiers/utils/mcp_utils/transports/synthetic_transport.py @@ -0,0 +1,397 @@ +""" +Synthetic MCP Transport for testing with fake backends. + +Instead of connecting to real MCP servers (like Zapier/Airtable), +this transport intercepts tool calls and routes them to in-memory handlers. +""" + +import json +from typing import Dict, Callable, Any, Optional +from dataclasses import dataclass, field +from mcp.types import Tool + +# Import the base transport interface +import sys +sys.path.insert(0, '/mnt/user-data/uploads') + +from verifiers.envs.mcp.transports.base import MCPTransport + + +class SyntheticTransport(MCPTransport): + """ + A mock transport that intercepts tool calls and routes them to synthetic handlers. + + Usage: + # Define your synthetic data + data = { + "candidates": [ + {"id": "1", "name": "Alice", "status": "active"}, + {"id": "2", "name": "Bob", "status": "interviewing"}, + ] + } + + # Define handlers for each tool + handlers = { + "search_candidates": lambda data, args: json.dumps(data["candidates"]), + "get_candidate": lambda data, args: json.dumps( + next((c for c in data["candidates"] if c["id"] == args["id"]), None) + ), + } + + # Create transport + transport = SyntheticTransport( + tools=create_tool_definitions(), + handlers=handlers, + data=data + ) + """ + + def __init__( + self, + tools: Dict[str, Tool], + handlers: Dict[str, Callable[[dict, dict], str]], + data: Optional[dict] = None, + name: str = "synthetic" + ): + """ + Args: + tools: Dict mapping tool names to MCP Tool definitions + handlers: Dict mapping tool names to handler functions. + Each handler receives (data, arguments) and returns a string. + data: The synthetic data store that handlers can read/write + name: Name for this transport (for logging) + """ + self._tools = tools + self.handlers = handlers + self.data = data if data is not None else {} + self.name = name + self._connected = False + + @property + def tools(self) -> Dict[str, Tool]: + return self._tools + + async def connect(self) -> Dict[str, Tool]: + """No-op connect - we're already 'connected' to our in-memory data.""" + self._connected = True + return self._tools + + async def call_tool(self, tool_name: str, arguments: dict) -> str: + """Route the tool call to the appropriate handler.""" + if not self._connected: + raise RuntimeError(f"Transport '{self.name}' not connected") + + if tool_name not in self.handlers: + raise ValueError( + f"No handler registered for tool '{tool_name}'. " + f"Available handlers: {list(self.handlers.keys())}" + ) + + handler = self.handlers[tool_name] + + # Call handler - it can be sync or async + result = handler(self.data, arguments) + + # If handler is async, await it + if hasattr(result, '__await__'): + result = await result + + return str(result) + + async def disconnect(self) -> None: + """No-op disconnect.""" + self._connected = False + + async def is_connected(self) -> bool: + return self._connected + + +# ============================================================================= +# Helper to create Tool definitions programmatically +# ============================================================================= + +def create_tool( + name: str, + description: str, + parameters: Dict[str, Any], + required: Optional[list] = None +) -> Tool: + """ + Helper to create an MCP Tool definition. + + Args: + name: Tool name + description: Tool description + parameters: Dict of parameter_name -> {"type": ..., "description": ...} + required: List of required parameter names + """ + properties = {} + for param_name, param_spec in parameters.items(): + if isinstance(param_spec, str): + # Simple type string + properties[param_name] = {"type": param_spec} + else: + # Full spec dict + properties[param_name] = param_spec + + input_schema = { + "type": "object", + "properties": properties, + } + if required: + input_schema["required"] = required + + return Tool( + name=name, + description=description, + inputSchema=input_schema + ) + + +# ============================================================================= +# Example: Synthetic Airtable-like backend +# ============================================================================= + +class SyntheticAirtable: + """ + A synthetic Airtable-like data store with common operations. + + Usage: + airtable = SyntheticAirtable() + + # Add some tables with data + airtable.add_table("Candidates", [ + {"id": "rec1", "Name": "Alice", "Status": "Active", "Role": "Engineer"}, + {"id": "rec2", "Name": "Bob", "Status": "Interviewing", "Role": "Designer"}, + ]) + + # Get tools and handlers for use with SyntheticTransport + tools = airtable.get_tools() + handlers = airtable.get_handlers() + + transport = SyntheticTransport(tools, handlers, airtable.data) + """ + + def __init__(self): + self.data: Dict[str, list] = {} # table_name -> list of records + self._record_counter = 0 + + def add_table(self, table_name: str, records: list): + """Add a table with initial records.""" + # Ensure each record has an id + for record in records: + if "id" not in record: + self._record_counter += 1 + record["id"] = f"rec{self._record_counter}" + self.data[table_name] = records + + def get_tools(self) -> Dict[str, Tool]: + """Return MCP Tool definitions for Airtable-like operations.""" + return { + "list_records": create_tool( + name="list_records", + description="List all records in a table", + parameters={ + "table_name": { + "type": "string", + "description": "Name of the table to list records from" + }, + "max_records": { + "type": "integer", + "description": "Maximum number of records to return" + } + }, + required=["table_name"] + ), + "search_records": create_tool( + name="search_records", + description="Search for records matching a query string", + parameters={ + "table_name": { + "type": "string", + "description": "Name of the table to search" + }, + "query": { + "type": "string", + "description": "Search query - matches against all fields" + }, + "field": { + "type": "string", + "description": "Optional: specific field to search in" + } + }, + required=["table_name", "query"] + ), + "get_record": create_tool( + name="get_record", + description="Get a specific record by ID", + parameters={ + "table_name": { + "type": "string", + "description": "Name of the table" + }, + "record_id": { + "type": "string", + "description": "ID of the record to retrieve" + } + }, + required=["table_name", "record_id"] + ), + "create_record": create_tool( + name="create_record", + description="Create a new record in a table", + parameters={ + "table_name": { + "type": "string", + "description": "Name of the table" + }, + "fields": { + "type": "object", + "description": "Record fields as key-value pairs" + } + }, + required=["table_name", "fields"] + ), + "update_record": create_tool( + name="update_record", + description="Update an existing record", + parameters={ + "table_name": { + "type": "string", + "description": "Name of the table" + }, + "record_id": { + "type": "string", + "description": "ID of the record to update" + }, + "fields": { + "type": "object", + "description": "Fields to update as key-value pairs" + } + }, + required=["table_name", "record_id", "fields"] + ), + "count_records": create_tool( + name="count_records", + description="Count records in a table, optionally matching a query", + parameters={ + "table_name": { + "type": "string", + "description": "Name of the table" + }, + "query": { + "type": "string", + "description": "Optional: only count records matching this query" + } + }, + required=["table_name"] + ), + } + + def get_handlers(self) -> Dict[str, Callable]: + """Return handler functions for each tool.""" + return { + "list_records": self._handle_list_records, + "search_records": self._handle_search_records, + "get_record": self._handle_get_record, + "create_record": self._handle_create_record, + "update_record": self._handle_update_record, + "count_records": self._handle_count_records, + } + + def _handle_list_records(self, data: dict, args: dict) -> str: + table_name = args.get("table_name") + max_records = args.get("max_records") + + if table_name not in data: + return json.dumps({"error": f"Table '{table_name}' not found"}) + + records = data[table_name] + if max_records: + records = records[:max_records] + + return json.dumps({"records": records, "total": len(data[table_name])}) + + def _handle_search_records(self, data: dict, args: dict) -> str: + table_name = args.get("table_name") + query = args.get("query", "").lower() + field = args.get("field") + + if table_name not in data: + return json.dumps({"error": f"Table '{table_name}' not found"}) + + results = [] + for record in data[table_name]: + if field: + # Search specific field + value = str(record.get(field, "")).lower() + if query in value: + results.append(record) + else: + # Search all fields + for value in record.values(): + if query in str(value).lower(): + results.append(record) + break + + return json.dumps({"records": results, "count": len(results)}) + + def _handle_get_record(self, data: dict, args: dict) -> str: + table_name = args.get("table_name") + record_id = args.get("record_id") + + if table_name not in data: + return json.dumps({"error": f"Table '{table_name}' not found"}) + + for record in data[table_name]: + if record.get("id") == record_id: + return json.dumps({"record": record}) + + return json.dumps({"error": f"Record '{record_id}' not found"}) + + def _handle_create_record(self, data: dict, args: dict) -> str: + table_name = args.get("table_name") + fields = args.get("fields", {}) + + if table_name not in data: + return json.dumps({"error": f"Table '{table_name}' not found"}) + + self._record_counter += 1 + new_record = {"id": f"rec{self._record_counter}", **fields} + data[table_name].append(new_record) + + return json.dumps({"record": new_record, "created": True}) + + def _handle_update_record(self, data: dict, args: dict) -> str: + table_name = args.get("table_name") + record_id = args.get("record_id") + fields = args.get("fields", {}) + + if table_name not in data: + return json.dumps({"error": f"Table '{table_name}' not found"}) + + for record in data[table_name]: + if record.get("id") == record_id: + record.update(fields) + return json.dumps({"record": record, "updated": True}) + + return json.dumps({"error": f"Record '{record_id}' not found"}) + + def _handle_count_records(self, data: dict, args: dict) -> str: + table_name = args.get("table_name") + query = args.get("query", "").lower() if args.get("query") else None + + if table_name not in data: + return json.dumps({"error": f"Table '{table_name}' not found"}) + + if not query: + count = len(data[table_name]) + else: + count = 0 + for record in data[table_name]: + for value in record.values(): + if query in str(value).lower(): + count += 1 + break + + return json.dumps({"count": count})