Skip to content

Commit 10dfec9

Browse files
authored
753 fix tool invocation invalid method (#754)
* Fix tool invocation 'Invalid method' error with backward compatibility (#753) - Add backward compatibility for direct tool invocation (pre-PR #746 format) - Support both old format (method=tool_name) and new format (method=tools/call) - Add comprehensive test coverage for RPC tool invocation scenarios - Ensure graceful fallback to gateway forwarding when method is not a tool The RPC endpoint now handles tool invocations in both formats: 1. New format: method='tools/call' with name and arguments in params 2. Old format: method='tool_name' with params as arguments (backward compat) This maintains compatibility with existing clients while supporting the new standardized RPC method structure introduced in PR #746. Signed-off-by: Mihai Criveti <[email protected]> * Fix flake8 E722: Replace bare except with Exception Signed-off-by: Mihai Criveti <[email protected]> * lint Signed-off-by: Mihai Criveti <[email protected]> --------- Signed-off-by: Mihai Criveti <[email protected]>
1 parent e7dcd88 commit 10dfec9

File tree

4 files changed

+365
-9
lines changed

4 files changed

+365
-9
lines changed

mcpgateway/main.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2318,7 +2318,22 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user: str
23182318
elif method.startswith("logging/"):
23192319
result = {}
23202320
else:
2321-
raise JSONRPCError(-32000, "Invalid method", params)
2321+
# Backward compatibility: Try to invoke as a tool directly
2322+
# This allows both old format (method=tool_name) and new format (method=tools/call)
2323+
headers = {k.lower(): v for k, v in request.headers.items()}
2324+
try:
2325+
result = await tool_service.invoke_tool(db=db, name=method, arguments=params, request_headers=headers)
2326+
if hasattr(result, "model_dump"):
2327+
result = result.model_dump(by_alias=True, exclude_none=True)
2328+
except (ValueError, Exception):
2329+
# If not a tool, try forwarding to gateway
2330+
try:
2331+
result = await gateway_service.forward_request(db, method, params)
2332+
if hasattr(result, "model_dump"):
2333+
result = result.model_dump(by_alias=True, exclude_none=True)
2334+
except Exception:
2335+
# If all else fails, return invalid method error
2336+
raise JSONRPCError(-32000, "Invalid method", params)
23222337

23232338
return {"jsonrpc": "2.0", "result": result, "id": req_id}
23242339

tests/unit/mcpgateway/cache/test_session_registry.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ async def __aenter__(self):
478478

479479
async def __aexit__(self, exc_type, exc_val, exc_tb):
480480
return None
481-
481+
482482
with patch(
483483
"mcpgateway.cache.session_registry.ResilientHttpClient",
484484
MockAsyncClient
@@ -523,7 +523,7 @@ async def __aenter__(self):
523523

524524
async def __aexit__(self, exc_type, exc_val, exc_tb):
525525
return None
526-
526+
527527
with patch(
528528
"mcpgateway.cache.session_registry.ResilientHttpClient",
529529
MockAsyncClient
@@ -562,7 +562,7 @@ async def __aenter__(self):
562562

563563
async def __aexit__(self, exc_type, exc_val, exc_tb):
564564
return None
565-
565+
566566
with patch(
567567
"mcpgateway.cache.session_registry.ResilientHttpClient",
568568
MockAsyncClient
@@ -604,7 +604,7 @@ async def __aenter__(self):
604604

605605
async def __aexit__(self, exc_type, exc_val, exc_tb):
606606
return None
607-
607+
608608
with patch(
609609
"mcpgateway.cache.session_registry.ResilientHttpClient",
610610
MockAsyncClient
@@ -645,7 +645,7 @@ async def __aenter__(self):
645645

646646
async def __aexit__(self, exc_type, exc_val, exc_tb):
647647
return None
648-
648+
649649
with patch(
650650
"mcpgateway.cache.session_registry.ResilientHttpClient",
651651
MockAsyncClient
@@ -725,7 +725,7 @@ async def __aenter__(self):
725725

726726
async def __aexit__(self, exc_type, exc_val, exc_tb):
727727
return None
728-
728+
729729
with patch(
730730
"mcpgateway.cache.session_registry.ResilientHttpClient",
731731
MockAsyncClient
@@ -766,7 +766,7 @@ async def __aenter__(self):
766766

767767
async def __aexit__(self, exc_type, exc_val, exc_tb):
768768
return None
769-
769+
770770
with patch(
771771
"mcpgateway.cache.session_registry.ResilientHttpClient",
772772
MockAsyncClient
@@ -807,7 +807,7 @@ async def __aenter__(self):
807807

808808
async def __aexit__(self, exc_type, exc_val, exc_tb):
809809
return None
810-
810+
811811
with patch(
812812
"mcpgateway.cache.session_registry.ResilientHttpClient",
813813
MockAsyncClient
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# -*- coding: utf-8 -*-
2+
"""Test backward compatibility for tool invocation after PR #746."""
3+
4+
from unittest.mock import AsyncMock, MagicMock, patch
5+
6+
import pytest
7+
from fastapi.testclient import TestClient
8+
from sqlalchemy.orm import Session
9+
10+
from mcpgateway.main import app
11+
12+
13+
@pytest.fixture
14+
def client():
15+
"""Create a test client."""
16+
return TestClient(app)
17+
18+
19+
@pytest.fixture
20+
def mock_db():
21+
"""Create a mock database session."""
22+
return MagicMock(spec=Session)
23+
24+
25+
class TestRPCBackwardCompatibility:
26+
"""Test backward compatibility for RPC tool invocation."""
27+
28+
def test_old_format_tool_invocation_with_backward_compatibility(self, client, mock_db):
29+
"""Test that old format (direct tool name as method) still works with backward compatibility."""
30+
with patch("mcpgateway.config.settings.auth_required", False):
31+
with patch("mcpgateway.main.get_db", return_value=mock_db):
32+
with patch("mcpgateway.main.tool_service.invoke_tool", new_callable=AsyncMock) as mock_invoke:
33+
mock_invoke.return_value = {"result": "success", "data": "test data from old format"}
34+
35+
# Old format: tool name directly as method
36+
request_body = {"jsonrpc": "2.0", "method": "my_custom_tool", "params": {"query": "test query", "limit": 10}, "id": 123}
37+
38+
response = client.post("/rpc", json=request_body)
39+
40+
assert response.status_code == 200
41+
result = response.json()
42+
assert result["jsonrpc"] == "2.0"
43+
assert "result" in result
44+
assert result["result"]["result"] == "success"
45+
assert result["result"]["data"] == "test data from old format"
46+
assert result["id"] == 123
47+
48+
# Verify the tool was invoked with correct parameters
49+
mock_invoke.assert_called_once()
50+
call_args = mock_invoke.call_args
51+
assert call_args.kwargs["name"] == "my_custom_tool"
52+
assert call_args.kwargs["arguments"] == {"query": "test query", "limit": 10}
53+
54+
def test_new_format_tool_invocation_still_works(self, client, mock_db):
55+
"""Test that new format (tools/call method) continues to work."""
56+
with patch("mcpgateway.config.settings.auth_required", False):
57+
with patch("mcpgateway.main.get_db", return_value=mock_db):
58+
with patch("mcpgateway.main.tool_service.invoke_tool", new_callable=AsyncMock) as mock_invoke:
59+
mock_invoke.return_value = {"result": "success", "data": "test data from new format"}
60+
61+
# New format: tools/call method
62+
request_body = {"jsonrpc": "2.0", "method": "tools/call", "params": {"name": "my_custom_tool", "arguments": {"query": "test query", "limit": 10}}, "id": 456}
63+
64+
response = client.post("/rpc", json=request_body)
65+
66+
assert response.status_code == 200
67+
result = response.json()
68+
assert result["jsonrpc"] == "2.0"
69+
assert "result" in result
70+
assert result["result"]["result"] == "success"
71+
assert result["result"]["data"] == "test data from new format"
72+
assert result["id"] == 456
73+
74+
# Verify the tool was invoked with correct parameters
75+
mock_invoke.assert_called_once()
76+
call_args = mock_invoke.call_args
77+
assert call_args.kwargs["name"] == "my_custom_tool"
78+
assert call_args.kwargs["arguments"] == {"query": "test query", "limit": 10}
79+
80+
def test_both_formats_invoke_same_tool(self, client, mock_db):
81+
"""Test that both old and new formats can invoke the same tool successfully."""
82+
with patch("mcpgateway.config.settings.auth_required", False):
83+
with patch("mcpgateway.main.get_db", return_value=mock_db):
84+
with patch("mcpgateway.main.tool_service.invoke_tool", new_callable=AsyncMock) as mock_invoke:
85+
mock_invoke.return_value = {"result": "success"}
86+
87+
# Test old format
88+
old_format_request = {"jsonrpc": "2.0", "method": "search_tool", "params": {"query": "old format"}, "id": 1}
89+
90+
response_old = client.post("/rpc", json=old_format_request)
91+
assert response_old.status_code == 200
92+
93+
# Reset mock
94+
mock_invoke.reset_mock()
95+
96+
# Test new format
97+
new_format_request = {"jsonrpc": "2.0", "method": "tools/call", "params": {"name": "search_tool", "arguments": {"query": "new format"}}, "id": 2}
98+
99+
response_new = client.post("/rpc", json=new_format_request)
100+
assert response_new.status_code == 200
101+
102+
# Both should have invoked the tool
103+
assert mock_invoke.call_count == 1
104+
call_args = mock_invoke.call_args
105+
assert call_args.kwargs["name"] == "search_tool"
106+
assert call_args.kwargs["arguments"]["query"] == "new format"
107+
108+
def test_invalid_method_still_returns_error(self, client, mock_db):
109+
"""Test that truly invalid methods still return an error."""
110+
with patch("mcpgateway.config.settings.auth_required", False):
111+
with patch("mcpgateway.main.get_db", return_value=mock_db):
112+
with patch("mcpgateway.main.tool_service.invoke_tool", new_callable=AsyncMock) as mock_invoke:
113+
# Simulate tool not found
114+
mock_invoke.side_effect = ValueError("Tool not found")
115+
116+
with patch("mcpgateway.main.gateway_service.forward_request", new_callable=AsyncMock) as mock_forward:
117+
# Simulate gateway forward also failing
118+
mock_forward.side_effect = ValueError("Not a gateway method")
119+
120+
request_body = {"jsonrpc": "2.0", "method": "completely_invalid_method", "params": {}, "id": 999}
121+
122+
response = client.post("/rpc", json=request_body)
123+
124+
assert response.status_code == 200
125+
result = response.json()
126+
assert result["jsonrpc"] == "2.0"
127+
assert "error" in result
128+
assert result["error"]["code"] == -32000
129+
assert result["error"]["message"] == "Invalid method"
130+
assert result["id"] == 999

0 commit comments

Comments
 (0)