Skip to content

Commit bf5a44e

Browse files
committed
added streamablehttp mcp servers support
1 parent 29309e1 commit bf5a44e

File tree

2 files changed

+81
-10
lines changed

2 files changed

+81
-10
lines changed

mcpgateway/services/gateway_service.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from filelock import FileLock, Timeout
2525
from mcp import ClientSession
2626
from mcp.client.sse import sse_client
27+
from mcp.client.streamable_http import streamablehttp_client
2728
from sqlalchemy import select
2829
from sqlalchemy.orm import Session
2930

@@ -186,7 +187,7 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
186187
auth_type = getattr(gateway, "auth_type", None)
187188
auth_value = getattr(gateway, "auth_value", {})
188189

189-
capabilities, tools = await self._initialize_gateway(str(gateway.url), auth_value)
190+
capabilities, tools = await self._initialize_gateway(str(gateway.url), auth_value, gateway.transport)
190191

191192
all_names = [td.name for td in tools]
192193

@@ -217,6 +218,7 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
217218
name=gateway.name,
218219
url=str(gateway.url),
219220
description=gateway.description,
221+
transport=gateway.transport,
220222
capabilities=capabilities,
221223
last_seen=datetime.now(timezone.utc),
222224
auth_type=auth_type,
@@ -311,6 +313,8 @@ async def update_gateway(self, db: Session, gateway_id: int, gateway_update: Gat
311313
gateway.url = str(gateway_update.url)
312314
if gateway_update.description is not None:
313315
gateway.description = gateway_update.description
316+
if gateway_update.transport is not None:
317+
gateway.transport = gateway_update.transport
314318

315319
if getattr(gateway, "auth_type", None) is not None:
316320
gateway.auth_type = gateway_update.auth_type
@@ -322,7 +326,7 @@ async def update_gateway(self, db: Session, gateway_id: int, gateway_update: Gat
322326
# Try to reinitialize connection if URL changed
323327
if gateway_update.url is not None:
324328
try:
325-
capabilities, _ = await self._initialize_gateway(gateway.url, gateway.auth_value)
329+
capabilities, _ = await self._initialize_gateway(gateway.url, gateway.auth_value, gateway.transport)
326330
gateway.capabilities = capabilities
327331
gateway.last_seen = datetime.utcnow()
328332

@@ -399,7 +403,7 @@ async def toggle_gateway_status(self, db: Session, gateway_id: int, activate: bo
399403
self._active_gateways.add(gateway.url)
400404
# Try to initialize if activating
401405
try:
402-
capabilities, tools = await self._initialize_gateway(gateway.url, gateway.auth_value)
406+
capabilities, tools = await self._initialize_gateway(gateway.url, gateway.auth_value, gateway.transport)
403407
gateway.capabilities = capabilities.dict()
404408
gateway.last_seen = datetime.utcnow()
405409
except Exception as e:
@@ -571,14 +575,20 @@ async def check_health_of_gateways(self, gateways: List[DbGateway]) -> bool:
571575
headers = decode_auth(auth_data)
572576

573577
# Perform the GET and raise on 4xx/5xx
574-
async with client.stream("GET", gateway.url, headers=headers) as response:
575-
# This will raise immediately if status is 4xx/5xx
576-
response.raise_for_status()
578+
if (gateway.transport).lower() == "sse":
579+
async with client.stream("GET", gateway.url, headers=headers) as response:
580+
# This will raise immediately if status is 4xx/5xx
581+
response.raise_for_status()
582+
elif (gateway.transport).lower() == "streamablehttp":
583+
async with streamablehttp_client(url=gateway.url, headers=headers, timeout=5) as (read_stream, write_stream, get_session_id):
584+
async with ClientSession(read_stream, write_stream) as session:
585+
# Initialize the session
586+
response = await session.initialize()
577587

578588
# Mark successful check
579589
gateway.last_seen = datetime.utcnow()
580590

581-
except Exception:
591+
except Exception as e:
582592
await self._handle_gateway_failure(gateway)
583593

584594
# All gateways passed
@@ -629,7 +639,7 @@ async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]:
629639
finally:
630640
self._event_subscribers.remove(queue)
631641

632-
async def _initialize_gateway(self, url: str, authentication: Optional[Dict[str, str]] = None) -> Any:
642+
async def _initialize_gateway(self, url: str, authentication: Optional[Dict[str, str]] = None, transport: str= "sse") -> Any:
633643
"""Initialize connection to a gateway and retrieve its capabilities.
634644
635645
Args:
@@ -676,7 +686,45 @@ async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[s
676686

677687
return capabilities, tools
678688

679-
capabilities, tools = await connect_to_sse_server(url, authentication)
689+
async def connect_to_streamablehttp_server(server_url: str, authentication: Optional[Dict[str, str]] = None):
690+
"""
691+
Connect to an MCP server running with Streamable HTTP transport
692+
693+
Args:
694+
server_url: URL to connect to the server
695+
authentication: Authentication headers for connection to URL
696+
697+
Returns:
698+
list, list: List of capabilities and tools
699+
"""
700+
if authentication is None:
701+
authentication = {}
702+
# Store the context managers so they stay alive
703+
decoded_auth = decode_auth(authentication)
704+
705+
# Use async with for both streamablehttp_client and ClientSession
706+
async with streamablehttp_client(url=server_url, headers=decoded_auth) as (read_stream, write_stream, get_session_id):
707+
async with ClientSession(read_stream, write_stream) as session:
708+
# Initialize the session
709+
response = await session.initialize()
710+
# if get_session_id:
711+
# session_id = get_session_id()
712+
# if session_id:
713+
# print(f"Session ID: {session_id}")
714+
capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
715+
response = await session.list_tools()
716+
tools = response.tools
717+
tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
718+
tools = [ToolCreate.model_validate(tool) for tool in tools]
719+
for tool in tools:
720+
tool.request_type = "STREAMABLEHTTP"
721+
722+
return capabilities, tools
723+
724+
if transport.lower() == "sse":
725+
capabilities, tools = await connect_to_sse_server(url, authentication)
726+
elif transport.lower() == "streamablehttp":
727+
capabilities, tools = await connect_to_streamablehttp_server(url, authentication)
680728

681729
return capabilities, tools
682730
except Exception as e:

mcpgateway/services/tool_service.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import httpx
2626
from mcp import ClientSession
2727
from mcp.client.sse import sse_client
28+
from mcp.client.streamable_http import streamablehttp_client
2829
from sqlalchemy import delete, func, not_, select
2930
from sqlalchemy.exc import IntegrityError
3031
from sqlalchemy.orm import Session
@@ -515,6 +516,7 @@ async def invoke_tool(self, db: Session, name: str, arguments: Dict[str, Any]) -
515516

516517
success = True
517518
elif tool.integration_type == "MCP":
519+
transport = tool.request_type.lower()
518520
gateway = db.execute(select(DbGateway).where(DbGateway.id == tool.gateway_id).where(DbGateway.is_active)).scalar_one_or_none()
519521
if gateway.auth_type == "bearer":
520522
headers = decode_auth(gateway.auth_value)
@@ -539,10 +541,31 @@ async def connect_to_sse_server(server_url: str) -> str:
539541
tool_call_result = await session.call_tool(name, arguments)
540542
return tool_call_result
541543

544+
async def connect_to_streamablehttp_server(server_url: str) -> str:
545+
"""
546+
Connect to an MCP server running with Streamable HTTP transport
547+
548+
Args:
549+
server_url (str): MCP Server URL
550+
551+
Returns:
552+
str: Result of tool call
553+
"""
554+
# Use async with directly to manage the context
555+
async with streamablehttp_client(url=server_url, headers=headers) as (read_stream, write_stream, get_session_id):
556+
async with ClientSession(read_stream, write_stream) as session:
557+
# Initialize the session
558+
await session.initialize()
559+
tool_call_result = await session.call_tool(name, arguments)
560+
return tool_call_result
561+
542562
tool_gateway_id = tool.gateway_id
543563
tool_gateway = db.execute(select(DbGateway).where(DbGateway.id == tool_gateway_id).where(DbGateway.is_active)).scalar_one_or_none()
544564

545-
tool_call_result = await connect_to_sse_server(tool_gateway.url)
565+
if transport == "sse":
566+
tool_call_result = await connect_to_sse_server(tool_gateway.url)
567+
elif transport == "streamablehttp":
568+
tool_call_result = await connect_to_streamablehttp_server(tool_gateway.url)
546569
content = tool_call_result.model_dump(by_alias=True).get("content", [])
547570

548571
success = True

0 commit comments

Comments
 (0)