Skip to content

Commit 21ed548

Browse files
first stab at converting to msgspec
1 parent 1544471 commit 21ed548

File tree

11 files changed

+375
-303
lines changed

11 files changed

+375
-303
lines changed

.envrc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
export VIRTUAL_ENV=.venv
2+
export VIRTUAL_ENV_PROMPT=$(basename "$PWD")
3+
layout python3

.pre-commit-config.yaml

Lines changed: 0 additions & 7 deletions
This file was deleted.

pyproject.toml

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ classifiers = [
2222
"Programming Language :: Python :: 3.11",
2323
]
2424
dependencies = [
25-
"pydantic>=2.0.0",
25+
"msgspec>=0.18.0",
2626
]
2727

2828
[project.urls]
@@ -31,44 +31,15 @@ Repository = "https://github.com/fulfilio/mcp-utils.git"
3131

3232
[project.optional-dependencies]
3333
dev = [
34-
"pre-commit>=3.6.0",
35-
"ruff>=0.3.0",
3634
"pytest>=8.0.0",
37-
"pytest-cov>=4.1.0",
3835
"bump-my-version>=0.15.0",
3936
]
4037

4138
[tool.hatch.build.targets.wheel]
4239
packages = ["src/mcp_utils"]
4340

44-
[tool.ruff]
45-
target-version = "py310"
46-
line-length = 88
47-
extend-include = ["*.ipynb"]
48-
49-
[tool.ruff.lint]
50-
select = [
51-
"E", # pycodestyle errors
52-
"W", # pycodestyle warnings
53-
"F", # pyflakes
54-
"I", # isort
55-
"B", # flake8-bugbear
56-
"C4", # flake8-comprehensions
57-
"UP", # pyupgrade
58-
]
59-
ignore = []
60-
61-
[tool.ruff.format]
62-
quote-style = "double"
63-
indent-style = "space"
64-
skip-magic-trailing-comma = false
65-
line-ending = "auto"
66-
67-
[tool.ruff.lint.isort]
68-
known-first-party = ["mcp_utils"]
6941

7042
[tool.pytest.ini_options]
71-
addopts = "--cov=mcp_utils --cov-report=term-missing"
7243
testpaths = ["tests"]
7344
python_files = ["test_*.py"]
7445

src/mcp_utils/core.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from dataclasses import dataclass, field
1010
from typing import Any, TypeVar
1111

12-
from pydantic import ValidationError
12+
import msgspec
1313

