|
13 | 13 | from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder |
14 | 14 | from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS |
15 | 15 |
|
16 | | - |
17 | | -class ToolOutputValidator: |
18 | | - async def validate( |
19 | | - self, request: types.CallToolRequest, result: types.CallToolResult |
20 | | - ) -> bool: |
21 | | - raise RuntimeError("Not implemented") |
22 | | - |
23 | | - |
24 | 16 | DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") |
25 | 17 |
|
26 | 18 |
|
@@ -54,6 +46,12 @@ async def __call__( |
54 | 46 | ) -> None: ... |
55 | 47 |
|
56 | 48 |
|
| 49 | +class ToolOutputValidationFnT(Protocol): |
| 50 | + async def __call__( |
| 51 | + self, request: types.CallToolRequest, result: types.CallToolResult |
| 52 | + ) -> bool: ... |
| 53 | + |
| 54 | + |
57 | 55 | async def _default_message_handler( |
58 | 56 | message: RequestResponder[types.ServerRequest, types.ClientResult] |
59 | 57 | | types.ServerNotification |
@@ -89,13 +87,13 @@ async def _default_logging_callback( |
89 | 87 |
|
90 | 88 | ToolOutputValidatorProvider: TypeAlias = Callable[ |
91 | 89 | ..., |
92 | | - Awaitable[ToolOutputValidator], |
| 90 | + Awaitable[ToolOutputValidationFnT], |
93 | 91 | ] |
94 | 92 |
|
95 | 93 |
|
96 | 94 | # this bag of spanners is required in order to |
97 | 95 | # enable the client session to be parsed to the validator |
98 | | -async def _python_circularity_hell(arg: Any) -> ToolOutputValidator: |
| 96 | +async def _python_circularity_hell(arg: Any) -> ToolOutputValidationFnT: |
99 | 97 | # in any sane version of the universe this should never happen |
100 | 98 | # of course in any sane programming language class circularity |
101 | 99 | # dependencies shouldn't be this hard to manage |
@@ -327,7 +325,7 @@ async def call_tool( |
327 | 325 | ) |
328 | 326 |
|
329 | 327 | if validate_result: |
330 | | - valid = await self._tool_output_validator.validate(request, result) |
| 328 | + valid = await self._tool_output_validator(request, result) |
331 | 329 |
|
332 | 330 | if not valid: |
333 | 331 | raise RuntimeError("Server responded with invalid result: " f"{result}") |
@@ -451,15 +449,15 @@ async def _received_notification( |
451 | 449 | pass |
452 | 450 |
|
453 | 451 |
|
454 | | -class SimpleCachingToolOutputValidator(ToolOutputValidator): |
| 452 | +class SimpleCachingToolOutputValidator(ToolOutputValidationFnT): |
455 | 453 | _schema_cache: dict[str, dict[str, Any] | bool] |
456 | 454 |
|
457 | 455 | def __init__(self, session: ClientSession): |
458 | 456 | self._session = session |
459 | 457 | self._schema_cache = {} |
460 | 458 | self._refresh_cache = True |
461 | 459 |
|
462 | | - async def validate( |
| 460 | + async def __call__( |
463 | 461 | self, request: types.CallToolRequest, result: types.CallToolResult |
464 | 462 | ) -> bool: |
465 | 463 | if result.isError: |
@@ -508,7 +506,7 @@ async def _refresh_schema_cache(self): |
508 | 506 |
|
509 | 507 | async def _escape_from_circular_python_hell( |
510 | 508 | session: ClientSession, |
511 | | -) -> ToolOutputValidator: |
| 509 | +) -> ToolOutputValidationFnT: |
512 | 510 | return SimpleCachingToolOutputValidator(session) |
513 | 511 |
|
514 | 512 |
|
|
0 commit comments