Skip to content

Commit ae9eae2

Browse files
committed
updated validate_gateway_url - progress
Signed-off-by: Satya <[email protected]>
1 parent 8e02710 commit ae9eae2

File tree

4 files changed

+189
-31
lines changed

4 files changed

+189
-31
lines changed

.env.example

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,13 @@ MAX_PROMPT_SIZE=102400
240240
# Timeout for rendering prompt templates (in seconds)
241241
PROMPT_RENDER_TIMEOUT=10
242242

243+
#####################################
244+
# Gateway Validation
245+
#####################################
246+
247+
# Timeout for gateway validation (in seconds)
248+
GATEWAY_VALIDATION_TIMEOUT=5
249+
243250
#####################################
244251
# Health Checks
245252
#####################################

mcpgateway/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,9 @@ def _parse_federation_peers(cls, v):
280280
health_check_timeout: int = 10 # seconds
281281
unhealthy_threshold: int = 5 # after this many failures, mark as Offline
282282

283+
# Validation Gateway URL
284+
gateway_validation_timeout: int = 5 # seconds
285+
283286
filelock_name: str = "gateway_service_leader.lock"
284287

285288
# Default Roots

mcpgateway/services/gateway_service.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -244,47 +244,54 @@ def __init__(self) -> None:
244244
else:
245245
self._redis_client = None
246246

247-
async def _validate_gateway_url(self, url: str, headers: dict, transport_type: str, timeout=5):
247+
async def _validate_gateway_url(self, url: str, headers: dict, transport_type: str, timeout: Optional[int] = None):
248248
"""
249249
Validate if the given URL is a live Server-Sent Events (SSE) endpoint.
250250
251-
This function performs a GET request followed by a HEAD request to the provided URL
252-
to ensure the endpoint is reachable and returns a valid `Content-Type` header indicating
253-
Server-Sent Events (`text/event-stream`).
254-
255251
Args:
256252
url (str): The full URL of the endpoint to validate.
257253
headers (dict): Headers to be included in the requests (e.g., Authorization).
258254
transport_type (str): SSE or STREAMABLEHTTP
259-
timeout (int, optional): Timeout in seconds for both requests. Defaults to 5.
255+
timeout (int, optional): Timeout in seconds. Defaults to settings.gateway_validation_timeout.
260256
261257
Returns:
262-
bool: True if the endpoint is reachable and supports SSE (Content-Type is
263-
'text/event-stream'), otherwise False.
258+
bool: True if the endpoint is reachable and supports SSE/StreamableHTTP, otherwise False.
264259
"""
265-
async with httpx.AsyncClient() as client:
266-
timeout = httpx.Timeout(timeout)
267-
try:
268-
async with client.stream("GET", url, headers=headers, timeout=timeout) as response:
269-
response_headers = dict(response.headers)
270-
location = response_headers.get("location")
271-
content_type = response_headers.get("content-type")
272-
if transport_type == "STREAMABLEHTTP":
273-
if location:
274-
async with client.stream("GET", location, headers=headers, timeout=timeout) as response_redirect:
275-
response_headers = dict(response_redirect.headers)
276-
mcp_session_id = response_headers.get("mcp-session-id")
277-
content_type = response_headers.get("content-type")
278-
if mcp_session_id is not None and mcp_session_id != "":
279-
if content_type is not None and content_type != "" and "application/json" in content_type:
280-
return True
281-
282-
elif transport_type == "SSE":
283-
if content_type is not None and content_type != "" and "text/event-stream" in content_type:
284-
return True
260+
if timeout is None:
261+
timeout = settings.gateway_validation_timeout
262+
validation_client = ResilientHttpClient(client_args={"timeout": settings.gateway_validation_timeout, "verify": not settings.skip_ssl_verify})
263+
try:
264+
async with validation_client.client.stream("GET", url, headers=headers, timeout=timeout) as response:
265+
response_headers = dict(response.headers)
266+
location = response_headers.get("location")
267+
content_type = response_headers.get("content-type")
268+
if response.status_code in (401, 403):
269+
logger.debug(f"Authentication failed for {url} with status {response.status_code}")
285270
return False
286-
except Exception:
271+
272+
if transport_type == "STREAMABLEHTTP":
273+
if location:
274+
async with validation_client.client.stream("GET", location, headers=headers, timeout=timeout) as response_redirect:
275+
response_headers = dict(response_redirect.headers)
276+
mcp_session_id = response_headers.get("mcp-session-id")
277+
content_type = response_headers.get("content-type")
278+
if response_redirect.status_code in (401, 403):
279+
logger.debug(f"Authentication failed at redirect location {location}")
280+
return False
281+
if mcp_session_id is not None and mcp_session_id != "":
282+
if content_type is not None and content_type != "" and "application/json" in content_type:
283+
return True
284+
285+
elif transport_type == "SSE":
286+
if content_type is not None and content_type != "" and "text/event-stream" in content_type:
287+
return True
287288
return False
289+
except Exception as e:
290+
print(str(e))
291+
logger.debug(f"Gateway validation failed for {url}: {str(e)}", exc_info=True)
292+
return False
293+
finally:
294+
await validation_client.aclose()
288295