1414
from .queue import ResponseQueueProtocol
1515
from .schema import (
@@ -75,10 +75,18 @@ class MCPServer:
7575

7676
_tools: dict[str, Callable] = field(default_factory=dict)
7777
_tools_list: dict[str, ToolInfo] = field(default_factory=dict)
78+
_tool_arg_models: dict[str, type] = field(default_factory=dict)
7879

7980
def register_tool(self, name: str, callable: Callable, tool_info: ToolInfo) -> None:
8081
self._tools_list[name] = tool_info
8182
self._tools[name] = callable
83+
# Track arg model for validation separately from serializable ToolInfo
84+
try:
85+
from .utils import inspect_callable
86+
87+
self._tool_arg_models[name] = inspect_callable(callable).arg_model
88+
except Exception:
89+
pass
8290

8391
def tool(self, name: str | None = None) -> Callable:
8492
"""Register a tool"""
@@ -256,7 +264,7 @@ def get_list_tools(
256264
paginated_tools, next_page = get_page_of_items(tools, page, page_size)
257265

258266
return ListToolsResult(
259-
tools=[tool.model_dump(by_alias=True) for tool in paginated_tools],
267+
tools=paginated_tools,
260268
nextCursor=next_page,
261269
)
262270

@@ -341,13 +349,15 @@ def get_completions(
341349
def _handle_initialize(self, request: MCPRequest) -> MCPResponse:
342350
"""Handle initialize method."""
343351
return MCPResponse(
352+
jsonrpc="2.0",
344353
id=request.id,
345354
result=self.get_capabilities(),
346355
)
347356

348357
def _handle_ping(self, request: MCPRequest) -> MCPResponse:
349358
"""Handle ping method."""
350359
return MCPResponse(
360+
jsonrpc="2.0",
351361
id=request.id,
352362
result={},
353363
)
@@ -358,6 +368,7 @@ def _handle_completion_complete(self, request: MCPRequest) -> MCPResponse:
358368
arg_name = request.params["argument"]["name"]
359369
value = request.params["argument"]["value"]
360370
return MCPResponse(
371+
jsonrpc="2.0",
361372
id=request.id,
362373
result={"completion": self.get_completions(prompt_name, arg_name, value)},
363374
)
@@ -366,6 +377,7 @@ def _handle_prompts_list(self, request: MCPRequest) -> MCPResponse:
366377
"""Handle prompts/list method."""
367378
page = int(request.params.get("cursor", "1"))
368379
return MCPResponse(
380+
jsonrpc="2.0",
369381
id=request.id,
370382
result=self.get_list_prompts(page=page),
371383
)
@@ -377,13 +389,15 @@ def _handle_prompts_get(self, request: MCPRequest) -> MCPResponse:
377389
prompt = self._prompts[name]
378390
except KeyError:
379391
return MCPResponse(
392+
jsonrpc="2.0",
380393
id=request.id,
381394
error=ErrorResponse(
382395
code=400,
383396
message="Prompt not found",
384397
),
385398
)
386399
return MCPResponse(
400+
jsonrpc="2.0",
387401
id=request.id,
388402
result=prompt(**request.params["arguments"]),
389403
)
@@ -392,6 +406,7 @@ def _handle_tools_list(self, request: MCPRequest) -> MCPResponse:
392406
"""Handle tools/list method."""
393407
page = int(request.params.get("cursor", "1"))
394408
return MCPResponse(
409+
jsonrpc="2.0",
395410
id=request.id,
396411
result=self.get_list_tools(page=page),
397412
)
@@ -400,6 +415,7 @@ def _handle_resources_list(self, request: MCPRequest) -> MCPResponse:
400415
"""Handle resources/list method."""
401416
page = int(request.params.get("cursor", "1"))
402417
return MCPResponse(
418+
jsonrpc="2.0",
403419
id=request.id,
404420
result=self.get_list_resources(page=page),
405421
)
@@ -408,6 +424,7 @@ def _handle_resources_templates_list(self, request: MCPRequest) -> MCPResponse:
408424
"""Handle resources/templates/list method."""
409425
page = int(request.params.get("cursor", "1"))
410426
return MCPResponse(
427+
jsonrpc="2.0",
411428
id=request.id,
412429
result=self.get_list_resource_templates(page=page),
413430
)
@@ -419,54 +436,48 @@ def _handle_tools_call(self, request: MCPRequest) -> MCPResponse:
419436

420437
try:
421438
callable = self._tools[tool_name]
422-
arg_model = self._tools_list[tool_name].arg_model
423-
args = arg_model(**kwargs)
424-
result = callable(**dict(args))
439+
arg_model = self._tool_arg_models[tool_name]
440+
args = msgspec.convert(kwargs, arg_model)
441+
result = callable(**msgspec.to_builtins(args))
425442
if isinstance(result, dict):
426443
result = CallToolResult(
427-
content=[
428-
TextContent(
429-
text=json.dumps(result),
430-
type="text",
431-
)
432-
],
444+
content=[TextContent(text=json.dumps(result))],
433445
is_error=False,
434446
)
435447
elif isinstance(result, str):
436448
result = CallToolResult(
437-
content=[
438-
TextContent(
439-
text=result,
440-
type="text",
441-
)
442-
],
449+
content=[TextContent(text=result)],
443450
is_error=False,
444451
)
445452
elif isinstance(result, CallToolResult):
446453
result = result
447454
else:
448455
logger.error("Invalid tool result type: %s", type(result))
449456
return MCPResponse(
457+
jsonrpc="2.0",
450458
id=request.id,
451459
error=ErrorResponse(
452460
code=400,
453461
message="Invalid tool result type",
454462
),
455463
)
456464
return MCPResponse(
465+
jsonrpc="2.0",
457466
id=request.id,
458467
result=result,
459468
)
460469
except KeyError:
461470
return MCPResponse(
471+
jsonrpc="2.0",
462472
id=request.id,
463473
error=ErrorResponse(
464474
code=-32601,
465475
message="Tool not found",
466476
),
467477
)
468-
except ValidationError as e:
478+
except Exception as e:
469479
return MCPResponse(
480+
jsonrpc="2.0",
470481
id=request.id,
471482
error=ErrorResponse(
472483
code=-32602,
@@ -476,6 +487,7 @@ def _handle_tools_call(self, request: MCPRequest) -> MCPResponse:
476487
except Exception as e:
477488
logger.error(f"Error in tool {tool_name}: {e}")
478489
return MCPResponse(
490+
jsonrpc="2.0",
479491
id=request.id,
480492
error=ErrorResponse(
481493
code=-32603,
@@ -501,18 +513,18 @@ def _handle_message(
501513
message_id = message["id"]
502514
except KeyError:
503515
return MCPResponse(
516+
jsonrpc="2.0",
504517
id=0,
505518
error=ErrorResponse(
506519
code=-32600,
507520
message="Missing message id",
508521
),
509522
)
510523
try:
511-
mcp_request = MCPRequest.model_validate(
512-
{**message, "session_id": session_id},
513-
)
514-
except ValidationError as e:
524+
mcp_request = msgspec.convert({**message}, MCPRequest)
525+
except Exception as e:
515526
return MCPResponse(
527+
jsonrpc="2.0",
516528
id=0,
517529
error=ErrorResponse(
518530
code=-32600,
@@ -545,6 +557,7 @@ def _handle_message(
545557
return handler(mcp_request)
546558
else:
547559
return MCPResponse(
560+
jsonrpc="2.0",
548561
id=message_id,
549562
error=ErrorResponse(
550563
code=-32601,
@@ -570,5 +583,6 @@ def get_page_of_items(
570583
start_idx = (page - 1) * page_size
571584
end_idx = start_idx + page_size
572585
page_items = items[start_idx:end_idx]
573-
next_page = str(page + 1) if len(items) > end_idx else None
586+
# Use None to indicate no next page (not UNSET) for consistency with tests
587+
next_page = str(page + 1) if len(items) > end_idx else msgspec.UNSET
574588
return page_items, next_page

src/mcp_utils/queue.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import logging
66
from typing import Protocol
77

8+
import msgspec
9+
810
from .schema import MCPResponse
911

1012
logger = logging.getLogger("mcp_utils")
@@ -58,7 +60,7 @@ def push_response(
5860
response: The response to push
5961
"""
6062
queue_key = self._get_queue_key(session_id)
61-
value = response.model_dump_json(exclude_none=True)
63+
value = msgspec.json.encode(response).decode()
6264
logger.debug(f"Redis: Saving response for session: {session_id}: {value}")
6365
self.redis.rpush(queue_key, value)
6466

0 commit comments

Comments
 (0)