|
2 | 2 | import logging
|
3 | 3 | import warnings
|
4 | 4 | from collections.abc import Awaitable, Callable
|
5 |
| -from typing import Any |
| 5 | +from typing import Any, Self |
6 | 6 |
|
7 | 7 | from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
8 | 8 | from pydantic import AnyUrl
|
9 | 9 |
|
10 | 10 | from mcp_python.server import types
|
11 |
| -from mcp_python.server.session import ServerSession |
| 11 | +from mcp_python.server.session import ServerSession, SessionInitializationOptions |
12 | 12 | from mcp_python.server.stdio import stdio_server as stdio_server
|
13 | 13 | from mcp_python.shared.context import RequestContext
|
14 | 14 | from mcp_python.shared.session import RequestResponder
|
|
32 | 32 | ReadResourceResult,
|
33 | 33 | Resource,
|
34 | 34 | ResourceReference,
|
| 35 | + ServerCapabilities, |
35 | 36 | ServerResult,
|
36 | 37 | SetLevelRequest,
|
37 | 38 | SubscribeRequest,
|
|
40 | 41 |
|
41 | 42 | logger = logging.getLogger(__name__)
|
42 | 43 |
|
43 |
| - |
44 | 44 | request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar(
|
45 | 45 | "request_ctx"
|
46 | 46 | )
|
47 | 47 |
|
| 48 | +def pkg_version(package: str) -> str: |
| 49 | + try: |
| 50 | + from importlib.metadata import version |
| 51 | + return version(package) |
| 52 | + except Exception: |
| 53 | + return "unknown" |
| 54 | + |
| 55 | +class InitializationOptions(SessionInitializationOptions): |
| 56 | + """Information about a server provided as initialization options when a new session is started.""" |
| 57 | + |
| 58 | + @classmethod |
| 59 | + def from_server(cls, server: "Server") -> Self: |
| 60 | + return cls( |
| 61 | + server_name=server.name, |
| 62 | + server_version=pkg_version("mcp_python"), |
| 63 | + capabilities=server.get_capabilities() |
| 64 | + ) |
48 | 65 |
|
49 | 66 | class Server:
|
50 | 67 | def __init__(self, name: str):
|
@@ -276,13 +293,26 @@ async def handler(req: CompleteRequest):
|
276 | 293 |
|
277 | 294 | return decorator
|
278 | 295 |
|
| 296 | + def get_capabilities(self) -> ServerCapabilities: |
| 297 | + """Convert existing handlers to a ServerCapabilities object.""" |
| 298 | + def get_capability(req_type: type) -> dict[str, Any] | None: |
| 299 | + return {} if req_type in self.request_handlers else None |
| 300 | + |
| 301 | + return ServerCapabilities( |
| 302 | + prompts=get_capability(ListPromptsRequest), |
| 303 | + resources=get_capability(ListResourcesRequest), |
| 304 | + tools=get_capability(ListPromptsRequest), |
| 305 | + logging=get_capability(SetLevelRequest) |
| 306 | + ) |
| 307 | + |
279 | 308 | async def run(
|
280 | 309 | self,
|
281 | 310 | read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
282 | 311 | write_stream: MemoryObjectSendStream[JSONRPCMessage],
|
| 312 | + initialization_options: InitializationOptions |
283 | 313 | ):
|
284 | 314 | with warnings.catch_warnings(record=True) as w:
|
285 |
| - async with ServerSession(read_stream, write_stream) as session: |
| 315 | + async with ServerSession(read_stream, write_stream, initialization_options) as session: |
286 | 316 | async for message in session.incoming_messages:
|
287 | 317 | logger.debug(f"Received message: {message}")
|
288 | 318 |
|
|
0 commit comments