Skip to content

Commit 13e447a

Browse files
authored
Merge pull request #600 from TS0713/handle-invalid-gateway-url
handle-invalid-gateway-url
2 parents 45c897d + 6d69cc2 commit 13e447a

File tree

5 files changed

+306
-43
lines changed

5 files changed

+306
-43
lines changed

.env.example

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,13 @@ MAX_PROMPT_SIZE=102400
240240
# Timeout for rendering prompt templates (in seconds)
241241
PROMPT_RENDER_TIMEOUT=10
242242

243+
#####################################
244+
# Gateway Validation
245+
#####################################
246+
247+
# Timeout for gateway validation (in seconds)
248+
GATEWAY_VALIDATION_TIMEOUT=5
249+
243250
#####################################
244251
# Health Checks
245252
#####################################

mcpgateway/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,9 @@ def _parse_federation_peers(cls, v):
280280
health_check_timeout: int = 10 # seconds
281281
unhealthy_threshold: int = 5 # after this many failures, mark as Offline
282282

283+
# Validation Gateway URL
284+
gateway_validation_timeout: int = 5 # seconds
285+
283286
filelock_name: str = "gateway_service_leader.lock"
284287

285288
# Default Roots

mcpgateway/services/gateway_service.py

Lines changed: 95 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,57 @@ def __init__(self) -> None:
244244
else:
245245
self._redis_client = None
246246

