Skip to content

Commit 349185b

Browse files
committed
feat(sdk): add mcp extension
Signed-off-by: Tomas Pilar <thomas7pilar@gmail.com>
1 parent 1c8d04a commit 349185b

File tree

5 files changed

+166
-6
lines changed

5 files changed

+166
-6
lines changed

apps/agentstack-sdk-py/examples/mcp_agent.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
from typing import Annotated
66

77
from a2a.types import Message
8-
from mcp import ClientSession
98

109
from agentstack_sdk.a2a.extensions.auth.oauth import OAuthExtensionServer, OAuthExtensionSpec
10+
from agentstack_sdk.a2a.extensions.mcp import MCPExtensionParams, MCPExtensionServer, MCPExtensionSpec
1111
from agentstack_sdk.a2a.extensions.services.mcp import MCPServiceExtensionServer, MCPServiceExtensionSpec
1212
from agentstack_sdk.a2a.types import RunYield
1313
from agentstack_sdk.server import Server
1414
from agentstack_sdk.server.context import RunContext
15+
from agentstack_sdk.server.mcp.session import MCPClientSession
1516

1617
server = Server()
1718

@@ -21,24 +22,34 @@ async def mcp_agent(
2122
message: Message,
2223
context: RunContext,
2324
oauth: Annotated[OAuthExtensionServer, OAuthExtensionSpec.single_demand()],
24-
mcp: Annotated[
25+
mcp: Annotated[MCPExtensionServer, MCPExtensionSpec(params=MCPExtensionParams())],
26+
mcp_service: Annotated[
2527
MCPServiceExtensionServer,
2628
MCPServiceExtensionSpec.single_demand(),
2729
],
2830
) -> AsyncGenerator[RunYield, Message]:
2931
"""Lists tools"""
3032

31-
if not mcp:
33+
if not mcp_service:
3234
yield "MCP extension hasn't been activated, no tools are available"
3335
return
3436

35-
async with mcp.create_client() as (read, write), ClientSession(read, write) as session:
37+
async with (
38+
mcp_service.create_client() as (read, write),
39+
MCPClientSession(read, write, context=context).apply(mcp) as session,
40+
):
3641
await session.initialize()
3742

38-
tools = await session.list_tools()
43+
result = await session.list_tools()
3944

4045
yield "Available tools: \n"
41-
yield "\n".join([t.name for t in tools.tools])
46+
yield "\n".join([t.name for t in result.tools])
47+
48+
if result.tools:
49+
tool = result.tools[0]
50+
yield f"Calling tool {tool.name}"
51+
await session.call_tool(tool.name, None)
52+
yield "Tool call finished"
4253

4354

4455
if __name__ == "__main__":
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from .mcp import *
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
import uuid
7+
from types import NoneType
8+
from typing import TYPE_CHECKING, Any, Literal
9+
10+
import a2a.types
11+
import pydantic
12+
13+
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
14+
from agentstack_sdk.a2a.types import AgentMessage
15+
16+
if TYPE_CHECKING:
17+
from agentstack_sdk.server.context import RunContext
18+
19+
20+
class ToolCallApprovalRequest(pydantic.BaseModel):
21+
name: str
22+
arguments: dict[str, Any] | None = None
23+
24+
25+
class ToolCallApprovalResponse(pydantic.BaseModel):
26+
action: Literal["accept", "reject"]
27+
28+
29+
class MCPExtensionParams(pydantic.BaseModel):
30+
pass
31+
32+
33+
class MCPExtensionSpec(BaseExtensionSpec[MCPExtensionParams]):
34+
URI: str = "https://a2a-extensions.agentstack.beeai.dev/mcp/v1"
35+
36+
37+
class MCPExtensionMetadata(pydantic.BaseModel):
38+
pass
39+
40+
41+
class MCPExtensionServer(BaseExtensionServer[MCPExtensionSpec, MCPExtensionMetadata]):
42+
def create_message(self, *, request: ToolCallApprovalRequest):
43+
return AgentMessage(
44+
text="Tool call approval requested", metadata={self.spec.URI: request.model_dump(mode="json")}
45+
)
46+
47+
def parse_message(self, *, message: a2a.types.Message):
48+
if not message or not message.metadata or not (data := message.metadata.get(self.spec.URI)):
49+
raise RuntimeError("Invalid mcp response")
50+
return ToolCallApprovalResponse.model_validate(data)
51+
52+
async def raise_tool_approval(self, request: ToolCallApprovalRequest, context: RunContext):
53+
message = self.create_message(request=request)
54+
message = await context.yield_async(message)
55+
if message:
56+
result = self.parse_message(message=message)
57+
if result.action != "accept":
58+
raise RuntimeError("User has rejected the tool call")
59+
else:
60+
raise RuntimeError("Tool call approval response is missing")
61+
62+
63+
class MCPExtensionClient(BaseExtensionClient[MCPExtensionSpec, NoneType]):
64+
def create_message(self, *, response: ToolCallApprovalResponse, task_id: str | None):
65+
return a2a.types.Message(
66+
message_id=str(uuid.uuid4()),
67+
role=a2a.types.Role.user,
68+
parts=[],
69+
task_id=task_id,
70+
metadata={self.spec.URI: response.model_dump(mode="json")},
71+
)
72+
73+
def parse_message(self, *, message: a2a.types.Message):
74+
if not message or not message.metadata or not (data := message.metadata.get(self.spec.URI)):
75+
raise ValueError("Invalid mcp request")
76+
return ToolCallApprovalRequest.model_validate(data)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
2+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from datetime import timedelta
7+
from typing import Any, Self
8+
9+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10+
from mcp import ClientSession
11+
from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
12+
from mcp.shared.message import SessionMessage
13+
from mcp.shared.session import ProgressFnT
14+
from mcp.types import CallToolResult, Implementation
15+
16+
from agentstack_sdk.a2a.extensions.mcp import MCPExtensionServer, ToolCallApprovalRequest
17+
from agentstack_sdk.server.context import RunContext
18+
19+
20+
class MCPClientSession(ClientSession):
21+
def __init__(
22+
self,
23+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
24+
write_stream: MemoryObjectSendStream[SessionMessage],
25+
read_timeout_seconds: timedelta | None = None,
26+
sampling_callback: SamplingFnT | None = None,
27+
elicitation_callback: ElicitationFnT | None = None,
28+
list_roots_callback: ListRootsFnT | None = None,
29+
logging_callback: LoggingFnT | None = None,
30+
message_handler: MessageHandlerFnT | None = None,
31+
client_info: Implementation | None = None,
32+
*,
33+
context: RunContext,
34+
) -> None:
35+
super().__init__(
36+
read_stream,
37+
write_stream,
38+
read_timeout_seconds,
39+
sampling_callback,
40+
elicitation_callback,
41+
list_roots_callback,
42+
logging_callback,
43+
message_handler,
44+
client_info,
45+
)
46+
self._context = context
47+
self._mcp_extension = None
48+
49+
def apply(self, extension: MCPExtensionServer) -> Self:
50+
self._mcp_extension = extension
51+
return self
52+
53+
async def call_tool(
54+
self,
55+
name: str,
56+
arguments: dict[str, Any] | None = None,
57+
read_timeout_seconds: timedelta | None = None,
58+
progress_callback: ProgressFnT | None = None,
59+
) -> CallToolResult:
60+
"""Send a tools/call request with optional progress callback support."""
61+
62+
if self._mcp_extension:
63+
await self._mcp_extension.raise_tool_approval(
64+
request=ToolCallApprovalRequest(name=name, arguments=arguments), context=self._context
65+
)
66+
67+
return await super().call_tool(name, arguments, read_timeout_seconds, progress_callback)

0 commit comments

Comments
 (0)