@@ -244,6 +244,37 @@ def __init__(self) -> None:
244
244
else :
245
245
self ._redis_client = None
246
246
247
+ async def _validate_gateway_url (self , url : str , headers : dict , timeout = 5 ):
248
+ """
249
+ Validate if the given URL is a live Server-Sent Events (SSE) endpoint.
250
+
251
+ This function performs a GET request followed by a HEAD request to the provided URL
252
+ to ensure the endpoint is reachable and returns a valid `Content-Type` header indicating
253
+ Server-Sent Events (`text/event-stream`).
254
+
255
+ Args:
256
+ url (str): The full URL of the endpoint to validate.
257
+ headers (dict): Headers to be included in the requests (e.g., Authorization).
258
+ timeout (int, optional): Timeout in seconds for both requests. Defaults to 5.
259
+
260
+ Returns:
261
+ bool: True if the endpoint is reachable and supports SSE (Content-Type is
262
+ 'text/event-stream'), otherwise False.
263
+ """
264
+ async with httpx .AsyncClient () as client :
265
+ timeout = httpx .Timeout (timeout )
266
+ try :
267
+ async with client .stream ("GET" , url , headers = headers , timeout = timeout ) as response :
268
+ response .raise_for_status ()
269
+ response_head = await client .request ("HEAD" , url , headers = headers , timeout = timeout )
270
+ response .raise_for_status ()
271
+ content_type = response_head .headers .get ("Content-Type" , "" )
272
+ if "text/event-stream" in content_type .lower ():
273
+ return True
274
+ return False
275
+ except Exception :
276
+ return False
277
+
247
278
async def initialize (self ) -> None :
248
279
"""Initialize the service and start health check if this instance is the leader.
249
280
@@ -830,13 +861,11 @@ async def forward_request(self, gateway: DbGateway, method: str, params: Optiona
830
861
831
862
# Update last seen timestamp
832
863
gateway .last_seen = datetime .now (timezone .utc )
833
-
834
- if "error" in result :
835
- raise GatewayError (f"Gateway error: { result ['error' ].get ('message' )} " )
836
- return result .get ("result" )
837
-
838
- except Exception as e :
839
- raise GatewayConnectionError (f"Failed to forward request to { gateway .name } : { str (e )} " )
864
+ except Exception :
865
+ raise GatewayConnectionError (f"Failed to forward request to { gateway .name } " )
866
+ if "error" in result :
867
+ raise GatewayError (f"Gateway error: { result ['error' ].get ('message' )} " )
868
+ return result .get ("result" )
840
869
841
870
async def _handle_gateway_failure (self , gateway : str ) -> None :
842
871
"""Tracks and handles gateway failures during health checks.
@@ -1158,21 +1187,23 @@ async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[s
1158
1187
# Store the context managers so they stay alive
1159
1188
decoded_auth = decode_auth (authentication )
1160
1189
1161
- # Use async with for both sse_client and ClientSession
1162
- async with sse_client (url = server_url , headers = decoded_auth ) as streams :
1163
- async with ClientSession (* streams ) as session :
1164
- # Initialize the session
1165
- response = await session .initialize ()
1166
- capabilities = response .capabilities .model_dump (by_alias = True , exclude_none = True )
1190
+ if await self ._validate_gateway_url (url = server_url , headers = decoded_auth ):
1191
+ # Use async with for both sse_client and ClientSession
1192
+ async with sse_client (url = server_url , headers = decoded_auth ) as streams :
1193
+ async with ClientSession (* streams ) as session :
1194
+ # Initialize the session
1195
+ response = await session .initialize ()
1196
+ capabilities = response .capabilities .model_dump (by_alias = True , exclude_none = True )
1167
1197
1168
- response = await session .list_tools ()
1169
- tools = response .tools
1170
- tools = [tool .model_dump (by_alias = True , exclude_none = True ) for tool in tools ]
1198
+ response = await session .list_tools ()
1199
+ tools = response .tools
1200
+ tools = [tool .model_dump (by_alias = True , exclude_none = True ) for tool in tools ]
1171
1201
1172
- tools = [ToolCreate .model_validate (tool ) for tool in tools ]
1173
- logger .info (f"{ tools [0 ]= } " )
1202
+ tools = [ToolCreate .model_validate (tool ) for tool in tools ]
1203
+ logger .info (f"{ tools [0 ]= } " )
1174
1204
1175
- return capabilities , tools
1205
+ return capabilities , tools
1206
+ raise GatewayConnectionError (f"Failed to initialize gateway at { url } " )
1176
1207
1177
1208
async def connect_to_streamablehttp_server (server_url : str , authentication : Optional [Dict [str , str ]] = None ):
1178
1209
"""
@@ -1217,8 +1248,8 @@ async def connect_to_streamablehttp_server(server_url: str, authentication: Opti
1217
1248
capabilities , tools = await connect_to_streamablehttp_server (url , authentication )
1218
1249
1219
1250
return capabilities , tools
1220
- except Exception as e :
1221
- raise GatewayConnectionError (f"Failed to initialize gateway at { url } : { str ( e ) } " )
1251
+ except Exception :
1252
+ raise GatewayConnectionError (f"Failed to initialize gateway at { url } " )
1222
1253
1223
1254
def _get_gateways (self , include_inactive : bool = True ) -> list [DbGateway ]:
1224
1255
"""Sync function for database operations (runs in thread).
0 commit comments