22import logging
33from typing import Optional , TextIO , cast
44
5+ import requests
56from mcp import ClientSession , InitializeResult , StdioServerParameters , stdio_client
67from mcp .client .sse import sse_client
8+ from mcp .client .streamable_http import streamablehttp_client
79
810from mcpm .core .schema import ServerConfig , SSEServerConfig , STDIOServerConfig
911
@@ -21,6 +23,11 @@ def _sse_transport_context(server_config: ServerConfig):
2123 return sse_client (server_config .url , headers = server_config .headers )
2224
2325
26+ def _streamable_http_transport_context (server_config : ServerConfig ):
27+ server_config = cast (SSEServerConfig , server_config )
28+ return streamablehttp_client (server_config .url , headers = server_config .headers )
29+
30+
2431class ServerConnection :
2532 def __init__ (self , server_config : ServerConfig , errlog : TextIO ) -> None :
2633 self .session : Optional [ClientSession ] = None
@@ -31,14 +38,21 @@ def __init__(self, server_config: ServerConfig, errlog: TextIO) -> None:
3138 self ._shutdown_event = asyncio .Event ()
3239 self ._errlog = errlog
3340
34- self ._transport_context_factory = (
35- lambda config : _stdio_transport_context (config , errlog = self ._errlog )
36- if isinstance (config , STDIOServerConfig )
37- else _sse_transport_context (config )
38- )
39-
4041 self ._server_task = asyncio .create_task (self ._server_lifespan_cycle ())
4142
43+ def _transport_context_factory (self , server_config : ServerConfig ):
44+ if isinstance (server_config , STDIOServerConfig ):
45+ return _stdio_transport_context (server_config , self ._errlog )
46+ elif isinstance (server_config , SSEServerConfig ):
47+ r = requests .head (server_config .url )
48+ if r .status_code != 200 :
49+ return _streamable_http_transport_context (server_config )
50+ if r .headers .get ("connection" ) == "keep-alive" and r .headers .get ("content-type" , "" ).startswith (
51+ "text/event-stream"
52+ ):
53+ return _sse_transport_context (server_config )
54+ return _streamable_http_transport_context (server_config )
55+
4256 def healthy (self ) -> bool :
4357 return self .session is not None and self ._initialized
4458
@@ -56,7 +70,7 @@ async def wait_for_shutdown_request(self):
5670
5771 async def _server_lifespan_cycle (self ):
5872 try :
59- async with self ._transport_context_factory (self .server_config ) as (read , write ):
73+ async with self ._transport_context_factory (self .server_config ) as (read , write , * _ ):
6074 async with ClientSession (read , write ) as session :
6175 self .session_initialized_response = await session .initialize ()
6276
@@ -68,5 +82,8 @@ async def _server_lifespan_cycle(self):
6882 await self .wait_for_shutdown_request ()
6983 except Exception as e :
7084 logger .error (f"Failed to connect to server { self .server_config .name } : { e } " )
85+ import traceback
86+
87+ traceback .print_exc ()
7188 self ._initialized_event .set ()
7289 self ._shutdown_event .set ()
0 commit comments