24
24
from filelock import FileLock , Timeout
25
25
from mcp import ClientSession
26
26
from mcp .client .sse import sse_client
27
+ from mcp .client .streamable_http import streamablehttp_client
27
28
from sqlalchemy import select
28
29
from sqlalchemy .orm import Session
29
30
@@ -186,7 +187,7 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
186
187
auth_type = getattr (gateway , "auth_type" , None )
187
188
auth_value = getattr (gateway , "auth_value" , {})
188
189
189
- capabilities , tools = await self ._initialize_gateway (str (gateway .url ), auth_value )
190
+ capabilities , tools = await self ._initialize_gateway (str (gateway .url ), auth_value , gateway . transport )
190
191
191
192
all_names = [td .name for td in tools ]
192
193
@@ -217,6 +218,7 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
217
218
name = gateway .name ,
218
219
url = str (gateway .url ),
219
220
description = gateway .description ,
221
+ transport = gateway .transport ,
220
222
capabilities = capabilities ,
221
223
last_seen = datetime .now (timezone .utc ),
222
224
auth_type = auth_type ,
@@ -311,6 +313,8 @@ async def update_gateway(self, db: Session, gateway_id: int, gateway_update: Gat
311
313
gateway .url = str (gateway_update .url )
312
314
if gateway_update .description is not None :
313
315
gateway .description = gateway_update .description
316
+ if gateway_update .transport is not None :
317
+ gateway .transport = gateway_update .transport
314
318
315
319
if getattr (gateway , "auth_type" , None ) is not None :
316
320
gateway .auth_type = gateway_update .auth_type
@@ -322,7 +326,7 @@ async def update_gateway(self, db: Session, gateway_id: int, gateway_update: Gat
322
326
# Try to reinitialize connection if URL changed
323
327
if gateway_update .url is not None :
324
328
try :
325
- capabilities , _ = await self ._initialize_gateway (gateway .url , gateway .auth_value )
329
+ capabilities , _ = await self ._initialize_gateway (gateway .url , gateway .auth_value , gateway . transport )
326
330
gateway .capabilities = capabilities
327
331
gateway .last_seen = datetime .utcnow ()
328
332
@@ -399,7 +403,7 @@ async def toggle_gateway_status(self, db: Session, gateway_id: int, activate: bo
399
403
self ._active_gateways .add (gateway .url )
400
404
# Try to initialize if activating
401
405
try :
402
- capabilities , tools = await self ._initialize_gateway (gateway .url , gateway .auth_value )
406
+ capabilities , tools = await self ._initialize_gateway (gateway .url , gateway .auth_value , gateway . transport )
403
407
gateway .capabilities = capabilities .dict ()
404
408
gateway .last_seen = datetime .utcnow ()
405
409
except Exception as e :
@@ -571,14 +575,20 @@ async def check_health_of_gateways(self, gateways: List[DbGateway]) -> bool:
571
575
headers = decode_auth (auth_data )
572
576
573
577
# Perform the GET and raise on 4xx/5xx
574
- async with client .stream ("GET" , gateway .url , headers = headers ) as response :
575
- # This will raise immediately if status is 4xx/5xx
576
- response .raise_for_status ()
578
+ if (gateway .transport ).lower () == "sse" :
579
+ async with client .stream ("GET" , gateway .url , headers = headers ) as response :
580
+ # This will raise immediately if status is 4xx/5xx
581
+ response .raise_for_status ()
582
+ elif (gateway .transport ).lower () == "streamablehttp" :
583
+ async with streamablehttp_client (url = gateway .url , headers = headers , timeout = 5 ) as (read_stream , write_stream , get_session_id ):
584
+ async with ClientSession (read_stream , write_stream ) as session :
585
+ # Initialize the session
586
+ response = await session .initialize ()
577
587
578
588
# Mark successful check
579
589
gateway .last_seen = datetime .utcnow ()
580
590
581
- except Exception :
591
+ except Exception as e :
582
592
await self ._handle_gateway_failure (gateway )
583
593
584
594
# All gateways passed
@@ -629,7 +639,7 @@ async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]:
629
639
finally :
630
640
self ._event_subscribers .remove (queue )
631
641
632
- async def _initialize_gateway (self , url : str , authentication : Optional [Dict [str , str ]] = None ) -> Any :
642
+ async def _initialize_gateway (self , url : str , authentication : Optional [Dict [str , str ]] = None , transport : str = "sse" ) -> Any :
633
643
"""Initialize connection to a gateway and retrieve its capabilities.
634
644
635
645
Args:
@@ -676,7 +686,45 @@ async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[s
676
686
677
687
return capabilities , tools
678
688
679
- capabilities , tools = await connect_to_sse_server (url , authentication )
689
+ async def connect_to_streamablehttp_server (server_url : str , authentication : Optional [Dict [str , str ]] = None ):
690
+ """
691
+ Connect to an MCP server running with Streamable HTTP transport
692
+
693
+ Args:
694
+ server_url: URL to connect to the server
695
+ authentication: Authentication headers for connection to URL
696
+
697
+ Returns:
698
+ list, list: List of capabilities and tools
699
+ """
700
+ if authentication is None :
701
+ authentication = {}
702
+ # Store the context managers so they stay alive
703
+ decoded_auth = decode_auth (authentication )
704
+
705
+ # Use async with for both streamablehttp_client and ClientSession
706
+ async with streamablehttp_client (url = server_url , headers = decoded_auth ) as (read_stream , write_stream , get_session_id ):
707
+ async with ClientSession (read_stream , write_stream ) as session :
708
+ # Initialize the session
709
+ response = await session .initialize ()
710
+ # if get_session_id:
711
+ # session_id = get_session_id()
712
+ # if session_id:
713
+ # print(f"Session ID: {session_id}")
714
+ capabilities = response .capabilities .model_dump (by_alias = True , exclude_none = True )
715
+ response = await session .list_tools ()
716
+ tools = response .tools
717
+ tools = [tool .model_dump (by_alias = True , exclude_none = True ) for tool in tools ]
718
+ tools = [ToolCreate .model_validate (tool ) for tool in tools ]
719
+ for tool in tools :
720
+ tool .request_type = "STREAMABLEHTTP"
721
+
722
+ return capabilities , tools
723
+
724
+ if transport .lower () == "sse" :
725
+ capabilities , tools = await connect_to_sse_server (url , authentication )
726
+ elif transport .lower () == "streamablehttp" :
727
+ capabilities , tools = await connect_to_streamablehttp_server (url , authentication )
680
728
681
729
return capabilities , tools
682
730
except Exception as e :
0 commit comments