247+
async def _validate_gateway_url(self, url: str, headers: dict, transport_type: str, timeout: Optional[int] = None):
248+
"""
249+
Validate if the given URL is a live Server-Sent Events (SSE) endpoint.
250+
251+
Args:
252+
url (str): The full URL of the endpoint to validate.
253+
headers (dict): Headers to be included in the requests (e.g., Authorization).
254+
transport_type (str): SSE or STREAMABLEHTTP
255+
timeout (int, optional): Timeout in seconds. Defaults to settings.gateway_validation_timeout.
256+
257+
Returns:
258+
bool: True if the endpoint is reachable and supports SSE/StreamableHTTP, otherwise False.
259+
"""
260+
if timeout is None:
261+
timeout = settings.gateway_validation_timeout
262+
validation_client = ResilientHttpClient(client_args={"timeout": settings.gateway_validation_timeout, "verify": not settings.skip_ssl_verify})
263+
try:
264+
async with validation_client.client.stream("GET", url, headers=headers, timeout=timeout) as response:
265+
response_headers = dict(response.headers)
266+
location = response_headers.get("location")
267+
content_type = response_headers.get("content-type")
268+
if response.status_code in (401, 403):
269+
logger.debug(f"Authentication failed for {url} with status {response.status_code}")
270+
return False
271+
272+
if transport_type == "STREAMABLEHTTP":
273+
if location:
274+
async with validation_client.client.stream("GET", location, headers=headers, timeout=timeout) as response_redirect:
275+
response_headers = dict(response_redirect.headers)
276+
mcp_session_id = response_headers.get("mcp-session-id")
277+
content_type = response_headers.get("content-type")
278+
if response_redirect.status_code in (401, 403):
279+
logger.debug(f"Authentication failed at redirect location {location}")
280+
return False
281+
if mcp_session_id is not None and mcp_session_id != "":
282+
if content_type is not None and content_type != "" and "application/json" in content_type:
283+
return True
284+
285+
elif transport_type == "SSE":
286+
if content_type is not None and content_type != "" and "text/event-stream" in content_type:
287+
return True
288+
return False
289+
except httpx.UnsupportedProtocol as e:
290+
logger.debug(f"Gateway URL Unsupported Protocol for {url}: {str(e)}", exc_info=True)
291+
return False
292+
except Exception as e:
293+
logger.debug(f"Gateway validation failed for {url}: {str(e)}", exc_info=True)
294+
return False
295+
finally:
296+
await validation_client.aclose()
297+
247298
async def initialize(self) -> None:
248299
"""Initialize the service and start health check if this instance is the leader.
249300
@@ -844,13 +895,11 @@ async def forward_request(self, gateway: DbGateway, method: str, params: Optiona
844895

845896
# Update last seen timestamp
846897
gateway.last_seen = datetime.now(timezone.utc)
847-
848-
if "error" in result:
849-
raise GatewayError(f"Gateway error: {result['error'].get('message')}")
850-
return result.get("result")
851-
852-
except Exception as e:
853-
raise GatewayConnectionError(f"Failed to forward request to {gateway.name}: {str(e)}")
898+
except Exception:
899+
raise GatewayConnectionError(f"Failed to forward request to {gateway.name}")
900+
if "error" in result:
901+
raise GatewayError(f"Gateway error: {result['error'].get('message')}")
902+
return result.get("result")
854903

855904
async def _handle_gateway_failure(self, gateway: str) -> None:
856905
"""Tracks and handles gateway failures during health checks.
@@ -1115,9 +1164,10 @@ async def _initialize_gateway(self, url: str, authentication: Optional[Dict[str,
11151164
>>> import asyncio
11161165
>>> async def test_params():
11171166
... try:
1118-
... await service._initialize_gateway("invalid://url")
1167+
... await service._initialize_gateway("hello//")
11191168
... except Exception as e:
1120-
... return "Failed" in str(e) or "GatewayConnectionError" in str(type(e).__name__)
1169+
... return isinstance(e, GatewayConnectionError) or "Failed" in str(e)
1170+
11211171
>>> asyncio.run(test_params())
11221172
True
11231173
@@ -1172,21 +1222,23 @@ async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[s
11721222
# Store the context managers so they stay alive
11731223
decoded_auth = decode_auth(authentication)
11741224

1175-
# Use async with for both sse_client and ClientSession
1176-
async with sse_client(url=server_url, headers=decoded_auth) as streams:
1177-
async with ClientSession(*streams) as session:
1178-
# Initialize the session
1179-
response = await session.initialize()
1180-
capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
1225+
if await self._validate_gateway_url(url=server_url, headers=decoded_auth, transport_type="SSE"):
1226+
# Use async with for both sse_client and ClientSession
1227+
async with sse_client(url=server_url, headers=decoded_auth) as streams:
1228+
async with ClientSession(*streams) as session:
1229+
# Initialize the session
1230+
response = await session.initialize()
1231+
capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
11811232

1182-
response = await session.list_tools()
1183-
tools = response.tools
1184-
tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
1233+
response = await session.list_tools()
1234+
tools = response.tools
1235+
tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
11851236

1186-
tools = [ToolCreate.model_validate(tool) for tool in tools]
1187-
logger.info(f"{tools[0]=}")
1237+
tools = [ToolCreate.model_validate(tool) for tool in tools]
1238+
logger.info(f"{tools[0]=}")
11881239

1189-
return capabilities, tools
1240+
return capabilities, tools
1241+
raise GatewayConnectionError(f"Failed to initialize gateway at {url}")
11901242

11911243
async def connect_to_streamablehttp_server(server_url: str, authentication: Optional[Dict[str, str]] = None):
11921244
"""
@@ -1203,25 +1255,26 @@ async def connect_to_streamablehttp_server(server_url: str, authentication: Opti
12031255
authentication = {}
12041256
# Store the context managers so they stay alive
12051257
decoded_auth = decode_auth(authentication)
1206-
1207-
# Use async with for both streamablehttp_client and ClientSession
1208-
async with streamablehttp_client(url=server_url, headers=decoded_auth) as (read_stream, write_stream, _get_session_id):
1209-
async with ClientSession(read_stream, write_stream) as session:
1210-
# Initialize the session
1211-
response = await session.initialize()
1212-
# if get_session_id:
1213-
# session_id = get_session_id()
1214-
# if session_id:
1215-
# print(f"Session ID: {session_id}")
1216-
capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
1217-
response = await session.list_tools()
1218-
tools = response.tools
1219-
tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
1220-
tools = [ToolCreate.model_validate(tool) for tool in tools]
1221-
for tool in tools:
1222-
tool.request_type = "STREAMABLEHTTP"
1223-
1224-
return capabilities, tools
1258+
if await self._validate_gateway_url(url=server_url, headers=decoded_auth, transport_type="STREAMABLEHTTP"):
1259+
# Use async with for both streamablehttp_client and ClientSession
1260+
async with streamablehttp_client(url=server_url, headers=decoded_auth) as (read_stream, write_stream, _get_session_id):
1261+
async with ClientSession(read_stream, write_stream) as session:
1262+
# Initialize the session
1263+
response = await session.initialize()
1264+
# if get_session_id:
1265+
# session_id = get_session_id()
1266+
# if session_id:
1267+
# print(f"Session ID: {session_id}")
1268+
capabilities = response.capabilities.model_dump(by_alias=True, exclude_none=True)
1269+
response = await session.list_tools()
1270+
tools = response.tools
1271+
tools = [tool.model_dump(by_alias=True, exclude_none=True) for tool in tools]
1272+
tools = [ToolCreate.model_validate(tool) for tool in tools]
1273+
for tool in tools:
1274+
tool.request_type = "STREAMABLEHTTP"
1275+
1276+
return capabilities, tools
1277+
raise GatewayConnectionError(f"Failed to initialize gateway at {url}")
12251278

12261279
capabilities = {}
12271280
tools = []
@@ -1232,7 +1285,8 @@ async def connect_to_streamablehttp_server(server_url: str, authentication: Opti
12321285

12331286
return capabilities, tools
12341287
except Exception as e:
1235-
raise GatewayConnectionError(f"Failed to initialize gateway at {url}: {str(e)}")
1288+
logger.debug(f"Gateway initialization failed for {url}: {str(e)}", exc_info=True)
1289+
raise GatewayConnectionError(f"Failed to initialize gateway at {url}")
12361290

12371291
def _get_gateways(self, include_inactive: bool = True) -> list[DbGateway]:
12381292
"""Sync function for database operations (runs in thread).

mcpgateway/utils/verify_credentials.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ async def verify_jwt_token(token: str) -> dict:
145145

146146
# Log warning for non-expiring tokens
147147
if "exp" not in unverified:
148-
logger.warning("JWT token without expiration accepted. " "Consider enabling REQUIRE_TOKEN_EXPIRATION for better security. " f"Token sub: {unverified.get('sub', 'unknown')}")
148+
logger.warning(f"JWT token without expiration accepted. Consider enabling REQUIRE_TOKEN_EXPIRATION for better security. Token sub: {unverified.get('sub', 'unknown')}")
149149

150150
# Full validation
151151
options = {}

0 commit comments

Comments
 (0)