Skip to content

Commit 9ecd07d

Browse files
authored
Enhanced Validation Missing in GatewayCreate closes #694 (#695)
Signed-off-by: Mihai Criveti <[email protected]>
1 parent 3be2628 commit 9ecd07d

File tree

2 files changed

+127
-31
lines changed

2 files changed

+127
-31
lines changed

mcpgateway/schemas.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,20 +1818,7 @@ def create_auth_value(cls, v, info):
18181818
if (auth_type is None) or (auth_type == ""):
18191819
return v # If no auth_type is provided, no need to create auth_value
18201820

1821-
# If custom headers, use all headers
1822-
if auth_type == "authheaders":
1823-
auth_headers = data.get("auth_headers")
1824-
if auth_headers and isinstance(auth_headers, list):
1825-
# Convert list of {key, value} to dict
1826-
header_dict = {h["key"]: h["value"] for h in auth_headers if h.get("key")}
1827-
return encode_auth(header_dict)
1828-
# Fallback to old single key/value
1829-
header_key = data.get("auth_header_key")
1830-
header_value = data.get("auth_header_value")
1831-
if header_key and header_value:
1832-
return encode_auth({header_key: header_value})
1833-
1834-
# Otherwise, use the default logic
1821+
# Process the auth fields and generate auth_value based on auth_type
18351822
auth_value = cls._process_auth_fields(info)
18361823
return auth_value
18371824

@@ -1902,10 +1889,44 @@ def _process_auth_fields(info: ValidationInfo) -> Optional[Dict[str, Any]]:
19021889
# Support both new multi-headers format and legacy single header format
19031890
auth_headers = data.get("auth_headers")
19041891
if auth_headers and isinstance(auth_headers, list):
1905-
# New multi-headers format
1906-
header_dict = {h["key"]: h["value"] for h in auth_headers if h.get("key")}
1892+
# New multi-headers format with enhanced validation
1893+
header_dict = {}
1894+
duplicate_keys = set()
1895+
1896+
for header in auth_headers:
1897+
if not isinstance(header, dict):
1898+
continue
1899+
1900+
key = header.get("key")
1901+
value = header.get("value", "")
1902+
1903+
# Skip headers without keys
1904+
if not key:
1905+
continue
1906+
1907+
# Track duplicate keys (last value wins)
1908+
if key in header_dict:
1909+
duplicate_keys.add(key)
1910+
1911+
# Validate header key format (basic HTTP header validation)
1912+
if not all(c.isalnum() or c in "-_" for c in key.replace(" ", "")):
1913+
raise ValueError(f"Invalid header key format: '{key}'. Header keys should contain only alphanumeric characters, hyphens, and underscores.")
1914+
1915+
# Store header (empty values are allowed)
1916+
header_dict[key] = value
1917+
1918+
# Ensure at least one valid header
19071919
if not header_dict:
1908-
raise ValueError("For 'headers' auth, at least one header must be provided.")
1920+
raise ValueError("For 'headers' auth, at least one valid header with a key must be provided.")
1921+
1922+
# Warn about duplicate keys (optional - could log this instead)
1923+
if duplicate_keys:
1924+
logging.warning(f"Duplicate header keys detected (last value used): {', '.join(duplicate_keys)}")
1925+
1926+
# Check for excessive headers (prevent abuse)
1927+
if len(header_dict) > 100:
1928+
raise ValueError("Maximum of 100 headers allowed per gateway.")
1929+
19091930
return encode_auth(header_dict)
19101931

19111932
# Legacy single header format (backward compatibility)
@@ -2027,17 +2048,7 @@ def create_auth_value(cls, v, info):
20272048
if (auth_type is None) or (auth_type == ""):
20282049
return v # If no auth_type is provided, no need to create auth_value
20292050

2030-
# If custom headers, use all headers
2031-
if auth_type == "authheaders":
2032-
auth_headers = data.get("auth_headers")
2033-
if auth_headers and isinstance(auth_headers, list):
2034-
header_dict = {h["key"]: h["value"] for h in auth_headers if h.get("key")}
2035-
return encode_auth(header_dict)
2036-
header_key = data.get("auth_header_key")
2037-
header_value = data.get("auth_header_value")
2038-
if header_key and header_value:
2039-
return encode_auth({header_key: header_value})
2040-
2051+
# Process the auth fields and generate auth_value based on auth_type
20412052
auth_value = cls._process_auth_fields(info)
20422053
return auth_value
20432054

tests/unit/mcpgateway/test_multi_auth_headers.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,25 @@ async def test_gateway_update_add_multi_headers(self):
111111
assert decoded["X-New-Header"] == "new_value"
112112

113113
@pytest.mark.asyncio
114-
async def test_special_characters_in_headers(self):
115-
"""Test headers with special characters."""
114+
async def test_special_characters_in_headers_rejected(self):
115+
"""Test headers with invalid special characters are rejected."""
116116
auth_headers = [{"key": "X-Special-!@#", "value": "value-with-特殊字符"}, {"key": "Content-Type", "value": "application/json; charset=utf-8"}]
117117

