@@ -244,7 +244,7 @@ def __init__(self) -> None:
244
244
else :
245
245
self ._redis_client = None
246
246
247
- async def _validate_gateway_url (self , url : str , headers : dict , timeout = 5 ):
247
+ async def _validate_gateway_url (self , url : str , headers : dict , transport_type : str , timeout = 5 ):
248
248
"""
249
249
Validate if the given URL is a live Server-Sent Events (SSE) endpoint.
250
250
@@ -255,6 +255,7 @@ async def _validate_gateway_url(self, url: str, headers: dict, timeout=5):
255
255
Args:
256
256
url (str): The full URL of the endpoint to validate.
257
257
headers (dict): Headers to be included in the requests (e.g., Authorization).
258
+ transport_type (str): SSE or STREAMABLEHTTP
258
259
timeout (int, optional): Timeout in seconds for both requests. Defaults to 5.
259
260
260
261
Returns:
@@ -265,12 +266,22 @@ async def _validate_gateway_url(self, url: str, headers: dict, timeout=5):
265
266
timeout = httpx .Timeout (timeout )
266
267
try :
267
268
async with client .stream ("GET" , url , headers = headers , timeout = timeout ) as response :
268
- response .raise_for_status ()
269
- response_head = await client .request ("HEAD" , url , headers = headers , timeout = timeout )
270
- response .raise_for_status ()
271
- content_type = response_head .headers .get ("Content-Type" , "" )
272
- if "text/event-stream" in content_type .lower ():
273
- return True
269
+ response_headers = dict (response .headers )
270
+ location = response_headers .get ("location" )
271
+ content_type = response_headers .get ("content-type" )
272
+ if transport_type == "STREAMABLEHTTP" :
273
+ if location :
274
+ async with client .stream ("GET" , location , headers = headers , timeout = timeout ) as response_redirect :
275
+ response_headers = dict (response_redirect .headers )
276
+ mcp_session_id = response_headers .get ("mcp-session-id" )
277
+ content_type = response_headers .get ("content-type" )
278
+ if mcp_session_id is not None and mcp_session_id != "" :
279
+ if content_type is not None and content_type != "" and "application/json" in content_type :
280
+ return True
281
+
282
+ elif transport_type == "SSE" :
283
+ if content_type is not None and content_type != "" and "text/event-stream" in content_type :
284
+ return True
274
285
return False
275
286
except Exception :
276
287
return False
@@ -1187,7 +1198,7 @@ async def connect_to_sse_server(server_url: str, authentication: Optional[Dict[s
1187
1198
# Store the context managers so they stay alive
1188
1199
decoded_auth = decode_auth (authentication )
1189
1200
1190
- if await self ._validate_gateway_url (url = server_url , headers = decoded_auth ):
1201
+ if await self ._validate_gateway_url (url = server_url , headers = decoded_auth , transport_type = "SSE" ):
1191
1202
# Use async with for both sse_client and ClientSession
1192
1203
async with sse_client (url = server_url , headers = decoded_auth ) as streams :
1193
1204
async with ClientSession (* streams ) as session :
@@ -1220,25 +1231,26 @@ async def connect_to_streamablehttp_server(server_url: str, authentication: Opti
1220
1231
authentication = {}
1221
1232
# Store the context managers so they stay alive
1222
1233
decoded_auth = decode_auth (authentication )
1234
+ if await self ._validate_gateway_url (url = server_url , headers = decoded_auth , transport_type = "STREAMABLEHTTP" ):
1235
+ # Use async with for both streamablehttp_client and ClientSession
1236
+ async with streamablehttp_client (url = server_url , headers = decoded_auth ) as (read_stream , write_stream , _get_session_id ):
1237
+ async with ClientSession (read_stream , write_stream ) as session :
1238
+ # Initialize the session
1239
+ response = await session .initialize ()
1240
+ # if get_session_id:
1241
+ # session_id = get_session_id()
1242
+ # if session_id:
1243
+ # print(f"Session ID: {session_id}")
1244
+ capabilities = response .capabilities .model_dump (by_alias = True , exclude_none = True )
1245
+ response = await session .list_tools ()
1246
+ tools = response .tools
1247
+ tools = [tool .model_dump (by_alias = True , exclude_none = True ) for tool in tools ]
1248
+ tools = [ToolCreate .model_validate (tool ) for tool in tools ]
1249
+ for tool in tools :
1250
+ tool .request_type = "STREAMABLEHTTP"
1223
1251
1224
- # Use async with for both streamablehttp_client and ClientSession
1225
- async with streamablehttp_client (url = server_url , headers = decoded_auth ) as (read_stream , write_stream , _get_session_id ):
1226
- async with ClientSession (read_stream , write_stream ) as session :
1227
- # Initialize the session
1228
- response = await session .initialize ()
1229
- # if get_session_id:
1230
- # session_id = get_session_id()
1231
- # if session_id:
1232
- # print(f"Session ID: {session_id}")
1233
- capabilities = response .capabilities .model_dump (by_alias = True , exclude_none = True )
1234
- response = await session .list_tools ()
1235
- tools = response .tools
1236
- tools = [tool .model_dump (by_alias = True , exclude_none = True ) for tool in tools ]
1237
- tools = [ToolCreate .model_validate (tool ) for tool in tools ]
1238
- for tool in tools :
1239
- tool .request_type = "STREAMABLEHTTP"
1240
-
1241
- return capabilities , tools
1252
+ return capabilities , tools
1253
+ raise GatewayConnectionError (f"Failed to initialize gateway at { url } " )
1242
1254
1243
1255
capabilities = {}
1244
1256
tools = []
0 commit comments