99from dataclasses import dataclass , field
1010from typing import Any , TypeVar
1111
12- from pydantic import ValidationError
12+ import msgspec
1313
1414from .queue import ResponseQueueProtocol
1515from .schema import (
@@ -75,10 +75,18 @@ class MCPServer:
7575
7676 _tools : dict [str , Callable ] = field (default_factory = dict )
7777 _tools_list : dict [str , ToolInfo ] = field (default_factory = dict )
78+ _tool_arg_models : dict [str , type ] = field (default_factory = dict )
7879
7980 def register_tool (self , name : str , callable : Callable , tool_info : ToolInfo ) -> None :
8081 self ._tools_list [name ] = tool_info
8182 self ._tools [name ] = callable
83+ # Track arg model for validation separately from serializable ToolInfo
84+ try :
85+ from .utils import inspect_callable
86+
87+ self ._tool_arg_models [name ] = inspect_callable (callable ).arg_model
88+ except Exception :
89+ pass
8290
8391 def tool (self , name : str | None = None ) -> Callable :
8492 """Register a tool"""
@@ -256,7 +264,7 @@ def get_list_tools(
256264 paginated_tools , next_page = get_page_of_items (tools , page , page_size )
257265
258266 return ListToolsResult (
259- tools = [ tool . model_dump ( by_alias = True ) for tool in paginated_tools ] ,
267+ tools = paginated_tools ,
260268 nextCursor = next_page ,
261269 )
262270
@@ -341,13 +349,15 @@ def get_completions(
341349 def _handle_initialize (self , request : MCPRequest ) -> MCPResponse :
342350 """Handle initialize method."""
343351 return MCPResponse (
352+ jsonrpc = "2.0" ,
344353 id = request .id ,
345354 result = self .get_capabilities (),
346355 )
347356
348357 def _handle_ping (self , request : MCPRequest ) -> MCPResponse :
349358 """Handle ping method."""
350359 return MCPResponse (
360+ jsonrpc = "2.0" ,
351361 id = request .id ,
352362 result = {},
353363 )
@@ -358,6 +368,7 @@ def _handle_completion_complete(self, request: MCPRequest) -> MCPResponse:
358368 arg_name = request .params ["argument" ]["name" ]
359369 value = request .params ["argument" ]["value" ]
360370 return MCPResponse (
371+ jsonrpc = "2.0" ,
361372 id = request .id ,
362373 result = {"completion" : self .get_completions (prompt_name , arg_name , value )},
363374 )
@@ -366,6 +377,7 @@ def _handle_prompts_list(self, request: MCPRequest) -> MCPResponse:
366377 """Handle prompts/list method."""
367378 page = int (request .params .get ("cursor" , "1" ))
368379 return MCPResponse (
380+ jsonrpc = "2.0" ,
369381 id = request .id ,
370382 result = self .get_list_prompts (page = page ),
371383 )
@@ -377,13 +389,15 @@ def _handle_prompts_get(self, request: MCPRequest) -> MCPResponse:
377389 prompt = self ._prompts [name ]
378390 except KeyError :
379391 return MCPResponse (
392+ jsonrpc = "2.0" ,
380393 id = request .id ,
381394 error = ErrorResponse (
382395 code = 400 ,
383396 message = "Prompt not found" ,
384397 ),
385398 )
386399 return MCPResponse (
400+ jsonrpc = "2.0" ,
387401 id = request .id ,
388402 result = prompt (** request .params ["arguments" ]),
389403 )
@@ -392,6 +406,7 @@ def _handle_tools_list(self, request: MCPRequest) -> MCPResponse:
392406 """Handle tools/list method."""
393407 page = int (request .params .get ("cursor" , "1" ))
394408 return MCPResponse (
409+ jsonrpc = "2.0" ,
395410 id = request .id ,
396411 result = self .get_list_tools (page = page ),
397412 )
@@ -400,6 +415,7 @@ def _handle_resources_list(self, request: MCPRequest) -> MCPResponse:
400415 """Handle resources/list method."""
401416 page = int (request .params .get ("cursor" , "1" ))
402417 return MCPResponse (
418+ jsonrpc = "2.0" ,
403419 id = request .id ,
404420 result = self .get_list_resources (page = page ),
405421 )
@@ -408,6 +424,7 @@ def _handle_resources_templates_list(self, request: MCPRequest) -> MCPResponse:
408424 """Handle resources/templates/list method."""
409425 page = int (request .params .get ("cursor" , "1" ))
410426 return MCPResponse (
427+ jsonrpc = "2.0" ,
411428 id = request .id ,
412429 result = self .get_list_resource_templates (page = page ),
413430 )
@@ -419,54 +436,48 @@ def _handle_tools_call(self, request: MCPRequest) -> MCPResponse:
419436
420437 try :
421438 callable = self ._tools [tool_name ]
422- arg_model = self ._tools_list [tool_name ]. arg_model
423- args = arg_model ( ** kwargs )
424- result = callable (** dict (args ))
439+ arg_model = self ._tool_arg_models [tool_name ]
440+ args = msgspec . convert ( kwargs , arg_model )
441+ result = callable (** msgspec . to_builtins (args ))
425442 if isinstance (result , dict ):
426443 result = CallToolResult (
427- content = [
428- TextContent (
429- text = json .dumps (result ),
430- type = "text" ,
431- )
432- ],
444+ content = [TextContent (text = json .dumps (result ))],
433445 is_error = False ,
434446 )
435447 elif isinstance (result , str ):
436448 result = CallToolResult (
437- content = [
438- TextContent (
439- text = result ,
440- type = "text" ,
441- )
442- ],
449+ content = [TextContent (text = result )],
443450 is_error = False ,
444451 )
445452 elif isinstance (result , CallToolResult ):
446453 result = result
447454 else :
448455 logger .error ("Invalid tool result type: %s" , type (result ))
449456 return MCPResponse (
457+ jsonrpc = "2.0" ,
450458 id = request .id ,
451459 error = ErrorResponse (
452460 code = 400 ,
453461 message = "Invalid tool result type" ,
454462 ),
455463 )
456464 return MCPResponse (
465+ jsonrpc = "2.0" ,
457466 id = request .id ,
458467 result = result ,
459468 )
460469 except KeyError :
461470 return MCPResponse (
471+ jsonrpc = "2.0" ,
462472 id = request .id ,
463473 error = ErrorResponse (
464474 code = - 32601 ,
465475 message = "Tool not found" ,
466476 ),
467477 )
468- except ValidationError as e :
478+ except Exception as e :
469479 return MCPResponse (
480+ jsonrpc = "2.0" ,
470481 id = request .id ,
471482 error = ErrorResponse (
472483 code = - 32602 ,
@@ -476,6 +487,7 @@ def _handle_tools_call(self, request: MCPRequest) -> MCPResponse:
476487 except Exception as e :
477488 logger .error (f"Error in tool { tool_name } : { e } " )
478489 return MCPResponse (
490+ jsonrpc = "2.0" ,
479491 id = request .id ,
480492 error = ErrorResponse (
481493 code = - 32603 ,
@@ -501,18 +513,18 @@ def _handle_message(
501513 message_id = message ["id" ]
502514 except KeyError :
503515 return MCPResponse (
516+ jsonrpc = "2.0" ,
504517 id = 0 ,
505518 error = ErrorResponse (
506519 code = - 32600 ,
507520 message = "Missing message id" ,
508521 ),
509522 )
510523 try :
511- mcp_request = MCPRequest .model_validate (
512- {** message , "session_id" : session_id },
513- )
514- except ValidationError as e :
524+ mcp_request = msgspec .convert ({** message }, MCPRequest )
525+ except Exception as e :
515526 return MCPResponse (
527+ jsonrpc = "2.0" ,
516528 id = 0 ,
517529 error = ErrorResponse (
518530 code = - 32600 ,
@@ -545,6 +557,7 @@ def _handle_message(
545557 return handler (mcp_request )
546558 else :
547559 return MCPResponse (
560+ jsonrpc = "2.0" ,
548561 id = message_id ,
549562 error = ErrorResponse (
550563 code = - 32601 ,
@@ -570,5 +583,6 @@ def get_page_of_items(
570583 start_idx = (page - 1 ) * page_size
571584 end_idx = start_idx + page_size
572585 page_items = items [start_idx :end_idx ]
573- next_page = str (page + 1 ) if len (items ) > end_idx else None
586+ # Use None to indicate no next page (not UNSET) for consistency with tests
587+ next_page = str (page + 1 ) if len (items ) > end_idx else msgspec .UNSET
574588 return page_items , next_page
0 commit comments