118+
with pytest.raises(ValidationError) as exc_info:
119+
GatewayCreate(name="Test Gateway", url="http://example.com", auth_type="authheaders", auth_headers=auth_headers)
120+
121+
assert "Invalid header key format" in str(exc_info.value)
122+
assert "X-Special-!@#" in str(exc_info.value)
123+
124+
@pytest.mark.asyncio
125+
async def test_valid_special_characters_in_values(self):
126+
"""Test headers with special characters in values (allowed) but valid keys."""
127+
auth_headers = [{"key": "X-Special-Header", "value": "value-with-特殊字符"}, {"key": "Content-Type", "value": "application/json; charset=utf-8"}]
128+
118129
gateway = GatewayCreate(name="Test Gateway", url="http://example.com", auth_type="authheaders", auth_headers=auth_headers)
119130

120131
decoded = decode_auth(gateway.auth_value)
121-
assert decoded["X-Special-!@#"] == "value-with-特殊字符"
132+
assert decoded["X-Special-Header"] == "value-with-特殊字符"
122133
assert decoded["Content-Type"] == "application/json; charset=utf-8"
123134

124135
@pytest.mark.asyncio
@@ -169,3 +180,77 @@ async def test_authorization_header_in_multi_headers(self):
169180
decoded = decode_auth(gateway.auth_value)
170181
assert decoded["Authorization"] == "Bearer token123"
171182
assert decoded["X-API-Key"] == "secret"
183+
184+
@pytest.mark.asyncio
185+
async def test_gateway_create_invalid_header_key_format(self):
186+
"""Test creating gateway with invalid header key format."""
187+
auth_headers = [{"key": "Invalid@Key!", "value": "secret123"}]
188+
189+
with pytest.raises(ValidationError) as exc_info:
190+
GatewayCreate(name="Test Gateway", url="http://example.com", auth_type="authheaders", auth_headers=auth_headers)
191+
192+
assert "Invalid header key format" in str(exc_info.value)
193+
194+
@pytest.mark.asyncio
195+
async def test_gateway_create_excessive_headers(self):
196+
"""Test creating gateway with more than 100 headers."""
197+
auth_headers = [{"key": f"X-Header-{i}", "value": f"value-{i}"} for i in range(101)]
198+
199+
with pytest.raises(ValidationError) as exc_info:
200+
GatewayCreate(name="Test Gateway", url="http://example.com", auth_type="authheaders", auth_headers=auth_headers)
201+
202+
assert "Maximum of 100 headers allowed" in str(exc_info.value)
203+
204+
@pytest.mark.asyncio
205+
async def test_gateway_create_duplicate_keys_with_warning(self, caplog):
206+
"""Test creating gateway with duplicate header keys logs warning."""
207+
auth_headers = [
208+
{"key": "X-API-Key", "value": "first_value"},
209+
{"key": "X-API-Key", "value": "second_value"}, # Duplicate
210+
{"key": "X-Client-ID", "value": "client123"}
211+
]
212+
213+
gateway = GatewayCreate(name="Test Gateway", url="http://example.com", auth_type="authheaders", auth_headers=auth_headers)
214+
215+
# Check that duplicate warning was logged
216+
assert "Duplicate header keys detected" in caplog.text
217+
assert "X-API-Key" in caplog.text
218+
219+
# Check that last value wins
220+
decoded = decode_auth(gateway.auth_value)
221+
assert decoded["X-API-Key"] == "second_value"
222+
assert decoded["X-Client-ID"] == "client123"
223+
224+
@pytest.mark.asyncio
225+
async def test_gateway_create_mixed_valid_invalid_keys(self):
226+
"""Test creating gateway with mixed valid and invalid header keys."""
227+
auth_headers = [
228+
{"key": "Valid-Header", "value": "test123"},
229+
{"key": "Invalid@Key!", "value": "should_fail"} # This should fail validation
230+
]
231+
232+
with pytest.raises(ValidationError) as exc_info:
233+
GatewayCreate(name="Test Gateway", url="http://example.com", auth_type="authheaders", auth_headers=auth_headers)
234+
235+
assert "Invalid header key format" in str(exc_info.value)
236+
assert "Invalid@Key!" in str(exc_info.value)
237+
238+
@pytest.mark.asyncio
239+
async def test_gateway_create_edge_case_header_keys(self):
240+
"""Test creating gateway with edge case header keys."""
241+
# Test valid edge cases
242+
auth_headers = [
243+
{"key": "X-API-Key", "value": "test1"}, # Standard format
244+
{"key": "X_API_KEY", "value": "test2"}, # Underscores allowed
245+
{"key": "API-Key-123", "value": "test3"}, # Numbers and hyphens
246+
{"key": "UPPERCASE", "value": "test4"}, # Uppercase
247+
{"key": "lowercase", "value": "test5"} # Lowercase
248+
]
249+
250+
gateway = GatewayCreate(name="Test Gateway", url="http://example.com", auth_type="authheaders", auth_headers=auth_headers)
251+
252+
decoded = decode_auth(gateway.auth_value)
253+
assert len(decoded) == 5
254+
assert decoded["X-API-Key"] == "test1"
255+
assert decoded["X_API_KEY"] == "test2"
256+
assert decoded["API-Key-123"] == "test3"

0 commit comments

Comments
 (0)