Skip to content

Commit 8e02710

Browse files
committed
updated _validate_gateway_url to handle both sse & streamablehttp invalid gateway URL's
Signed-off-by: Satya <[email protected]>
1 parent 91bb2cf commit 8e02710

File tree

2 files changed

+41
-27
lines changed

2 files changed

+41
-27
lines changed

mcpgateway/services/gateway_service.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def __init__(self) -> None:
244244
else:
245245
self._redis_client = None
246246

247-
async def _validate_gateway_url(self, url: str, headers: dict, timeout=5):
247+
async def _validate_gateway_url(self, url: str, headers: dict, transport_type: str, timeout=5):
248248
"""
249249
Validate if the given URL is a live Server-Sent Events (SSE) endpoint.
250250
@@ -255,6 +255,7 @@ async def _validate_gateway_url(self, url: str, headers: dict, timeout=5):
255255
Args:
256256
url (str): The full URL of the endpoint to validate.
257257
headers (dict): Headers to be included in the requests (e.g., Authorization).
258+
transport_type (str): SSE or STREAMABLEHTTP
258259
timeout (int, optional): Timeout in seconds for both requests. Defaults to 5.
259260
260261
Returns:
@@ -265,12 +266,22 @@ async def _validate_gateway_url(self, url: str, headers: dict, timeout=5):
265266
timeout = httpx.Timeout(timeout)
266267
try:
267268
async with client.stream("GET", url, headers=headers, timeout=timeout) as response:
268-
response.raise_for_status()
269-
response_head = await client.request("HEAD", url, headers=headers, timeout=timeout)
270-
response.raise_for_status()
271-
content_type = response_head.headers.get("Content-Type", "")
272-
if "text/event-stream" in content_type.lower():
273-
return True
269+
response_headers = dict(response.headers)
270+
location = response_headers.get("location")
271+
content_type = response_headers.get("content-type")
272+
if transport_type == "STREAMABLEHTTP":
273+
if location:
274+
async with client.stream("GET", location, headers=headers, timeout=timeout) as response_redirect:
275+
response_headers = dict(response_redirect.headers)
276+
mcp_session_id = response_headers.get("mcp-session-id")
277+
content_type = response_headers.get("content-type")
278+
if mcp_session_id is not None and mcp_session_id != "":
279+
if content_type is not None and content_type != "" and "application/json" in content_type:
280+
return True
281+
282+
elif transport_type == "SSE":
283+
if content_type is not None and content_type != "" and "text/event-stream" in content_type:
284+
return True
274285
return False
275286
except Exception:
276287
return False
@@ -1187,7 +1198,7 @@ async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[s
11871198
# Store the context managers so they stay alive
11881199
decoded_auth = decode_auth(authentication)
11891200

1190-
if await self._validate_gateway_url(url=server_url, headers=decoded_auth):
1201+
if await self._validate_gateway_url(url=server_url, headers=decoded_auth, transport_type="SSE"):
11911202
# Use async with for both sse_client and ClientSession
11921203
async with sse_client(url=server_url, headers=decoded_auth) as streams:
11931204
async with ClientSession(*streams) as session:
@@ -1220,25 +1231,26 @@ async def connect_to_streamablehttp_server(server_url: str, authentication: Opti
12201231
authentication = {}
12211232
# Store the context managers so they stay alive
12221233
decoded_auth = decode_auth(authentication)
1234+
if await self._validate_gateway_url(url=server_url, headers=decoded_auth, transport_type="STREAMABLEHTTP"):
1235+
# Use async with for both streamablehttp_client and ClientSession
1236+
async with streamablehttp_client(url=server_url, headers=decoded_auth) as (read_stream, write_stream, _get_session_id):
1237+
async with ClientSession(read_stream, write_stream) as session:
1238+
# Initialize the session
1239+
response = await session.initialize()
1240+
# if get_session_id:
1241+
# session_id = get_session_id()
1242+
# if session_id:
1243+
# print(f"Session ID: {session_id}")
1244+
capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
1245+
response = await session.list_tools()
1246+
tools = response.tools
1247+
tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
1248+
tools = [ToolCreate.model_validate(tool) for tool in tools]
1249+
for tool in tools:
1250+
tool.request_type = "STREAMABLEHTTP"
12231251

1224-
# Use async with for both streamablehttp_client and ClientSession
1225-
async with streamablehttp_client(url=server_url, headers=decoded_auth) as (read_stream, write_stream, _get_session_id):
1226-
async with ClientSession(read_stream, write_stream) as session:
1227-
# Initialize the session
1228-
response = await session.initialize()
1229-
# if get_session_id:
1230-
# session_id = get_session_id()
1231-
# if session_id:
1232-
# print(f"Session ID: {session_id}")
1233-
capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
1234-
response = await session.list_tools()
1235-
tools = response.tools
1236-
tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
1237-
tools = [ToolCreate.model_validate(tool) for tool in tools]
1238-
for tool in tools:
1239-
tool.request_type = "STREAMABLEHTTP"
1240-
1241-
return capabilities, tools
1252+
return capabilities, tools
1253+
raise GatewayConnectionError(f"Failed to initialize gateway at {url}")
12421254

12431255
capabilities = {}
12441256
tools = []

mcpgateway/validators.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ class SecurityValidator:
5252
"""Configurable validation with MCP-compliant limits"""
5353

5454
# Configurable patterns (from settings)
55-
DANGEROUS_HTML_PATTERN = settings.validation_dangerous_html_pattern # Default: '<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|</*(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)>'
55+
DANGEROUS_HTML_PATTERN = (
56+
settings.validation_dangerous_html_pattern
57+
) # Default: '<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|</*(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)>'
5658
DANGEROUS_JS_PATTERN = settings.validation_dangerous_js_pattern # Default: javascript:|vbscript:|on\w+\s*=|data:.*script
5759
ALLOWED_URL_SCHEMES = settings.validation_allowed_url_schemes # Default: ["http://", "https://", "ws://", "wss://"]
5860

0 commit comments

Comments
 (0)