@@ -244,6 +244,57 @@ def __init__(self) -> None:
244
244
else :
245
245
self ._redis_client = None
246
246
247
+ async def _validate_gateway_url (self , url : str , headers : dict , transport_type : str , timeout : Optional [int ] = None ):
248
+ """
249
+ Validate if the given URL is a live Server-Sent Events (SSE) endpoint.
250
+
251
+ Args:
252
+ url (str): The full URL of the endpoint to validate.
253
+ headers (dict): Headers to be included in the requests (e.g., Authorization).
254
+ transport_type (str): SSE or STREAMABLEHTTP
255
+ timeout (int, optional): Timeout in seconds. Defaults to settings.gateway_validation_timeout.
256
+
257
+ Returns:
258
+ bool: True if the endpoint is reachable and supports SSE/StreamableHTTP, otherwise False.
259
+ """
260
+ if timeout is None :
261
+ timeout = settings .gateway_validation_timeout
262
+ validation_client = ResilientHttpClient (client_args = {"timeout" : settings .gateway_validation_timeout , "verify" : not settings .skip_ssl_verify })
263
+ try :
264
+ async with validation_client .client .stream ("GET" , url , headers = headers , timeout = timeout ) as response :
265
+ response_headers = dict (response .headers )
266
+ location = response_headers .get ("location" )
267
+ content_type = response_headers .get ("content-type" )
268
+ if response .status_code in (401 , 403 ):
269
+ logger .debug (f"Authentication failed for { url } with status { response .status_code } " )
270
+ return False
271
+
272
+ if transport_type == "STREAMABLEHTTP" :
273
+ if location :
274
+ async with validation_client .client .stream ("GET" , location , headers = headers , timeout = timeout ) as response_redirect :
275
+ response_headers = dict (response_redirect .headers )
276
+ mcp_session_id = response_headers .get ("mcp-session-id" )
277
+ content_type = response_headers .get ("content-type" )
278
+ if response_redirect .status_code in (401 , 403 ):
279
+ logger .debug (f"Authentication failed at redirect location { location } " )
280
+ return False
281
+ if mcp_session_id is not None and mcp_session_id != "" :
282
+ if content_type is not None and content_type != "" and "application/json" in content_type :
283
+ return True
284
+
285
+ elif transport_type == "SSE" :
286
+ if content_type is not None and content_type != "" and "text/event-stream" in content_type :
287
+ return True
288
+ return False
289
+ except httpx .UnsupportedProtocol as e :
290
+ logger .debug (f"Gateway URL Unsupported Protocol for { url } : { str (e )} " , exc_info = True )
291
+ return False
292
+ except Exception as e :
293
+ logger .debug (f"Gateway validation failed for { url } : { str (e )} " , exc_info = True )
294
+ return False
295
+ finally :
296
+ await validation_client .aclose ()
297
+
247
298
async def initialize (self ) -> None :
248
299
"""Initialize the service and start health check if this instance is the leader.
249
300
@@ -844,13 +895,11 @@ async def forward_request(self, gateway: DbGateway, method: str, params: Optiona
844
895
845
896
# Update last seen timestamp
846
897
gateway .last_seen = datetime .now (timezone .utc )
847
-
848
- if "error" in result :
849
- raise GatewayError (f"Gateway error: { result ['error' ].get ('message' )} " )
850
- return result .get ("result" )
851
-
852
- except Exception as e :
853
- raise GatewayConnectionError (f"Failed to forward request to { gateway .name } : { str (e )} " )
898
+ except Exception :
899
+ raise GatewayConnectionError (f"Failed to forward request to { gateway .name } " )
900
+ if "error" in result :
901
+ raise GatewayError (f"Gateway error: { result ['error' ].get ('message' )} " )
902
+ return result .get ("result" )
854
903
855
904
async def _handle_gateway_failure (self , gateway : str ) -> None :
856
905
"""Tracks and handles gateway failures during health checks.
@@ -1115,9 +1164,10 @@ async def _initialize_gateway(self, url: str, authentication: Optional[Dict[str,
1115
1164
>>> import asyncio
1116
1165
>>> async def test_params():
1117
1166
... try:
1118
- ... await service._initialize_gateway("invalid://url ")
1167
+ ... await service._initialize_gateway("hello// ")
1119
1168
... except Exception as e:
1120
- ... return "Failed" in str(e) or "GatewayConnectionError" in str(type(e).__name__)
1169
+ ... return isinstance(e, GatewayConnectionError) or "Failed" in str(e)
1170
+
1121
1171
>>> asyncio.run(test_params())
1122
1172
True
1123
1173
@@ -1172,21 +1222,23 @@ async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[s
1172
1222
# Store the context managers so they stay alive
1173
1223
decoded_auth = decode_auth (authentication )
1174
1224
1175
- # Use async with for both sse_client and ClientSession
1176
- async with sse_client (url = server_url , headers = decoded_auth ) as streams :
1177
- async with ClientSession (* streams ) as session :
1178
- # Initialize the session
1179
- response = await session .initialize ()
1180
- capabilities = response .capabilities .model_dump (by_alias = True , exclude_none = True )
1225
+ if await self ._validate_gateway_url (url = server_url , headers = decoded_auth , transport_type = "SSE" ):
1226
+ # Use async with for both sse_client and ClientSession
1227
+ async with sse_client (url = server_url , headers = decoded_auth ) as streams :
1228
+ async with ClientSession (* streams ) as session :
1229
+ # Initialize the session
1230
+ response = await session .initialize ()
1231
+ capabilities = response .capabilities .model_dump (by_alias = True , exclude_none = True )
1181
1232
1182
- response = await session .list_tools ()
1183
- tools = response .tools
1184
- tools = [tool .model_dump (by_alias = True , exclude_none = True ) for tool in tools ]
1233
+ response = await session .list_tools ()
1234
+ tools = response .tools
1235
+ tools = [tool .model_dump (by_alias = True , exclude_none = True ) for tool in tools ]
1185
1236
1186
- tools = [ToolCreate .model_validate (tool ) for tool in tools ]
1187
- logger .info (f"{ tools [0 ]= } " )
1237
+ tools = [ToolCreate .model_validate (tool ) for tool in tools ]
1238
+ logger .info (f"{ tools [0 ]= } " )
1188
1239
1189
- return capabilities , tools
1240
+ return capabilities , tools
1241
+ raise GatewayConnectionError (f"Failed to initialize gateway at { url } " )
1190
1242
1191
1243
async def connect_to_streamablehttp_server (server_url : str , authentication : Optional [Dict [str , str ]] = None ):
1192
1244
"""
@@ -1203,25 +1255,26 @@ async def connect_to_streamablehttp_server(server_url: str, authentication: Opti
1203
1255
authentication = {}
1204
1256
# Store the context managers so they stay alive
1205
1257
decoded_auth = decode_auth (authentication )
1206
-
1207
- # Use async with for both streamablehttp_client and ClientSession
1208
- async with streamablehttp_client (url = server_url , headers = decoded_auth ) as (read_stream , write_stream , _get_session_id ):
1209
- async with ClientSession (read_stream , write_stream ) as session :
1210
- # Initialize the session
1211
- response = await session .initialize ()
1212
- # if get_session_id:
1213
- # session_id = get_session_id()
1214
- # if session_id:
1215
- # print(f"Session ID: {session_id}")
1216
- capabilities = response .capabilities .model_dump (by_alias = True , exclude_none = True )
1217
- response = await session .list_tools ()
1218
- tools = response .tools
1219
- tools = [tool .model_dump (by_alias = True , exclude_none = True ) for tool in tools ]
1220
- tools = [ToolCreate .model_validate (tool ) for tool in tools ]
1221
- for tool in tools :
1222
- tool .request_type = "STREAMABLEHTTP"
1223
-
1224
- return capabilities , tools
1258
+ if await self ._validate_gateway_url (url = server_url , headers = decoded_auth , transport_type = "STREAMABLEHTTP" ):
1259
+ # Use async with for both streamablehttp_client and ClientSession
1260
+ async with streamablehttp_client (url = server_url , headers = decoded_auth ) as (read_stream , write_stream , _get_session_id ):
1261
+ async with ClientSession (read_stream , write_stream ) as session :
1262
+ # Initialize the session
1263
+ response = await session .initialize ()
1264
+ # if get_session_id:
1265
+ # session_id = get_session_id()
1266
+ # if session_id:
1267
+ # print(f"Session ID: {session_id}")
1268
+ capabilities = response .capabilities .model_dump (by_alias = True , exclude_none = True )
1269
+ response = await session .list_tools ()
1270
+ tools = response .tools
1271
+ tools = [tool .model_dump (by_alias = True , exclude_none = True ) for tool in tools ]
1272
+ tools = [ToolCreate .model_validate (tool ) for tool in tools ]
1273
+ for tool in tools :
1274
+ tool .request_type = "STREAMABLEHTTP"
1275
+
1276
+ return capabilities , tools
1277
+ raise GatewayConnectionError (f"Failed to initialize gateway at { url } " )
1225
1278
1226
1279
capabilities = {}
1227
1280
tools = []
@@ -1232,7 +1285,8 @@ async def connect_to_streamablehttp_server(server_url: str, authentication: Opti
1232
1285
1233
1286
return capabilities , tools
1234
1287
except Exception as e :
1235
- raise GatewayConnectionError (f"Failed to initialize gateway at { url } : { str (e )} " )
1288
+ logger .debug (f"Gateway initialization failed for { url } : { str (e )} " , exc_info = True )
1289
+ raise GatewayConnectionError (f"Failed to initialize gateway at { url } " )
1236
1290
1237
1291
def _get_gateways (self , include_inactive : bool = True ) -> list [DbGateway ]:
1238
1292
"""Sync function for database operations (runs in thread).
0 commit comments