Skip to content

Commit f9f5465

Browse files
committed
Add client_provided_arg_model to FuncMetadata
1 parent babb477 commit f9f5465

File tree

5 files changed

+171
-12
lines changed

5 files changed

+171
-12
lines changed

src/mcp/server/fastmcp/tools/base.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
from pydantic import BaseModel, Field
88

99
from mcp.server.fastmcp.exceptions import ToolError
10-
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
10+
from mcp.server.fastmcp.utilities.func_metadata import (
11+
FuncMetadata,
12+
filter_args_by_arg_model,
13+
func_metadata,
14+
)
1115

1216
if TYPE_CHECKING:
1317
from mcp.server.fastmcp.server import Context
@@ -85,10 +89,15 @@ async def run(
8589
return await self.fn_metadata.call_fn_with_arg_validation(
8690
self.fn,
8791
self.is_async,
88-
arguments,
89-
{self.context_kwarg: context}
90-
if self.context_kwarg is not None
91-
else None,
92+
filter_args_by_arg_model(arguments, self.fn_metadata.arg_model),
93+
filter_args_by_arg_model(
94+
arguments, self.fn_metadata.client_provided_arg_model
95+
)
96+
| (
97+
{self.context_kwarg: context}
98+
if self.context_kwarg is not None
99+
else {}
100+
),
92101
)
93102
except Exception as e:
94103
raise ToolError(f"Error executing tool {self.name}: {e}") from e

src/mcp/server/fastmcp/utilities/func_metadata.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@
1818
logger = get_logger(__name__)
1919

2020

21+
class ClientProvidedArg:
22+
"""A class to annotate an argument that is to be provided by client at call
23+
time and to be skipped from JSON schema generation."""
24+
25+
def __init__(self):
26+
pass
27+
28+
2129
class ArgModelBase(BaseModel):
2230
"""A model representing the arguments to a function."""
2331

@@ -36,8 +44,42 @@ def model_dump_one_level(self) -> dict[str, Any]:
3644
)
3745

3846

