Skip to content

Commit 9a8799b

Browse files
committed
feat(sdk): add mcp extension
Signed-off-by: Tomas Pilar <[email protected]>
1 parent 790d061 commit 9a8799b

File tree

7 files changed

+245
-30
lines changed

7 files changed

+245
-30
lines changed

apps/agentstack-sdk-py/examples/mcp_agent.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,18 @@
55
from typing import Annotated
66

77
from a2a.types import Message
8-
from mcp import ClientSession
98

109
from agentstack_sdk.a2a.extensions.auth.oauth import OAuthExtensionServer, OAuthExtensionSpec
10+
from agentstack_sdk.a2a.extensions.mcp.tool_call import (
11+
ToolCallExtensionParams,
12+
ToolCallExtensionServer,
13+
ToolCallExtensionSpec,
14+
)
1115
from agentstack_sdk.a2a.extensions.services.mcp import MCPServiceExtensionServer, MCPServiceExtensionSpec
1216
from agentstack_sdk.a2a.types import RunYield
1317
from agentstack_sdk.server import Server
1418
from agentstack_sdk.server.context import RunContext
19+
from agentstack_sdk.server.mcp.session import MCPClientSession
1520

1621
server = Server()
1722

@@ -21,24 +26,34 @@ async def mcp_agent(
2126
message: Message,
2227
context: RunContext,
2328
oauth: Annotated[OAuthExtensionServer, OAuthExtensionSpec.single_demand()],
24-
mcp: Annotated[
29+
mcp_tool_call: Annotated[ToolCallExtensionServer, ToolCallExtensionSpec(params=ToolCallExtensionParams())],
30+
mcp_service: Annotated[
2531
MCPServiceExtensionServer,
2632
MCPServiceExtensionSpec.single_demand(),
2733
],
2834
) -> AsyncGenerator[RunYield, Message]:
2935
"""Lists tools"""
3036

31-
if not mcp:
37+
if not mcp_service:
3238
yield "MCP extension hasn't been activated, no tools are available"
3339
return
3440

35-
async with mcp.create_client() as (read, write), ClientSession(read, write) as session:
41+
async with (
42+
mcp_service.create_client() as (read, write),
43+
MCPClientSession(read, write, context=context).apply(mcp_tool_call) as session,
44+
):
3645
await session.initialize()
3746

38-
tools = await session.list_tools()
47+
result = await session.list_tools()
3948

4049
yield "Available tools: \n"
41-
yield "\n".join([t.name for t in tools.tools])
50+
yield "\n".join([t.name for t in result.tools])
51+
52+
if result.tools:
53+
tool = result.tools[0]
54+
yield f"Calling tool {tool.name}"
55+
await session.call_tool(tool.name, None)
56+
yield "Tool call finished"
4257

4358

4459
if __name__ == "__main__":

apps/agentstack-sdk-py/examples/mcp_client.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pydantic import AnyHttpUrl, AnyUrl
1313

1414
import agentstack_sdk.a2a.extensions
15+
from agentstack_sdk.a2a.extensions.mcp.tool_call import ToolCallResponse
1516

1617

1718
class OAuthHandler:
@@ -67,67 +68,83 @@ async def handler(request: web.Request) -> web.Response:
6768
async def run(base_url: str = "http://127.0.0.1:10000"):
6869
async with httpx.AsyncClient(timeout=30) as httpx_client:
6970
card = await a2a.client.A2ACardResolver(httpx_client, base_url=base_url).get_agent_card()
70-
mcp_spec = agentstack_sdk.a2a.extensions.MCPServiceExtensionSpec.from_agent_card(card)
71+
mcp_service_spec = agentstack_sdk.a2a.extensions.MCPServiceExtensionSpec.from_agent_card(card)
7172
oauth_spec = agentstack_sdk.a2a.extensions.OAuthExtensionSpec.from_agent_card(card)
73+
tool_call_spec = agentstack_sdk.a2a.extensions.ToolCallExtensionSpec.from_agent_card(card)
7274

73-
if not mcp_spec:
75+
if not mcp_service_spec:
7476
raise ValueError(f"Agent at {base_url} does not support MCP service injection")
7577
if not oauth_spec:
7678
raise ValueError(f"Agent at {base_url} does not support oAuth")
79+
if not tool_call_spec:
80+
raise ValueError(f"Agent at {base_url} does not support MCP")
7781

78-
mcp_extension_client = agentstack_sdk.a2a.extensions.MCPServiceExtensionClient(mcp_spec)
82+
mcp_service_extension_client = agentstack_sdk.a2a.extensions.MCPServiceExtensionClient(mcp_service_spec)
7983
oauth_extension_client = agentstack_sdk.a2a.extensions.OAuthExtensionClient(oauth_spec)
84+
tool_call_extension_client = agentstack_sdk.a2a.extensions.ToolCallExtensionClient(tool_call_spec)
8085

8186
oauth = OAuthHandler()
8287
message = a2a.types.Message(
8388
message_id=str(uuid.uuid4()),
8489
role=a2a.types.Role.user,
8590
parts=[a2a.types.Part(root=a2a.types.TextPart(text="Howdy!"))],
86-
metadata=mcp_extension_client.fulfillment_metadata(
91+
metadata=mcp_service_extension_client.fulfillment_metadata(
8792
mcp_fulfillments={
8893
key: agentstack_sdk.a2a.extensions.services.mcp.MCPFulfillment(
8994
transport=agentstack_sdk.a2a.extensions.services.mcp.StreamableHTTPTransport(
9095
url=AnyHttpUrl("https://mcp.stripe.com")
9196
),
9297
)
93-
for key in mcp_spec.params.mcp_demands
98+
for key in mcp_service_spec.params.mcp_demands
9499
}
95100
)
96101
| oauth_extension_client.fulfillment_metadata(
97102
oauth_fulfillments={
98103
key: agentstack_sdk.a2a.extensions.OAuthFulfillment(redirect_uri=AnyUrl(oauth.redirect_uri))
99104
for key in oauth_spec.params.oauth_demands
100105
}
101-
),
106+
)
107+
| tool_call_extension_client.metadata(),
102108
)
103109

104110
client = a2a.client.ClientFactory(a2a.client.ClientConfig(httpx_client=httpx_client, polling=True)).create(
105111
card=card
106112
)
107113

108114
task = None
109-
async for event in client.send_message(message):
110-
if isinstance(event, a2a.types.Message):
111-
print(event)
112-
return
113-
task, _update = event
115+
while True:
116+
async for event in client.send_message(message):
117+
if isinstance(event, a2a.types.Message):
118+
print(event)
119+
return
120+
task, _update = event
114121

115-
if task and task.status.state == a2a.types.TaskState.auth_required:
116-
if not task.status.message:
117-
raise RuntimeError("Missing message")
122+
if task and task.status.state == a2a.types.TaskState.auth_required:
123+
if not task.status.message:
124+
raise RuntimeError("Missing message")
118125

119-
auth_request = oauth_extension_client.parse_auth_request(message=task.status.message)
126+
auth_request = oauth_extension_client.parse_auth_request(message=task.status.message)
120127

121-
print("Agent has requested authorization")
122-
oauth.open_browser(str(auth_request.authorization_endpoint_url))
123-
request = await oauth.handle_redirect()
128+
print("Agent has requested authorization")
129+
oauth.open_browser(str(auth_request.authorization_endpoint_url))
130+
request = await oauth.handle_redirect()
124131

125-
async for event in client.send_message(
126-
oauth_extension_client.create_auth_response(task_id=task.id, redirect_uri=AnyUrl(str(request.url)))
127-
):
128-
if isinstance(event, a2a.types.Message):
129-
raise RuntimeError("Agent responded with message to a task")
130-
task, _update = event
132+
message = oauth_extension_client.create_auth_response(
133+
task_id=task.id, redirect_uri=AnyUrl(str(request.url))
134+
)
135+
elif task and task.status.state == a2a.types.TaskState.input_required:
136+
if not task.status.message:
137+
raise RuntimeError("Missing message")
138+
139+
approval_request = tool_call_extension_client.parse_message(message=task.status.message)
140+
141+
print("Agent has requested a tool call")
142+
print(approval_request)
143+
choice = input("Approve (Y/n): ")
144+
response = ToolCallResponse(action="accept" if choice.lower() == "y" else "reject")
145+
message = tool_call_extension_client.create_message(task_id=task.id, response=response)
146+
else:
147+
break
131148

132149
print(task)
133150

apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
from .auth import *
5+
from .mcp import *
56
from .services import *
67
from .ui import *
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from .tool_call import *
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
import uuid
7+
from types import NoneType
8+
from typing import TYPE_CHECKING, Any, Literal
9+
10+
import a2a.types
11+
import pydantic
12+
from mcp.types import Implementation, ToolAnnotations
13+
14+
from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec
15+
from agentstack_sdk.a2a.types import AgentMessage, InputRequired
16+
17+
if TYPE_CHECKING:
18+
from agentstack_sdk.server.context import RunContext
19+
20+
21+
class ToolCallRequest(pydantic.BaseModel):
22+
server: Implementation | None
23+
name: str
24+
arguments: dict[str, Any] | None
25+
annotations: ToolAnnotations | None
26+
27+
28+
class ToolCallResponse(pydantic.BaseModel):
29+
action: Literal["accept", "reject"]
30+
31+
32+
class ToolCallExtensionParams(pydantic.BaseModel):
33+
pass
34+
35+
36+
class ToolCallExtensionSpec(BaseExtensionSpec[ToolCallExtensionParams]):
37+
URI: str = "https://a2a-extensions.agentstack.beeai.dev/mcp/tool-call/v1"
38+
39+
40+
class ToolCallExtensionMetadata(pydantic.BaseModel):
41+
pass
42+
43+
44+
class ToolCallExtensionServer(BaseExtensionServer[ToolCallExtensionSpec, ToolCallExtensionMetadata]):
45+
def create_message(self, *, request: ToolCallRequest):
46+
return AgentMessage(
47+
text="Tool call approval requested", metadata={self.spec.URI: request.model_dump(mode="json")}
48+
)
49+
50+
def parse_message(self, *, message: a2a.types.Message):
51+
if not message or not message.metadata or not (data := message.metadata.get(self.spec.URI)):
52+
raise RuntimeError("Invalid mcp response")
53+
return ToolCallResponse.model_validate(data)
54+
55+
async def request_tool_call(self, request: ToolCallRequest, context: RunContext) -> ToolCallResponse:
56+
message = self.create_message(request=request)
57+
message = await context.yield_async(InputRequired(message=message))
58+
if message:
59+
result = self.parse_message(message=message)
60+
if result.action != "accept":
61+
raise RuntimeError("User has rejected the tool call")
62+
return result
63+
else:
64+
raise RuntimeError("Tool call approval response is missing")
65+
66+
67+
class ToolCallExtensionClient(BaseExtensionClient[ToolCallExtensionSpec, NoneType]):
68+
def create_message(self, *, response: ToolCallResponse, task_id: str | None):
69+
return a2a.types.Message(
70+
message_id=str(uuid.uuid4()),
71+
role=a2a.types.Role.user,
72+
parts=[],
73+
task_id=task_id,
74+
metadata={self.spec.URI: response.model_dump(mode="json")},
75+
)
76+
77+
def parse_message(self, *, message: a2a.types.Message):
78+
if not message or not message.metadata or not (data := message.metadata.get(self.spec.URI)):
79+
raise ValueError("Invalid tool call request")
80+
return ToolCallRequest.model_validate(data)
81+
82+
def metadata(self) -> dict[str, Any]:
83+
return {self.spec.URI: ToolCallExtensionMetadata().model_dump(mode="json")}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
2+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
from datetime import timedelta
7+
from typing import Any, Self
8+
9+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10+
from mcp import ClientSession
11+
from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
12+
from mcp.shared.message import SessionMessage
13+
from mcp.shared.session import ProgressFnT
14+
from mcp.types import CallToolResult, Implementation, InitializeResult, ListToolsResult, Tool
15+
16+
from agentstack_sdk.a2a.extensions.mcp.tool_call import ToolCallExtensionServer, ToolCallRequest
17+
from agentstack_sdk.server.context import RunContext
18+
19+
20+
class MCPClientSession(ClientSession):
21+
def __init__(
22+
self,
23+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
24+
write_stream: MemoryObjectSendStream[SessionMessage],
25+
read_timeout_seconds: timedelta | None = None,
26+
sampling_callback: SamplingFnT | None = None,
27+
elicitation_callback: ElicitationFnT | None = None,
28+
list_roots_callback: ListRootsFnT | None = None,
29+
logging_callback: LoggingFnT | None = None,
30+
message_handler: MessageHandlerFnT | None = None,
31+
client_info: Implementation | None = None,
32+
*,
33+
context: RunContext,
34+
) -> None:
35+
super().__init__(
36+
read_stream,
37+
write_stream,
38+
read_timeout_seconds,
39+
sampling_callback,
40+
elicitation_callback,
41+
list_roots_callback,
42+
logging_callback,
43+
message_handler,
44+
client_info,
45+
)
46+
self._context = context
47+
self._mcp_extension = None
48+
self._server_info: Implementation | None = None
49+
self._observed_tools: dict[str, Tool] = {}
50+
51+
def apply(self, extension: ToolCallExtensionServer) -> Self:
52+
self._mcp_extension = extension
53+
return self
54+
55+
async def initialize(self) -> InitializeResult:
56+
result = await super().initialize()
57+
self._server_info = result.serverInfo
58+
return result
59+
60+
async def list_tools(self, cursor: str | None = None) -> ListToolsResult:
61+
result = await super().list_tools(cursor=cursor)
62+
for tool in result.tools:
63+
self._observed_tools[tool.name] = tool
64+
return result
65+
66+
async def call_tool(
67+
self,
68+
name: str,
69+
arguments: dict[str, Any] | None = None,
70+
read_timeout_seconds: timedelta | None = None,
71+
progress_callback: ProgressFnT | None = None,
72+
) -> CallToolResult:
73+
"""Send a tools/call request with optional progress callback support."""
74+
75+
if self._mcp_extension:
76+
tool = self._observed_tools.get(name)
77+
await self._mcp_extension.request_tool_call(
78+
request=ToolCallRequest(
79+
server=self._server_info,
80+
name=name,
81+
arguments=arguments,
82+
annotations=tool.annotations if tool else None,
83+
),
84+
context=self._context,
85+
)
86+
return await super().call_tool(
87+
name,
88+
arguments,
89+
read_timeout_seconds,
90+
progress_callback,
91+
)
92+
else:
93+
return await super().call_tool(name, arguments, read_timeout_seconds, progress_callback)

0 commit comments

Comments
 (0)