Skip to content

Commit 6ccc204

Browse files
rakduttarakdutta1crivetimihai
authored
Duplicate Gateway Registration with Equivalent URLs Bypasses Uniqueness Check (#712)
* issue 649 Signed-off-by: RAKHI DUTTA <[email protected]> * Rebased and linted Signed-off-by: Mihai Criveti <[email protected]> --------- Signed-off-by: RAKHI DUTTA <[email protected]> Signed-off-by: Mihai Criveti <[email protected]> Co-authored-by: RAKHI DUTTA <[email protected]> Co-authored-by: Mihai Criveti <[email protected]>
1 parent 36898ef commit 6ccc204

File tree

3 files changed

+104
-58
lines changed

3 files changed

+104
-58
lines changed

mcpgateway/services/gateway_service.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@
4242
from datetime import datetime, timezone
4343
import logging
4444
import os
45+
import socket
4546
import tempfile
4647
from typing import Any, AsyncGenerator, Dict, List, Optional, Set, TYPE_CHECKING
48+
from urllib.parse import urlparse, urlunparse
4749
import uuid
4850

4951
# Third-Party
@@ -248,6 +250,33 @@ def __init__(self) -> None:
248250
else:
249251
self._redis_client = None
250252

253+
@staticmethod
254+
def normalize_url(url: str) -> str:
255+
"""
256+
Normalize a URL by resolving the hostname to its IP address.
257+
258+
Args:
259+
url (str): The URL to normalize.
260+
261+
Returns:
262+
str: The normalized URL with the hostname replaced by its IP address.
263+
264+
Examples:
265+
>>> GatewayService.normalize_url('http://localhost:8080/path')
266+
'http://127.0.0.1:8080/path'
267+
"""
268+
parsed = urlparse(url)
269+
hostname = parsed.hostname
270+
try:
271+
ip = socket.gethostbyname(hostname)
272+
except Exception:
273+
ip = hostname
274+
netloc = ip
275+
if parsed.port:
276+
netloc += f":{parsed.port}"
277+
normalized = parsed._replace(netloc=netloc)
278+
return urlunparse(normalized)
279+
251280
async def _validate_gateway_url(self, url: str, headers: dict, transport_type: str, timeout: Optional[int] = None):
252281
"""
253282
Validate if the given URL is a live Server-Sent Events (SSE) endpoint.
@@ -393,6 +422,9 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
393422
gateway_id=existing_gateway.id,
394423
)
395424

425+
# Normalize the gateway URL
426+
normalized_url = self.normalize_url(gateway.url)
427+
396428
auth_type = getattr(gateway, "auth_type", None)
397429
# Support multiple custom headers
398430
auth_value = getattr(gateway, "auth_value", {})
@@ -401,13 +433,13 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
401433
header_dict = {h["key"]: h["value"] for h in gateway.auth_headers if h.get("key")}
402434
auth_value = encode_auth(header_dict) # Encode the dict for consistency
403435

404-
capabilities, tools = await self._initialize_gateway(gateway.url, auth_value, gateway.transport)
436+
capabilities, tools = await self._initialize_gateway(normalized_url, auth_value, gateway.transport)
405437

406438
tools = [
407439
DbTool(
408440
original_name=tool.name,
409441
original_name_slug=slugify(tool.name),
410-
url=gateway.url,
442+
url=normalized_url,
411443
description=tool.description,
412444
integration_type="MCP", # Gateway-discovered tools are MCP type
413445
request_type=tool.request_type,
@@ -425,7 +457,7 @@ async def register_gateway(self, db: Session, gateway: GatewayCreate) -> Gateway
425457
db_gateway = DbGateway(
426458
name=gateway.name,
427459
slug=slugify(gateway.name),
428-
url=gateway.url,
460+
url=normalized_url,
429461
description=gateway.description,
430462
tags=gateway.tags,
431463
transport=gateway.transport,
@@ -566,7 +598,8 @@ async def update_gateway(self, db: Session, gateway_id: str, gateway_update: Gat
566598
gateway.name = gateway_update.name
567599
gateway.slug = slugify(gateway_update.name)
568600
if gateway_update.url is not None:
569-
gateway.url = gateway_update.url
601+
# Normalize the updated URL
602+
gateway.url = self.normalize_url(gateway_update.url)
570603
if gateway_update.description is not None:
571604
gateway.description = gateway_update.description
572605
if gateway_update.transport is not None:

tests/unit/mcpgateway/services/test_gateway_service.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import asyncio
2020
from datetime import datetime, timezone
2121
from unittest.mock import AsyncMock, MagicMock, Mock, patch
22+
import socket
2223

2324
# Third-Party
2425
import httpx
@@ -172,12 +173,13 @@ async def test_register_gateway(self, gateway_service, test_db, monkeypatch):
172173
)
173174
)
174175
gateway_service._notify_gateway_added = AsyncMock()
175-
176+
normalize_url = lambda url: f"http://{socket.gethostbyname(url)}/gateway"
177+
url = normalize_url("example.com")
176178
# Patch GatewayRead.model_validate to return a mock with .masked()
177179
mock_model = Mock()
178180
mock_model.masked.return_value = mock_model
179181
mock_model.name = "test_gateway"
180-
mock_model.url = "http://example.com/gateway"
182+
mock_model.url = url
181183
mock_model.description = "A test gateway"
182184

183185
monkeypatch.setattr(
@@ -187,7 +189,7 @@ async def test_register_gateway(self, gateway_service, test_db, monkeypatch):
187189

188190
gateway_create = GatewayCreate(
189191
name="test_gateway",
190-
url="http://example.com/gateway",
192+
url=url,
191193
description="A test gateway",
192194
)
193195

@@ -202,9 +204,10 @@ async def test_register_gateway(self, gateway_service, test_db, monkeypatch):
202204
# `result` is the same GatewayCreate instance because we stubbed
203205
# GatewayRead.model_validate → just check its fields:
204206
assert result.name == "test_gateway"
205-
assert result.url == "http://example.com/gateway"
207+
expected_url = url
208+
assert result.url == expected_url
206209
assert result.description == "A test gateway"
207-
210+
mock_model.url = expected_url
208211
@pytest.mark.asyncio
209212
async def test_register_gateway_name_conflict(self, gateway_service, mock_gateway, test_db):
210213
"""Trying to register a gateway whose *name* already exists raises a conflict error."""
@@ -229,7 +232,6 @@ async def test_register_gateway_name_conflict(self, gateway_service, mock_gatewa
229232
async def test_register_gateway_connection_error(self, gateway_service, test_db):
230233
"""Initial connection to the remote gateway fails and the error propagates."""
231234
test_db.execute = Mock(return_value=_make_execute_result(scalar=None))
232-
233235
# _initialize_gateway blows up before any DB work happens
234236
gateway_service._initialize_gateway = AsyncMock(side_effect=GatewayConnectionError("Failed to connect"))
235237

@@ -257,28 +259,39 @@ async def test_register_gateway_with_auth(self, gateway_service, test_db, monkey
257259
test_db.commit = Mock()
258260
test_db.refresh = Mock()
259261

262+
#url = f"http://{socket.gethostbyname('example.com')}/gateway"
263+
normalize_url = lambda url: f"http://{socket.gethostbyname(url)}/gateway"
264+
url = normalize_url("example.com")
265+
print(f"url:{url}")
260266
gateway_service._initialize_gateway = AsyncMock(
261267
return_value=(
262268
{
263-
"prompts": {"listChanged": True},
264269
"resources": {"listChanged": True},
265270
"tools": {"listChanged": True},
266271
},
267272
[],
268273
)
269274
)
275+
270276
gateway_service._notify_gateway_added = AsyncMock()
271277

272278
mock_model = Mock()
273279
mock_model.masked.return_value = mock_model
274280
mock_model.name = "auth_gateway"
281+
mock_model.url = url
275282

276283
monkeypatch.setattr(
277284
"mcpgateway.services.gateway_service.GatewayRead.model_validate",
278285
lambda x: mock_model,
279286
)
280287

281-
gateway_create = GatewayCreate(name="auth_gateway", url="http://example.com/gateway", description="Gateway with auth", auth_type="bearer", auth_token="test-token")
288+
gateway_create = GatewayCreate(
289+
name="auth_gateway",
290+
url=url,
291+
description="Gateway with auth",
292+
auth_type="bearer",
293+
auth_token="test-token"
294+
)
282295

283296
await gateway_service.register_gateway(test_db, gateway_create)
284297

@@ -973,16 +986,16 @@ async def test_update_gateway_url_change_with_tools(self, gateway_service, mock_
973986

974987
gateway_service._initialize_gateway = AsyncMock(return_value=({"tools": {"listChanged": True}}, new_tools))
975988
gateway_service._notify_gateway_updated = AsyncMock()
976-
977-
gateway_update = GatewayUpdate(url="http://example.com/new-url")
989+
url = GatewayService.normalize_url("http://example.com/new-url")
990+
gateway_update = GatewayUpdate(url=url)
978991

979992
mock_gateway_read = MagicMock()
980993
mock_gateway_read.masked.return_value = mock_gateway_read
981994

982995
with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read):
983996
await gateway_service.update_gateway(test_db, 1, gateway_update)
984997

985-
assert mock_gateway.url == "http://example.com/new-url"
998+
assert mock_gateway.url == url
986999
gateway_service._initialize_gateway.assert_called_once()
9871000
test_db.commit.assert_called_once()
9881001

@@ -997,8 +1010,8 @@ async def test_update_gateway_url_initialization_failure(self, gateway_service,
9971010
# Mock initialization failure
9981011
gateway_service._initialize_gateway = AsyncMock(side_effect=GatewayConnectionError("Connection failed"))
9991012
gateway_service._notify_gateway_updated = AsyncMock()
1000-
1001-
gateway_update = GatewayUpdate(url="http://example.com/bad-url")
1013+
url = GatewayService.normalize_url("http://example.com/bad-url")
1014+
gateway_update = GatewayUpdate(url=url)
10021015

10031016
mock_gateway_read = MagicMock()
10041017
mock_gateway_read.masked.return_value = mock_gateway_read
@@ -1007,7 +1020,7 @@ async def test_update_gateway_url_initialization_failure(self, gateway_service,
10071020
with patch("mcpgateway.services.gateway_service.GatewayRead.model_validate", return_value=mock_gateway_read):
10081021
await gateway_service.update_gateway(test_db, 1, gateway_update)
10091022

1010-
assert mock_gateway.url == "http://example.com/bad-url"
1023+
assert mock_gateway.url == url
10111024
test_db.commit.assert_called_once()
10121025

10131026
@pytest.mark.asyncio

0 commit comments

Comments
 (0)