Skip to content

Commit e7dcd88

Browse files
authored
Adds RPC endpoints and updates RPC response and error handling (#746)
* Fix rpc endpoints Signed-off-by: Madhav Kandukuri <[email protected]> * Remove commented code Signed-off-by: Madhav Kandukuri <[email protected]> * remove duplicate code in session registry Signed-off-by: Madhav Kandukuri <[email protected]> * Linting fixes Signed-off-by: Madhav Kandukuri <[email protected]> * Fix tests Signed-off-by: Madhav Kandukuri <[email protected]> --------- Signed-off-by: Madhav Kandukuri <[email protected]>
1 parent 397890b commit e7dcd88

File tree

6 files changed

+409
-166
lines changed

6 files changed

+409
-166
lines changed

mcpgateway/cache/session_registry.py

Lines changed: 34 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from mcpgateway.services.logging_service import LoggingService
6868
from mcpgateway.transports import SSETransport
6969
from mcpgateway.utils.retry_manager import ResilientHttpClient
70+
from mcpgateway.validation.jsonrpc import JSONRPCError
7071

7172
# Initialize logging service first
7273
logging_service: LoggingService = LoggingService()
@@ -1276,19 +1277,43 @@ async def generate_response(self, message: Dict[str, Any], transport: SSETranspo
12761277
>>> # Response: {}
12771278
"""
12781279
result = {}
1280+
12791281
if "method" in message and "id" in message:
1280-
method = message["method"]
1281-
params = message.get("params", {})
1282-
req_id = message["id"]
1283-
db = next(get_db())
1284-
if method == "initialize":
1285-
init_result = await self.handle_initialize_logic(params)
1286-
response = {
1282+
try:
1283+
method = message["method"]
1284+
params = message.get("params", {})
1285+
params["server_id"] = server_id
1286+
req_id = message["id"]
1287+
1288+
rpc_input = {
12871289
"jsonrpc": "2.0",
1288-
"result": init_result.model_dump(by_alias=True, exclude_none=True),
1290+
"method": method,
1291+
"params": params,
12891292
"id": req_id,
12901293
}
1291-
await transport.send_message(response)
1294+
headers = {"Authorization": f"Bearer {user['token']}", "Content-Type": "application/json"}
1295+
rpc_url = base_url + "/rpc"
1296+
async with ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify}) as client:
1297+
rpc_response = await client.post(
1298+
url=rpc_url,
1299+
json=rpc_input,
1300+
headers=headers,
1301+
)
1302+
result = rpc_response.json()
1303+
result = result.get("result", {})
1304+
1305+
response = {"jsonrpc": "2.0", "result": result, "id": req_id}
1306+
except JSONRPCError as e:
1307+
result = e.to_dict()
1308+
response = {"jsonrpc": "2.0", "error": result["error"], "id": req_id}
1309+
except Exception as e:
1310+
result = {"code": -32000, "message": "Internal error", "data": str(e)}
1311+
response = {"jsonrpc": "2.0", "error": result, "id": req_id}
1312+
1313+
logging.debug(f"Sending sse message:{response}")
1314+
await transport.send_message(response)
1315+
1316+
if message["method"] == "initialize":
12921317
await transport.send_message(
12931318
{
12941319
"jsonrpc": "2.0",
@@ -1309,48 +1334,3 @@ async def generate_response(self, message: Dict[str, Any], transport: SSETranspo
13091334
"params": {},
13101335
}
13111336
)
1312-
elif method == "tools/list":
1313-
if server_id:
1314-
tools = await tool_service.list_server_tools(db, server_id=server_id)
1315-
else:
1316-
tools = await tool_service.list_tools(db)
1317-
result = {"tools": [t.model_dump(by_alias=True, exclude_none=True) for t in tools]}
1318-
elif method == "resources/list":
1319-
if server_id:
1320-
resources = await resource_service.list_server_resources(db, server_id=server_id)
1321-
else:
1322-
resources = await resource_service.list_resources(db)
1323-
result = {"resources": [r.model_dump(by_alias=True, exclude_none=True) for r in resources]}
1324-
elif method == "prompts/list":
1325-
if server_id:
1326-
prompts = await prompt_service.list_server_prompts(db, server_id=server_id)
1327-
else:
1328-
prompts = await prompt_service.list_prompts(db)
1329-
result = {"prompts": [p.model_dump(by_alias=True, exclude_none=True) for p in prompts]}
1330-
elif method == "prompts/get":
1331-
prompts = await prompt_service.get_prompt(db, name=params.get("name"), arguments=params.get("arguments", {}))
1332-
result = prompts.model_dump(by_alias=True, exclude_none=True)
1333-
elif method == "ping":
1334-
result = {}
1335-
elif method == "tools/call":
1336-
rpc_input = {
1337-
"jsonrpc": "2.0",
1338-
"method": message["params"]["name"],
1339-
"params": message["params"]["arguments"],
1340-
"id": 1,
1341-
}
1342-
headers = {"Authorization": f"Bearer {user['token']}", "Content-Type": "application/json"}
1343-
rpc_url = base_url + "/rpc"
1344-
async with ResilientHttpClient(client_args={"timeout": settings.federation_timeout, "verify": not settings.skip_ssl_verify}) as client:
1345-
rpc_response = await client.post(
1346-
url=rpc_url,
1347-
json=rpc_input,
1348-
headers=headers,
1349-
)
1350-
result = rpc_response.json()
1351-
else:
1352-
result = {}
1353-
1354-
response = {"jsonrpc": "2.0", "result": result, "id": req_id}
1355-
logging.info(f"Sending sse message:{response}")
1356-
await transport.send_message(response)

mcpgateway/main.py

Lines changed: 67 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
from mcpgateway.db import Prompt as DbPrompt
6060
from mcpgateway.db import PromptMetric, refresh_slugs_on_startup, SessionLocal
6161
from mcpgateway.handlers.sampling import SamplingHandler
62-
from mcpgateway.models import InitializeRequest, InitializeResult, ListResourceTemplatesResult, LogLevel, ResourceContent, Root
62+
from mcpgateway.models import InitializeResult, ListResourceTemplatesResult, LogLevel, ResourceContent, Root
6363
from mcpgateway.observability import init_telemetry
6464
from mcpgateway.plugins import PluginManager, PluginViolationError
6565
from mcpgateway.schemas import (
@@ -2225,39 +2225,57 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user: str
22252225
logger.debug(f"User {user} made an RPC request")
22262226
body = await request.json()
22272227
method = body["method"]
2228-
# rpc_id = body.get("id")
2228+
req_id = body.get("id") if "body" in locals() else None
22292229
params = body.get("params", {})
2230+
server_id = params.get("server_id", None)
22302231
cursor = params.get("cursor") # Extract cursor parameter
22312232

22322233
RPCRequest(jsonrpc="2.0", method=method, params=params) # Validate the request body against the RPCRequest model
22332234

2234-
if method == "tools/list":
2235-
tools = await tool_service.list_tools(db, cursor=cursor)
2236-
result = [t.model_dump(by_alias=True, exclude_none=True) for t in tools]
2235+
if method == "initialize":
2236+
result = await session_registry.handle_initialize_logic(body.get("params", {}))
2237+
if hasattr(result, "model_dump"):
2238+
result = result.model_dump(by_alias=True, exclude_none=True)
2239+
elif method == "tools/list":
2240+
if server_id:
2241+
tools = await tool_service.list_server_tools(db, server_id, cursor=cursor)
2242+
else:
2243+
tools = await tool_service.list_tools(db, cursor=cursor)
2244+
result = {"tools": [t.model_dump(by_alias=True, exclude_none=True) for t in tools]}
22372245
elif method == "list_tools": # Legacy endpoint
2238-
tools = await tool_service.list_tools(db, cursor=cursor)
2239-
result = [t.model_dump(by_alias=True, exclude_none=True) for t in tools]
2240-
elif method == "initialize":
2241-
result = initialize(
2242-
InitializeRequest(
2243-
protocol_version=params.get("protocolVersion") or params.get("protocol_version", ""),
2244-
capabilities=params.get("capabilities", {}),
2245-
client_info=params.get("clientInfo") or params.get("client_info", {}),
2246-
),
2247-
user,
2248-
).model_dump(by_alias=True, exclude_none=True)
2246+
if server_id:
2247+
tools = await tool_service.list_server_tools(db, server_id, cursor=cursor)
2248+
else:
2249+
tools = await tool_service.list_tools(db, cursor=cursor)
2250+
result = {"tools": [t.model_dump(by_alias=True, exclude_none=True) for t in tools]}
22492251
elif method == "list_gateways":
22502252
gateways = await gateway_service.list_gateways(db, include_inactive=False)
2251-
result = [g.model_dump(by_alias=True, exclude_none=True) for g in gateways]
2253+
result = {"gateways": [g.model_dump(by_alias=True, exclude_none=True) for g in gateways]}
22522254
elif method == "list_roots":
22532255
roots = await root_service.list_roots()
2254-
result = [r.model_dump(by_alias=True, exclude_none=True) for r in roots]
2256+
result = {"roots": [r.model_dump(by_alias=True, exclude_none=True) for r in roots]}
22552257
elif method == "resources/list":
2256-
resources = await resource_service.list_resources(db)
2257-
result = [r.model_dump(by_alias=True, exclude_none=True) for r in resources]
2258+
if server_id:
2259+
resources = await resource_service.list_server_resources(db, server_id)
2260+
else:
2261+
resources = await resource_service.list_resources(db)
2262+
result = {"resources": [r.model_dump(by_alias=True, exclude_none=True) for r in resources]}
2263+
elif method == "resources/read":
2264+
uri = params.get("uri")
2265+
request_id = params.get("requestId", None)
2266+
if not uri:
2267+
raise JSONRPCError(-32602, "Missing resource URI in parameters", params)
2268+
result = await resource_service.read_resource(db, uri, request_id=request_id, user=user)
2269+
if hasattr(result, "model_dump"):
2270+
result = {"contents": [result.model_dump(by_alias=True, exclude_none=True)]}
2271+
else:
2272+
result = {"contents": [result]}
22582273
elif method == "prompts/list":
2259-
prompts = await prompt_service.list_prompts(db, cursor=cursor)
2260-
result = [p.model_dump(by_alias=True, exclude_none=True) for p in prompts]
2274+
if server_id:
2275+
prompts = await prompt_service.list_server_prompts(db, server_id, cursor=cursor)
2276+
else:
2277+
prompts = await prompt_service.list_prompts(db, cursor=cursor)
2278+
result = {"prompts": [p.model_dump(by_alias=True, exclude_none=True) for p in prompts]}
22612279
elif method == "prompts/get":
22622280
name = params.get("name")
22632281
arguments = params.get("arguments", {})
@@ -2269,31 +2287,52 @@ async def handle_rpc(request: Request, db: Session = Depends(get_db), user: str
22692287
elif method == "ping":
22702288
# Per the MCP spec, a ping returns an empty result.
22712289
result = {}
2272-
else:
2290+
elif method == "tools/call":
22732291
# Get request headers
22742292
headers = {k.lower(): v for k, v in request.headers.items()}
2293+
name = params.get("name")
2294+
arguments = params.get("arguments", {})
2295+
if not name:
2296+
raise JSONRPCError(-32602, "Missing tool name in parameters", params)
22752297
try:
2276-
result = await tool_service.invoke_tool(db=db, name=method, arguments=params, request_headers=headers)
2298+
result = await tool_service.invoke_tool(db=db, name=name, arguments=arguments, request_headers=headers)
22772299
if hasattr(result, "model_dump"):
22782300
result = result.model_dump(by_alias=True, exclude_none=True)
22792301
except ValueError:
22802302
result = await gateway_service.forward_request(db, method, params)
22812303
if hasattr(result, "model_dump"):
22822304
result = result.model_dump(by_alias=True, exclude_none=True)
2305+
# TODO: Implement methods
2306+
elif method == "resources/templates/list":
2307+
result = {}
2308+
elif method.startswith("roots/"):
2309+
result = {}
2310+
elif method.startswith("notifications/"):
2311+
result = {}
2312+
elif method.startswith("sampling/"):
2313+
result = {}
2314+
elif method.startswith("elicitation/"):
2315+
result = {}
2316+
elif method.startswith("completion/"):
2317+
result = {}
2318+
elif method.startswith("logging/"):
2319+
result = {}
2320+
else:
2321+
raise JSONRPCError(-32000, "Invalid method", params)
22832322

2284-
response = result
2285-
return response
2323+
return {"jsonrpc": "2.0", "result": result, "id": req_id}
22862324

22872325
except JSONRPCError as e:
2288-
return e.to_dict()
2326+
error = e.to_dict()
2327+
return {"jsonrpc": "2.0", "error": error["error"], "id": req_id}
22892328
except Exception as e:
22902329
if isinstance(e, ValueError):
22912330
return JSONResponse(content={"message": "Method invalid"}, status_code=422)
22922331
logger.error(f"RPC error: {str(e)}")
22932332
return {
22942333
"jsonrpc": "2.0",
22952334
"error": {"code": -32000, "message": "Internal error", "data": str(e)},
2296-
"id": body.get("id") if "body" in locals() else None,
2335+
"id": req_id,
22972336
}
22982337

22992338

tests/e2e/test_main_apis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,7 +1324,7 @@ async def test_rpc_ping(self, client: AsyncClient, mock_auth):
13241324

13251325
assert response.status_code == 200
13261326
result = response.json()
1327-
assert result == {} # ping returns empty result
1327+
assert result == {"jsonrpc": "2.0", "result": {}, "id": "test-123"} # ping returns empty result
13281328

13291329
async def test_rpc_list_tools(self, client: AsyncClient, mock_auth):
13301330
"""Test POST /rpc - tools/list method."""
@@ -1334,7 +1334,7 @@ async def test_rpc_list_tools(self, client: AsyncClient, mock_auth):
13341334

13351335
assert response.status_code == 200
13361336
result = response.json()
1337-
assert isinstance(result, list)
1337+
assert isinstance(result.get("result", {}).get("tools"), list)
13381338

13391339
async def test_rpc_invalid_method(self, client: AsyncClient, mock_auth):
13401340
"""Test POST /rpc with invalid method."""

tests/integration/test_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,10 @@ def test_rpc_tool_invocation_flow(
244244
"is_error": False,
245245
}
246246

247-
rpc_body = {"jsonrpc": "2.0", "id": 7, "method": "test_tool", "params": {"foo": "bar"}}
247+
rpc_body = {"jsonrpc": "2.0", "id": 7, "method": "tools/call", "params": {"name": "test_tool", "arguments": {"foo": "bar"}}}
248248
resp = test_client.post("/rpc/", json=rpc_body, headers=auth_headers)
249249
assert resp.status_code == 200
250-
assert resp.json()["content"][0]["text"] == "ok"
250+
assert resp.json()["result"]["content"][0]["text"] == "ok"
251251
mock_invoke.assert_awaited_once_with(db=ANY, name="test_tool", arguments={"foo": "bar"}, request_headers=ANY)
252252

253253
# --------------------------------------------------------------------- #

0 commit comments

Comments
 (0)