Skip to content

Commit 41f76c2

Browse files
authored
feat(sdk): add tool call extension (#1572)
Signed-off-by: Tomas Pilar <[email protected]>
1 parent 128c63b commit 41f76c2

File tree

10 files changed

+408
-46
lines changed

10 files changed

+408
-46
lines changed

apps/agentstack-sdk-py/examples/mcp_agent.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99

1010
from agentstack_sdk.a2a.extensions.auth.oauth import OAuthExtensionServer, OAuthExtensionSpec
1111
from agentstack_sdk.a2a.extensions.services.mcp import MCPServiceExtensionServer, MCPServiceExtensionSpec
12+
from agentstack_sdk.a2a.extensions.tools.call import (
13+
ToolCallExtensionParams,
14+
ToolCallExtensionServer,
15+
ToolCallExtensionSpec,
16+
ToolCallRequest,
17+
)
1218
from agentstack_sdk.a2a.types import RunYield
1319
from agentstack_sdk.server import Server
1420
from agentstack_sdk.server.context import RunContext
@@ -21,24 +27,43 @@ async def mcp_agent(
2127
message: Message,
2228
context: RunContext,
2329
oauth: Annotated[OAuthExtensionServer, OAuthExtensionSpec.single_demand()],
24-
mcp: Annotated[
30+
mcp_tool_call: Annotated[ToolCallExtensionServer, ToolCallExtensionSpec(params=ToolCallExtensionParams())],
31+
mcp_service: Annotated[
2532
MCPServiceExtensionServer,
2633
MCPServiceExtensionSpec.single_demand(),
2734
],
2835
) -> AsyncGenerator[RunYield, Message]:
2936
"""Lists tools"""
3037

31-
if not mcp:
38+
if not mcp_service:
3239
yield "MCP extension hasn't been activated, no tools are available"
3340
return
3441

35-
async with mcp.create_client() as (read, write), ClientSession(read, write) as session:
36-
await session.initialize()
42+
if not mcp_tool_call:
43+
yield "MCP Tool Call extension hasn't been activated, no approval requests will be issued"
3744

38-
tools = await session.list_tools()
45+
async with (
46+
mcp_service.create_client() as (read, write),
47+
ClientSession(read, write) as session,
48+
):
49+
session_init_result = await session.initialize()
50+
51+
result = await session.list_tools()
3952

4053
yield "Available tools: \n"
41-
yield "\n".join([t.name for t in tools.tools])
54+
yield "\n".join([t.name for t in result.tools])
55+
56+
if result.tools:
57+
tool = result.tools[0]
58+
input = {}
59+
yield f"Requesting approval for tool {tool.name}"
60+
if mcp_tool_call:
61+
await mcp_tool_call.request_tool_call_approval(
62+
ToolCallRequest.from_mcp_tool(tool, input, server=session_init_result.serverInfo), context=context
63+
)
64+
yield f"Calling tool {tool.name}"
65+
await session.call_tool(tool.name, input)
66+
yield "Tool call finished"
4267

4368

4469
if __name__ == "__main__":

apps/agentstack-sdk-py/examples/mcp_client.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pydantic import AnyHttpUrl, AnyUrl
1313

1414
import agentstack_sdk.a2a.extensions
15+
from agentstack_sdk.a2a.extensions.tools.call import ToolCallResponse
1516

1617

1718
class OAuthHandler:
@@ -67,67 +68,83 @@ async def handler(request: web.Request) -> web.Response:
6768
async def run(base_url: str = "http://127.0.0.1:10000"):
6869
async with httpx.AsyncClient(timeout=30) as httpx_client:
6970
card = await a2a.client.A2ACardResolver(httpx_client, base_url=base_url).get_agent_card()
70-
mcp_spec = agentstack_sdk.a2a.extensions.MCPServiceExtensionSpec.from_agent_card(card)
71+
mcp_service_spec = agentstack_sdk.a2a.extensions.MCPServiceExtensionSpec.from_agent_card(card)
7172
oauth_spec = agentstack_sdk.a2a.extensions.OAuthExtensionSpec.from_agent_card(card)
73+
tool_call_spec = agentstack_sdk.a2a.extensions.ToolCallExtensionSpec.from_agent_card(card)
7274

73-
if not mcp_spec:
75+
if not mcp_service_spec:
7476
raise ValueError(f"Agent at {base_url} does not support MCP service injection")
7577
if not oauth_spec:
7678
raise ValueError(f"Agent at {base_url} does not support oAuth")
79+
if not tool_call_spec:
80+
raise ValueError(f"Agent at {base_url} does not support MCP")
7781

78-
mcp_extension_client = agentstack_sdk.a2a.extensions.MCPServiceExtensionClient(mcp_spec)
82+
mcp_service_extension_client = agentstack_sdk.a2a.extensions.MCPServiceExtensionClient(mcp_service_spec)
7983
oauth_extension_client = agentstack_sdk.a2a.extensions.OAuthExtensionClient(oauth_spec)
84+
tool_call_extension_client = agentstack_sdk.a2a.extensions.ToolCallExtensionClient(tool_call_spec)
8085

8186
oauth = OAuthHandler()
8287
message = a2a.types.Message(
8388
message_id=str(uuid.uuid4()),
8489
role=a2a.types.Role.user,
8590
parts=[a2a.types.Part(root=a2a.types.TextPart(text="Howdy!"))],
86-
metadata=mcp_extension_client.fulfillment_metadata(
91+
metadata=mcp_service_extension_client.fulfillment_metadata(
8792
mcp_fulfillments={
8893
key: agentstack_sdk.a2a.extensions.services.mcp.MCPFulfillment(
8994
transport=agentstack_sdk.a2a.extensions.services.mcp.StreamableHTTPTransport(
9095
url=AnyHttpUrl("https://mcp.stripe.com")
9196
),
9297
)
93-
for key in mcp_spec.params.mcp_demands
98+
for key in mcp_service_spec.params.mcp_demands
9499
}
95100
)
96101
| oauth_extension_client.fulfillment_metadata(
97102
oauth_fulfillments={
98103
key: agentstack_sdk.a2a.extensions.OAuthFulfillment(redirect_uri=AnyUrl(oauth.redirect_uri))
99104
for key in oauth_spec.params.oauth_demands
100105
}
101-
),
106+
)
107+
| tool_call_extension_client.metadata(),
102108
)
103109

104110
client = a2a.client.ClientFactory(a2a.client.ClientConfig(httpx_client=httpx_client, polling=True)).create(
105111
card=card
106112
)
107113

108114
task = None
109-
async for event in client.send_message(message):
110-
if isinstance(event, a2a.types.Message):
111-
print(event)
112-
return
113-
task, _update = event
115+
while True:
116+
async for event in client.send_message(message):
117+
if isinstance(event, a2a.types.Message):
118+
print(event)
119+
return
120+
task, _update = event
114121

115-
if task and task.status.state == a2a.types.TaskState.auth_required:
116-
if not task.status.message:
117-
raise RuntimeError("Missing message")
122+
if task and task.status.state == a2a.types.TaskState.auth_required:
123+
if not task.status.message:
124+
raise RuntimeError("Missing message")
118125

119-
auth_request = oauth_extension_client.parse_auth_request(message=task.status.message)
126+
auth_request = oauth_extension_client.parse_auth_request(message=task.status.message)
120127

121-
print("Agent has requested authorization")
122-
oauth.open_browser(str(auth_request.authorization_endpoint_url))
123-
request = await oauth.handle_redirect()
128+
print("Agent has requested authorization")
129+
oauth.open_browser(str(auth_request.authorization_endpoint_url))
130+
request = await oauth.handle_redirect()
124131

125-
async for event in client.send_message(
126-
oauth_extension_client.create_auth_response(task_id=task.id, redirect_uri=AnyUrl(str(request.url)))
127-
):
128-
if isinstance(event, a2a.types.Message):
129-
raise RuntimeError("Agent responded with message to a task")
130-
task, _update = event
132+
message = oauth_extension_client.create_auth_response(
133+
task_id=task.id, redirect_uri=AnyUrl(str(request.url))
134+
)
135+
elif task and task.status.state == a2a.types.TaskState.input_required:
136+
if not task.status.message:
137+
raise RuntimeError("Missing message")
138+
139+
approval_request = tool_call_extension_client.parse_request(message=task.status.message)
140+
141+
print("Agent has requested a tool call")
142+
print(approval_request)
143+
choice = input("Approve (Y/n): ")
144+
response = ToolCallResponse(action="accept" if choice.lower() == "y" else "reject")
145+
message = tool_call_extension_client.create_response_message(task_id=task.id, response=response)
146+
else:
147+
break
131148

132149
print(task)
133150

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from typing import Annotated
5+
6+
from a2a.types import Message
7+
from mcp import ClientSession
8+
from mcp.client.streamable_http import streamablehttp_client
9+
from mcp.types import TextContent
10+
11+
from agentstack_sdk.a2a.extensions.tools.call import (
12+
ToolCallExtensionParams,
13+
ToolCallExtensionServer,
14+
ToolCallExtensionSpec,
15+
ToolCallRequest,
16+
)
17+
from agentstack_sdk.a2a.extensions.tools.exceptions import ToolCallRejectionError
18+
from agentstack_sdk.server import Server
19+
from agentstack_sdk.server.context import RunContext
20+
21+
server = Server()
22+
23+
24+
@server.agent()
25+
async def tool_call_approval_agent(
26+
message: Message,
27+
context: RunContext,
28+
mcp_tool_call: Annotated[ToolCallExtensionServer, ToolCallExtensionSpec(params=ToolCallExtensionParams())],
29+
):
30+
async with (
31+
streamablehttp_client(url="https://hf.co/mcp") as (read, write, _),
32+
ClientSession(read, write) as session,
33+
):
34+
session_init_result = await session.initialize()
35+
36+
list_tools_result = await session.list_tools()
37+
tools = {tool.name: tool for tool in list_tools_result.tools}
38+
39+
whoami_tool = tools.get("hf_whoami")
40+
if not whoami_tool:
41+
raise RuntimeError("Could not find whoami_tool on the server")
42+
43+
arguments = {}
44+
try:
45+
await mcp_tool_call.request_tool_call_approval(
46+
ToolCallRequest.from_mcp_tool(whoami_tool, arguments, server=session_init_result.serverInfo),
47+
context=context,
48+
)
49+
result = await session.call_tool("hf_whoami", arguments)
50+
content = result.content[0]
51+
if isinstance(content, TextContent):
52+
yield content.text
53+
else:
54+
yield "Tool call succeeded"
55+
except ToolCallRejectionError:
56+
yield "Tool call has been rejected by the client"
57+
58+
59+
if __name__ == "__main__":
60+
server.run()
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import asyncio
5+
import uuid
6+
7+
import a2a.client
8+
import a2a.types
9+
import httpx
10+
11+
import agentstack_sdk.a2a.extensions
12+
from agentstack_sdk.a2a.extensions.tools.call import ToolCallResponse
13+
14+
15+
async def run(base_url: str = "http://127.0.0.1:10000"):
16+
async with httpx.AsyncClient(timeout=30) as httpx_client:
17+
card = await a2a.client.A2ACardResolver(httpx_client, base_url=base_url).get_agent_card()
18+
tool_call_spec = agentstack_sdk.a2a.extensions.ToolCallExtensionSpec.from_agent_card(card)
19+
20+
if not tool_call_spec:
21+
raise ValueError(f"Agent at {base_url} does not support MCP Tool Call extension")
22+
23+
tool_call_extension_client = agentstack_sdk.a2a.extensions.ToolCallExtensionClient(tool_call_spec)
24+
25+
message = a2a.types.Message(
26+
message_id=str(uuid.uuid4()),
27+
role=a2a.types.Role.user,
28+
parts=[a2a.types.Part(root=a2a.types.TextPart(text="Howdy!"))],
29+
metadata=tool_call_extension_client.metadata(),
30+
)
31+
32+
client = a2a.client.ClientFactory(a2a.client.ClientConfig(httpx_client=httpx_client, polling=True)).create(
33+
card=card
34+
)
35+
36+
task = None
37+
while True:
38+
async for event in client.send_message(message):
39+
if isinstance(event, a2a.types.Message):
40+
print(event)
41+
return
42+
task, _update = event
43+
44+
if task and task.status.state == a2a.types.TaskState.input_required:
45+
if not task.status.message:
46+
raise RuntimeError("Missing message")
47+
48+
approval_request = tool_call_extension_client.parse_request(message=task.status.message)
49+
50+
print("Agent has requested a tool call")
51+
print(approval_request)
52+
choice = input("Approve (Y/n): ")
53+
response = ToolCallResponse(action="accept" if choice.lower() == "y" else "reject")
54+
message = tool_call_extension_client.create_response_message(task_id=task.id, response=response)
55+
else:
56+
break
57+
58+
print(task)
59+
60+
61+
if __name__ == "__main__":
62+
asyncio.run(run())

apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33

44
from .auth import *
55
from .services import *
6+
from .tools import *
67
from .ui import *
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from .call import *
5+
from .exceptions import *

0 commit comments

Comments
 (0)