Skip to content

Commit ea4e3aa

Browse files
authored
Validate RPC method for XSS (#576)
* Validate prompt arguments for XSS Signed-off-by: Madhav Kandukuri <[email protected]> * copied from check-prompt-args Signed-off-by: Madhav Kandukuri <[email protected]> * Fix tests Signed-off-by: Madhav Kandukuri <[email protected]> * flake8 fix Signed-off-by: Madhav Kandukuri <[email protected]> * Remove commented code Signed-off-by: Madhav Kandukuri <[email protected]> * Minor changes that fix smoketest Signed-off-by: Madhav Kandukuri <[email protected]> --------- Signed-off-by: Madhav Kandukuri <[email protected]>
1 parent f25ec11 commit ea4e3aa

File tree

5 files changed

+26
-20
lines changed

5 files changed

+26
-20
lines changed

mcpgateway/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,7 @@ def validate_database(self) -> None:
506506
validation_safe_uri_pattern: str = r"^[a-zA-Z0-9_\-.:/?=&%]+$"
507507
validation_unsafe_uri_pattern: str = r'[<>"\'\\]'
508508
validation_tool_name_pattern: str = r"^[a-zA-Z][a-zA-Z0-9._-]*$" # MCP tool naming
509+
validation_tool_method_pattern: str = r"^[a-zA-Z][a-zA-Z0-9_\./-]*$"
509510

510511
# MCP-compliant size limits (configurable via env)
511512
validation_max_name_length: int = 255
@@ -516,6 +517,8 @@ def validate_database(self) -> None:
516517
validation_max_url_length: int = 2048
517518
validation_max_rpc_param_size: int = 262144 # 256KB
518519

520+
validation_max_method_length: int = 128
521+
519522
# Allowed MIME types
520523
validation_allowed_mime_types: List[str] = [
521524
"text/plain",

mcpgateway/main.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
ResourceCreate,
8686
ResourceRead,
8787
ResourceUpdate,
88+
RPCRequest,
8889
ServerCreate,
8990
ServerRead,
9091
ServerUpdate,
@@ -131,7 +132,6 @@
131132
from mcpgateway.utils.verify_credentials import require_auth, require_auth_override
132133
from mcpgateway.validation.jsonrpc import (
133134
JSONRPCError,
134-
validate_request,
135135
)
136136

137137
# Import the admin routes from the new module
@@ -1970,12 +1970,13 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user: str
19701970
try:
19711971
logger.debug(f"User {user} made an RPC request")
19721972
body = await request.json()
1973-
validate_request(body)
19741973
method = body["method"]
19751974
# rpc_id = body.get("id")
19761975
params = body.get("params", {})
19771976
cursor = params.get("cursor") # Extract cursor parameter
19781977

1978+
RPCRequest(jsonrpc="2.0", method=method, params=params) # Validate the request body against the RPCRequest model
1979+
19791980
if method == "tools/list":
19801981
tools = await tool_service.list_tools(db, cursor=cursor)
19811982
result = [t.model_dump(by_alias=True, exclude_none=True) for t in tools]
@@ -2030,6 +2031,8 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user: str
20302031
except JSONRPCError as e:
20312032
return e.to_dict()
20322033
except Exception as e:
2034+
if isinstance(e, ValueError):
2035+
return JSONResponse(content={"message": "Method invalid"}, status_code=422)
20332036
logger.error(f"RPC error: {str(e)}")
20342037
return {
20352038
"jsonrpc": "2.0",

mcpgateway/schemas.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2122,9 +2122,10 @@ def validate_method(cls, v: str) -> str:
21222122
Raises:
21232123
ValueError: When value is not safe
21242124
"""
2125-
if not re.match(r"^[a-zA-Z][a-zA-Z0-9_\.]*$", v):
2125+
SecurityValidator.validate_no_xss(v, "RPC method name")
2126+
if not re.match(settings.validation_tool_method_pattern, v):
21262127
raise ValueError("Invalid method name format")
2127-
if len(v) > 128: # MCP method name limit
2128+
if len(v) > settings.validation_max_method_length:
21282129
raise ValueError("Method name too long")
21292130
return v
21302131

tests/security/test_input_validation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -802,7 +802,6 @@ def test_rpc_request_validation(self):
802802
# Invalid method names
803803
invalid_methods = [
804804
"method with spaces",
805-
"method-with-dash",
806805
"method@special",
807806
"<script>alert('XSS')</script>",
808807
"9method", # Starts with number

tests/unit/mcpgateway/test_main.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,9 @@ def test_static_files(self, test_client):
239239
class TestProtocolEndpoints:
240240
"""Tests for MCP protocol operations: initialize, ping, notifications, etc."""
241241

242-
@patch("mcpgateway.main.validate_request")
242+
# @patch("mcpgateway.main.validate_request")
243243
@patch("mcpgateway.main.session_registry.handle_initialize_logic")
244-
def test_initialize_endpoint(self, mock_handle_initialize, _mock_validate, test_client, auth_headers):
244+
def test_initialize_endpoint(self, mock_handle_initialize, test_client, auth_headers):
245245
"""Test MCP protocol initialization."""
246246
mock_capabilities = ServerCapabilities(
247247
prompts={"listChanged": True},
@@ -271,8 +271,8 @@ def test_initialize_endpoint(self, mock_handle_initialize, _mock_validate, test_
271271
assert body["protocolVersion"] == PROTOCOL_VERSION
272272
mock_handle_initialize.assert_called_once()
273273

274-
@patch("mcpgateway.main.validate_request")
275-
def test_ping_endpoint(self, _mock_validate, test_client, auth_headers):
274+
# @patch("mcpgateway.main.validate_request")
275+
def test_ping_endpoint(self, test_client, auth_headers):
276276
"""Test MCP ping endpoint."""
277277
req = {"jsonrpc": "2.0", "method": "ping", "id": "test-id"}
278278
response = test_client.post("/protocol/ping", json=req, headers=auth_headers)
@@ -807,8 +807,8 @@ def test_rpc_tool_invocation(self, mock_invoke_tool, test_client, auth_headers):
807807
mock_invoke_tool.assert_called_once_with(db=ANY, name="test_tool", arguments={"param": "value"})
808808

809809
@patch("mcpgateway.main.prompt_service.get_prompt")
810-
@patch("mcpgateway.main.validate_request")
811-
def test_rpc_prompt_get(self, _mock_validate, mock_get_prompt, test_client, auth_headers):
810+
# @patch("mcpgateway.main.validate_request")
811+
def test_rpc_prompt_get(self, mock_get_prompt, test_client, auth_headers):
812812
"""Test prompt retrieval via JSON-RPC."""
813813
mock_get_prompt.return_value = {
814814
"messages": [{"role": "user", "content": {"type": "text", "text": "Rendered prompt"}}],
@@ -829,8 +829,8 @@ def test_rpc_prompt_get(self, _mock_validate, mock_get_prompt, test_client, auth
829829
mock_get_prompt.assert_called_once_with(ANY, "test_prompt", {"param": "value"})
830830

831831
@patch("mcpgateway.main.tool_service.list_tools")
832-
@patch("mcpgateway.main.validate_request")
833-
def test_rpc_list_tools(self, _mock_validate, mock_list_tools, test_client, auth_headers):
832+
# @patch("mcpgateway.main.validate_request")
833+
def test_rpc_list_tools(self, mock_list_tools, test_client, auth_headers):
834834
"""Test listing tools via JSON-RPC."""
835835
mock_tool = MagicMock()
836836
mock_tool.model_dump.return_value = MOCK_TOOL_READ
@@ -849,26 +849,26 @@ def test_rpc_list_tools(self, _mock_validate, mock_list_tools, test_client, auth
849849
assert isinstance(body, list)
850850
mock_list_tools.assert_called_once()
851851

852-
@patch("mcpgateway.main.validate_request")
853-
def test_rpc_invalid_request(self, mock_validate, test_client, auth_headers):
852+
@patch("mcpgateway.main.RPCRequest")
853+
def test_rpc_invalid_request(self, mock_rpc_request, test_client, auth_headers):
854854
"""Test RPC error handling for invalid requests."""
855-
mock_validate.side_effect = Exception("Invalid request")
855+
mock_rpc_request.side_effect = ValueError("Invalid method")
856856

857857
req = {"jsonrpc": "1.0", "id": "test-id", "method": "invalid_method"}
858858
response = test_client.post("/rpc/", json=req, headers=auth_headers)
859859

860-
assert response.status_code == 200
860+
assert response.status_code == 422
861861
body = response.json()
862-
assert "error" in body and "Invalid request" in body["error"]["data"]
862+
assert "Method invalid" in body.get("message")
863863

864864
def test_rpc_invalid_json(self, test_client, auth_headers):
865865
"""Test RPC error handling for malformed JSON."""
866866
headers = auth_headers
867867
headers["content-type"] = "application/json"
868868
response = test_client.post("/rpc/", content="invalid json", headers=headers)
869-
assert response.status_code == 200 # Returns error response, not HTTP error
869+
assert response.status_code == 422 # Returns error response, not HTTP error
870870
body = response.json()
871-
assert "error" in body
871+
assert "Method invalid" in body.get("message")
872872

873873
@patch("mcpgateway.main.logging_service.set_level")
874874
def test_set_log_level_endpoint(self, mock_set_level, test_client, auth_headers):

0 commit comments

Comments
 (0)