289296
async def initialize(self) -> None:
290297
"""Initialize the service and start health check if this instance is the leader.
@@ -1260,7 +1267,8 @@ async def connect_to_streamablehttp_server(server_url: str, authentication: Opti
12601267
capabilities, tools = await connect_to_streamablehttp_server(url, authentication)
12611268

12621269
return capabilities, tools
1263-
except Exception:
1270+
except Exception as e:
1271+
logger.debug(f"Gateway initialization failed for {url}: {str(e)}", exc_info=True)
12641272
raise GatewayConnectionError(f"Failed to initialize gateway at {url}")
12651273

12661274
def _get_gateways(self, include_inactive: bool = True) -> list[DbGateway]:

tests/unit/mcpgateway/services/test_gateway_service.py

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from unittest.mock import AsyncMock, MagicMock, Mock
2121

2222
# Third-Party
23+
import httpx
2324
import pytest
2425

2526
# First-Party
@@ -144,7 +145,6 @@ class TestGatewayService:
144145
# ────────────────────────────────────────────────────────────────────
145146
# REGISTER
146147
# ────────────────────────────────────────────────────────────────────
147-
148148
@pytest.mark.asyncio
149149
async def test_register_gateway(self, gateway_service, test_db):
150150
"""Successful gateway registration populates DB and returns data."""
@@ -231,6 +231,146 @@ async def test_register_gateway_connection_error(self, gateway_service, test_db)
231231

232232
assert "Failed to connect" in str(exc_info.value)
233233

234+
# ────────────────────────────────────────────────────────────────────
235+
# Validate Gateway URL Timeout
236+
# ────────────────────────────────────────────────────────────────────
237+
@pytest.mark.asyncio
238+
async def test_gateway_validate_timeout(self, gateway_service, monkeypatch):
239+
# creating a mock with a timeout error
240+
mock_stream = AsyncMock(side_effect=httpx.ReadTimeout("Timeout"))
241+
242+
mock_aclose = AsyncMock()
243+
244+
# Step 3: Mock client with .stream and .aclose
245+
mock_client_instance = MagicMock()
246+
mock_client_instance.stream = mock_stream
247+
mock_client_instance.aclose = mock_aclose
248+
249+
mock_http_client = MagicMock()
250+
mock_http_client.client = mock_client_instance
251+
mock_http_client.aclose = mock_aclose
252+
253+
monkeypatch.setattr("mcpgateway.services.gateway_service.ResilientHttpClient", MagicMock(return_value=mock_http_client))
254+
255+
result = await gateway_service._validate_gateway_url(url="http://example.com", headers={}, transport_type="SSE", timeout=2)
256+
257+
assert result is False
258+
259+
# ────────────────────────────────────────────────────────────────────
260+
# Validate Gateway URL SSL Verification
261+
# ────────────────────────────────────────────────────────────────────
262+
@pytest.mark.asyncio
263+
async def test_ssl_verification_bypass(self, gateway_service, monkeypatch):
264+
# TODO
265+
pass
266+
267+
# ────────────────────────────────────────────────────────────────────
268+
# Validate Gateway URL Auth Failure
269+
# ────────────────────────────────────────────────────────────────────
270+
@pytest.mark.asyncio
271+
async def test_validate_auth_failure(self, gateway_service, monkeypatch):
272+
# Mock the response object to be returned inside the async with block
273+
response_mock = MagicMock()
274+
response_mock.status_code = 401
275+
response_mock.headers = {"content-type": "text/event-stream"}
276+
277+
# Create an async context manager mock that returns response_mock
278+
stream_context = MagicMock()
279+
stream_context.__aenter__ = AsyncMock(return_value=response_mock)
280+
stream_context.__aexit__ = AsyncMock(return_value=None)
281+
282+
# Mock the AsyncClient to return this context manager from .stream()
283+
client_mock = MagicMock()
284+
client_mock.stream = AsyncMock(return_value=stream_context)
285+
client_mock.aclose = AsyncMock()
286+
287+
# Mock ResilientHttpClient to return this client
288+
resilient_client_mock = MagicMock()
289+
resilient_client_mock.client = client_mock
290+
resilient_client_mock.aclose = AsyncMock()
291+
292+
monkeypatch.setattr("mcpgateway.services.gateway_service.ResilientHttpClient", MagicMock(return_value=resilient_client_mock))
293+
294+
# Run the method
295+
result = await gateway_service._validate_gateway_url(url="http://example.com", headers={}, transport_type="SSE")
296+
297+
# Expect False due to 401
298+
assert result is False
299+
300+
# ────────────────────────────────────────────────────────────────────
301+
# Validate Gateway URL Connection Error
302+
# ────────────────────────────────────────────────────────────────────
303+
@pytest.mark.asyncio
304+
async def test_validate_connectivity_failure(self, gateway_service, monkeypatch):
305+
# Create an async context manager mock that raises ConnectError
306+
stream_context = AsyncMock()
307+
stream_context.__aenter__.side_effect = httpx.ConnectError("connection error")
308+
stream_context.__aexit__.return_value = AsyncMock()
309+
310+
# Mock client with .stream() and .aclose()
311+
mock_client = MagicMock()
312+
mock_client.stream.return_value = stream_context
313+
mock_client.aclose = AsyncMock()
314+
315+
# Patch ResilientHttpClient to return this mock client
316+
resilient_client_mock = MagicMock()
317+
resilient_client_mock.client = mock_client
318+
resilient_client_mock.aclose = AsyncMock()
319+
320+
monkeypatch.setattr("mcpgateway.services.gateway_service.ResilientHttpClient", MagicMock(return_value=resilient_client_mock))
321+
322+
# Call the method and assert result
323+
result = await gateway_service._validate_gateway_url(url="http://example.com", headers={}, transport_type="SSE")
324+
325+
assert result is False
326+
327+
# ────────────────────────────────────────────────────────────────────
328+
# Validate Gateway URL Bulk Connections Validation
329+
# ────────────────────────────────────────────────────────────────────
330+
@pytest.mark.asyncio
331+
async def test_bulk_concurrent_validation(self, gateway_service, monkeypatch):
332+
# TODO
333+
pass
334+
335+
# ───────────────────────────────────────────────────────────────────────────
336+
# Validate Gateway - StreamableHTTP with mcp-session-id & redirected-url
337+
# ───────────────────────────────────────────────────────────────────────────
338+
@pytest.mark.asyncio
339+
async def test_streamablehttp_redirect(self, gateway_service, monkeypatch):
340+
# Mock first response (redirect)
341+
first_response = MagicMock()
342+
first_response.status_code = 200
343+
first_response.headers = {"Location": "http://sampleredirected.com"}
344+
345+
first_cm = AsyncMock()
346+
first_cm.__aenter__.return_value = first_response
347+
first_cm.__aexit__.return_value = None
348+
349+
# Mock redirected response (final)
350+
redirected_response = MagicMock()
351+
redirected_response.status_code = 200
352+
redirected_response.headers = {"Mcp-Session-Id": "sample123", "Content-Type": "application/json"}
353+
354+
second_cm = AsyncMock()
355+
second_cm.__aenter__.return_value = redirected_response
356+
second_cm.__aexit__.return_value = None
357+
358+
# Mock ResilientHttpClient client.stream to return redirect chain
359+
client_mock = MagicMock()
360+
client_mock.stream = AsyncMock(side_effect=[first_cm, second_cm])
361+
client_mock.aclose = AsyncMock()
362+
363+
resilient_http_mock = MagicMock()
364+
resilient_http_mock.client = client_mock
365+
resilient_http_mock.aclose = AsyncMock()
366+
367+
monkeypatch.setattr("mcpgateway.services.gateway_service.ResilientHttpClient", MagicMock(return_value=resilient_http_mock))
368+
369+
result = await gateway_service._validate_gateway_url(url="http://example.com", headers={}, transport_type="STREAMABLEHTTP")
370+
# TODO
371+
# assert result is True
372+
pass
373+
234374
# ────────────────────────────────────────────────────────────────────
235375
# LIST / GET
236376
# ────────────────────────────────────────────────────────────────────

0 commit comments

Comments
 (0)