Skip to content

Commit 91bb2cf

Browse files
committed
validates gateways sse URL, if not valid raises exception
Signed-off-by: Satya <[email protected]>
1 parent 2c22355 commit 91bb2cf

File tree

2 files changed

+53
-24
lines changed

2 files changed

+53
-24
lines changed

mcpgateway/services/gateway_service.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,37 @@ def __init__(self) -> None:
244244
else:
245245
self._redis_client = None
246246

247+
async def _validate_gateway_url(self, url: str, headers: dict, timeout=5):
248+
"""
249+
Validate if the given URL is a live Server-Sent Events (SSE) endpoint.
250+
251+
This function performs a GET request followed by a HEAD request to the provided URL
252+
to ensure the endpoint is reachable and returns a valid `Content-Type` header indicating
253+
Server-Sent Events (`text/event-stream`).
254+
255+
Args:
256+
url (str): The full URL of the endpoint to validate.
257+
headers (dict): Headers to be included in the requests (e.g., Authorization).
258+
timeout (int, optional): Timeout in seconds for both requests. Defaults to 5.
259+
260+
Returns:
261+
bool: True if the endpoint is reachable and supports SSE (Content-Type is
262+
'text/event-stream'), otherwise False.
263+
"""
264+
async with httpx.AsyncClient() as client:
265+
timeout = httpx.Timeout(timeout)
266+
try:
267+
async with client.stream("GET", url, headers=headers, timeout=timeout) as response:
268+
response.raise_for_status()
269+
response_head = await client.request("HEAD", url, headers=headers, timeout=timeout)
270+
response.raise_for_status()
271+
content_type = response_head.headers.get("Content-Type", "")
272+
if "text/event-stream" in content_type.lower():
273+
return True
274+
return False
275+
except Exception:
276+
return False
277+
247278
async def initialize(self) -> None:
248279
"""Initialize the service and start health check if this instance is the leader.
249280
@@ -830,13 +861,11 @@ async def forward_request(self, gateway: DbGateway, method: str, params: Optiona
830861

831862
# Update last seen timestamp
832863
gateway.last_seen = datetime.now(timezone.utc)
833-
834-
if "error" in result:
835-
raise GatewayError(f"Gateway error: {result['error'].get('message')}")
836-
return result.get("result")
837-
838-
except Exception as e:
839-
raise GatewayConnectionError(f"Failed to forward request to {gateway.name}: {str(e)}")
864+
except Exception:
865+
raise GatewayConnectionError(f"Failed to forward request to {gateway.name}")
866+
if "error" in result:
867+
raise GatewayError(f"Gateway error: {result['error'].get('message')}")
868+
return result.get("result")
840869

841870
async def _handle_gateway_failure(self, gateway: str) -> None:
842871
"""Tracks and handles gateway failures during health checks.
@@ -1158,21 +1187,23 @@ async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[s
11581187
# Store the context managers so they stay alive
11591188
decoded_auth = decode_auth(authentication)
11601189

1161-
# Use async with for both sse_client and ClientSession
1162-
async with sse_client(url=server_url, headers=decoded_auth) as streams:
1163-
async with ClientSession(*streams) as session:
1164-
# Initialize the session
1165-
response = await session.initialize()
1166-
capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
1190+
if await self._validate_gateway_url(url=server_url, headers=decoded_auth):
1191+
# Use async with for both sse_client and ClientSession
1192+
async with sse_client(url=server_url, headers=decoded_auth) as streams:
1193+
async with ClientSession(*streams) as session:
1194+
# Initialize the session
1195+
response = await session.initialize()
1196+
capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
11671197

1168-
response = await session.list_tools()
1169-
tools = response.tools
1170-
tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
1198+
response = await session.list_tools()
1199+
tools = response.tools
1200+
tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
11711201

1172-
tools = [ToolCreate.model_validate(tool) for tool in tools]
1173-
logger.info(f"{tools[0]=}")
1202+
tools = [ToolCreate.model_validate(tool) for tool in tools]
1203+
logger.info(f"{tools[0]=}")
11741204

1175-
return capabilities, tools
1205+
return capabilities, tools
1206+
raise GatewayConnectionError(f"Failed to initialize gateway at {url}")
11761207

11771208
async def connect_to_streamablehttp_server(server_url: str, authentication: Optional[Dict[str, str]] = None):
11781209
"""
@@ -1217,8 +1248,8 @@ async def connect_to_streamablehttp_server(server_url: str, authentication: Opti
12171248
capabilities, tools = await connect_to_streamablehttp_server(url, authentication)
12181249

12191250
return capabilities, tools
1220-
except Exception as e:
1221-
raise GatewayConnectionError(f"Failed to initialize gateway at {url}: {str(e)}")
1251+
except Exception:
1252+
raise GatewayConnectionError(f"Failed to initialize gateway at {url}")
12221253

12231254
def _get_gateways(self, include_inactive: bool = True) -> list[DbGateway]:
12241255
"""Sync function for database operations (runs in thread).

mcpgateway/validators.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,7 @@ class SecurityValidator:
5252
"""Configurable validation with MCP-compliant limits"""
5353

5454
# Configurable patterns (from settings)
55-
DANGEROUS_HTML_PATTERN = (
56-
settings.validation_dangerous_html_pattern
57-
) # Default: '<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|</*(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)>'
55+
DANGEROUS_HTML_PATTERN = settings.validation_dangerous_html_pattern # Default: '<(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)\b|</*(script|iframe|object|embed|link|meta|base|form|img|svg|video|audio|source|track|area|map|canvas|applet|frame|frameset|html|head|body|style)>'
5856
DANGEROUS_JS_PATTERN = settings.validation_dangerous_js_pattern # Default: javascript:|vbscript:|on\w+\s*=|data:.*script
5957
ALLOWED_URL_SCHEMES = settings.validation_allowed_url_schemes # Default: ["http://", "https://", "ws://", "wss://"]
6058

0 commit comments

Comments
 (0)