47+
def filter_args_by_arg_model(
48+
arguments: dict[str, Any], model_filter: type[ArgModelBase] | None = None
49+
) -> dict[str, Any]:
50+
"""Filter the arguments dictionary to only include keys that are present in
51+
`model_filter`."""
52+
if not model_filter:
53+
return arguments
54+
filtered_args: dict[str, Any] = {}
55+
for key in arguments.keys():
56+
if key in model_filter.model_fields.keys():
57+
filtered_args[key] = arguments[key]
58+
return filtered_args
59+
60+
3961
class FuncMetadata(BaseModel):
62+
"""Metadata about a function, including Pydantic models for argument validation.
63+
64+
This class manages the arguments required by a function, separating them into two
65+
categories:
66+
67+
* `arg_model`: A Pydantic model representing the function's standard arguments.
68+
These arguments will be included in the JSON schema when the tool is listed,
69+
allowing for automatic argument parsing. This defines the structure of the
70+
expected input.
71+
72+
* `client_provided_arg_model` (Optional): A Pydantic model representing arguments
73+
that need to be provided directly by the client and will not be included in the
74+
JSON schema.
75+
76+
"""
77+
4078
arg_model: Annotated[type[ArgModelBase], WithJsonSchema(None)]
79+
client_provided_arg_model: (
80+
Annotated[type[ArgModelBase], WithJsonSchema(None)] | None
81+
) = None
82+
4183
# We can add things in the future like
4284
# - Maybe some args are excluded from attempting to parse from JSON
4385
# - Maybe some args are special (like context) for dependency injection
@@ -127,7 +169,8 @@ def func_metadata(
127169
"""
128170
sig = _get_typed_signature(func)
129171
params = sig.parameters
130-
dynamic_pydantic_model_params: dict[str, Any] = {}
172+
dynamic_pydantic_arg_model_params: dict[str, Any] = {}
173+
dynamic_pydantic_client_provided_arg_model_params: dict[str, Any] = {}
131174
globalns = getattr(func, "__globals__", {})
132175
for param in params.values():
133176
if param.name.startswith("_"):
@@ -164,15 +207,34 @@ def func_metadata(
164207
if param.default is not inspect.Parameter.empty
165208
else PydanticUndefined,
166209
)
167-
dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info)
168-
continue
210+
211+
# loop through annotations,
212+
# use ClientProvidedArg metadata to split the arguments
213+
if any(isinstance(m, ClientProvidedArg) for m in field_info.metadata):
214+
dynamic_pydantic_client_provided_arg_model_params[param.name] = (
215+
field_info.annotation,
216+
field_info,
217+
)
218+
else:
219+
dynamic_pydantic_arg_model_params[param.name] = (
220+
field_info.annotation,
221+
field_info,
222+
)
169223

170224
arguments_model = create_model(
171225
f"{func.__name__}Arguments",
172-
**dynamic_pydantic_model_params,
226+
**dynamic_pydantic_arg_model_params,
173227
__base__=ArgModelBase,
174228
)
175-
resp = FuncMetadata(arg_model=arguments_model)
229+
230+
provided_arguments_model = create_model(
231+
f"{func.__name__}ClientProvidedArguments",
232+
**dynamic_pydantic_client_provided_arg_model_params,
233+
__base__=ArgModelBase,
234+
)
235+
resp = FuncMetadata(
236+
arg_model=arguments_model, client_provided_arg_model=provided_arguments_model
237+
)
176238
return resp
177239

178240

tests/server/fastmcp/test_func_metadata.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
from pydantic import BaseModel, Field
66

7-
from mcp.server.fastmcp.utilities.func_metadata import func_metadata
7+
from mcp.server.fastmcp.utilities.func_metadata import ClientProvidedArg, func_metadata
88

99

1010
class SomeInputModelA(BaseModel):
@@ -414,3 +414,38 @@ def func_with_str_and_int(a: str, b: int):
414414
result = meta.pre_parse_json({"a": "123", "b": 123})
415415
assert result["a"] == "123"
416416
assert result["b"] == 123
417+
418+
419+
def test_func_with_client_provided_args():
420+
"""Test that client-provided arguments are correctly parsed and validated"""
421+
422+
def func_with_client_provided_args(
423+
a: int,
424+
b: str,
425+
c: Annotated[int, ClientProvidedArg()],
426+
d: Annotated[str, ClientProvidedArg()],
427+
):
428+
return a, b, c, d
429+
430+
meta = func_metadata(func_with_client_provided_args)
431+
432+
# Test schema
433+
assert meta.arg_model.model_json_schema() == {
434+
"properties": {
435+
"a": {"title": "A", "type": "integer"},
436+
"b": {"title": "B", "type": "string"},
437+
},
438+
"required": ["a", "b"],
439+
"title": "func_with_client_provided_argsArguments",
440+
"type": "object",
441+
}
442+
assert meta.client_provided_arg_model is not None
443+
assert meta.client_provided_arg_model.model_json_schema() == {
444+
"properties": {
445+
"c": {"title": "C", "type": "integer"},
446+
"d": {"title": "D", "type": "string"},
447+
},
448+
"required": ["c", "d"],
449+
"title": "func_with_client_provided_argsClientProvidedArguments",
450+
"type": "object",
451+
}

tests/server/fastmcp/test_server.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import base64
22
from pathlib import Path
3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Annotated
44

55
import pytest
66
from pydantic import AnyUrl
77

88
from mcp.server.fastmcp import Context, FastMCP
99
from mcp.server.fastmcp.prompts.base import EmbeddedResource, Message, UserMessage
1010
from mcp.server.fastmcp.resources import FileResource, FunctionResource
11+
from mcp.server.fastmcp.utilities.func_metadata import ClientProvidedArg
1112
from mcp.server.fastmcp.utilities.types import Image
1213
from mcp.shared.exceptions import McpError
1314
from mcp.shared.memory import (
@@ -106,6 +107,12 @@ def tool_fn(x: int, y: int) -> int:
106107
return x + y
107108

108109

110+
def tool_with_client_provided_args_fn(
111+
x: int, y: Annotated[int, ClientProvidedArg()], z: str
112+
) -> str:
113+
return f"{x} + {y} = {z}"
114+
115+
109116
def error_tool_fn() -> None:
110117
raise ValueError("Test error")
111118

@@ -129,6 +136,13 @@ async def test_add_tool(self):
129136
mcp.add_tool(tool_fn)
130137
assert len(mcp._tool_manager.list_tools()) == 1
131138

139+
@pytest.mark.anyio
140+
async def test_add_tool_with_client_provided_arg(self):
141+
mcp = FastMCP()
142+
mcp.add_tool(tool_fn)
143+
mcp.add_tool(tool_with_client_provided_args_fn)
144+
assert len(mcp._tool_manager.list_tools()) == 2
145+
132146
@pytest.mark.anyio
133147
async def test_list_tools(self):
134148
mcp = FastMCP()

tests/server/fastmcp/test_tool_manager.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import json
22
import logging
3+
from typing import Annotated
34

45
import pytest
56
from pydantic import BaseModel
67

78
from mcp.server.fastmcp import Context, FastMCP
89
from mcp.server.fastmcp.exceptions import ToolError
910
from mcp.server.fastmcp.tools import ToolManager
11+
from mcp.server.fastmcp.utilities.func_metadata import ClientProvidedArg
1012
from mcp.server.session import ServerSessionT
1113
from mcp.shared.context import LifespanContextT
1214

@@ -220,6 +222,31 @@ def name_shrimp(tank: MyShrimpTank, ctx: Context) -> list[str]:
220222
)
221223
assert result == ["rex", "gertrude"]
222224

225+
@pytest.mark.anyio
226+
async def test_call_tool_with_client_provided_arg(self):
227+
class ClientArgModel(BaseModel):
228+
arg1: int
229+
arg2: str
230+
231+
def afunc(args: Annotated[dict, ClientProvidedArg()], ctx: Context) -> str:
232+
args_obj = ClientArgModel.model_validate(args)
233+
return f"{args_obj.arg1} {args_obj.arg2}"
234+
235+
manager = ToolManager()
236+
manager.add_tool(afunc)
237+
result = await manager.call_tool(
238+
"afunc",
239+
{"args": {"arg1": 3, "arg2": "apple"}},
240+
)
241+
assert result == "3 apple"
242+
243+
with pytest.raises(ToolError):
244+
# Raises an error because it misses the required args
245+
result = await manager.call_tool(
246+
"afunc",
247+
{},
248+
)
249+
223250

224251
class TestToolSchema:
225252
@pytest.mark.anyio
@@ -233,6 +260,18 @@ def something(a: int, ctx: Context) -> int:
233260
assert "Context" not in json.dumps(tool.parameters)
234261
assert "ctx" not in tool.fn_metadata.arg_model.model_fields
235262

263+
@pytest.mark.anyio
264+
async def test_client_provided_arg_excluded_from_schema(self):
265+
def something(a: int, b: Annotated[int, ClientProvidedArg()]) -> int:
266+
return a + b
267+
268+
manager = ToolManager()
269+
tool = manager.add_tool(something)
270+
assert "properties" in tool.parameters
271+
assert "b" not in tool.parameters["properties"]
272+
assert tool.fn_metadata.client_provided_arg_model is not None
273+
assert "b" in tool.fn_metadata.client_provided_arg_model.model_fields
274+
236275

237276
class TestContextHandling:
238277
"""Test context handling in the tool manager."""

0 commit comments

Comments
 (0)