diff --git a/apps/agentstack-sdk-py/examples/mcp_agent.py b/apps/agentstack-sdk-py/examples/mcp_agent.py index be7e26878..d00c1ba56 100644 --- a/apps/agentstack-sdk-py/examples/mcp_agent.py +++ b/apps/agentstack-sdk-py/examples/mcp_agent.py @@ -9,6 +9,12 @@ from agentstack_sdk.a2a.extensions.auth.oauth import OAuthExtensionServer, OAuthExtensionSpec from agentstack_sdk.a2a.extensions.services.mcp import MCPServiceExtensionServer, MCPServiceExtensionSpec +from agentstack_sdk.a2a.extensions.tools.call import ( + ToolCallExtensionParams, + ToolCallExtensionServer, + ToolCallExtensionSpec, + ToolCallRequest, +) from agentstack_sdk.a2a.types import RunYield from agentstack_sdk.server import Server from agentstack_sdk.server.context import RunContext @@ -21,24 +27,43 @@ async def mcp_agent( message: Message, context: RunContext, oauth: Annotated[OAuthExtensionServer, OAuthExtensionSpec.single_demand()], - mcp: Annotated[ + mcp_tool_call: Annotated[ToolCallExtensionServer, ToolCallExtensionSpec(params=ToolCallExtensionParams())], + mcp_service: Annotated[ MCPServiceExtensionServer, MCPServiceExtensionSpec.single_demand(), ], ) -> AsyncGenerator[RunYield, Message]: """Lists tools""" - if not mcp: + if not mcp_service: yield "MCP extension hasn't been activated, no tools are available" return - async with mcp.create_client() as (read, write), ClientSession(read, write) as session: - await session.initialize() + if not mcp_tool_call: + yield "MCP Tool Call extension hasn't been activated, no approval requests will be issued" - tools = await session.list_tools() + async with ( + mcp_service.create_client() as (read, write), + ClientSession(read, write) as session, + ): + session_init_result = await session.initialize() + + result = await session.list_tools() yield "Available tools: \n" - yield "\n".join([t.name for t in tools.tools]) + yield "\n".join([t.name for t in result.tools]) + + if result.tools: + tool = result.tools[0] + input = {} + yield f"Requesting approval for tool {tool.name}" + if mcp_tool_call: + await mcp_tool_call.request_tool_call_approval( + ToolCallRequest.from_mcp_tool(tool, input, server=session_init_result.serverInfo), context=context + ) + yield f"Calling tool {tool.name}" + await session.call_tool(tool.name, input) + yield "Tool call finished" if __name__ == "__main__": diff --git a/apps/agentstack-sdk-py/examples/mcp_client.py b/apps/agentstack-sdk-py/examples/mcp_client.py index 7b80dca1d..0421826ac 100644 --- a/apps/agentstack-sdk-py/examples/mcp_client.py +++ b/apps/agentstack-sdk-py/examples/mcp_client.py @@ -12,6 +12,7 @@ from pydantic import AnyHttpUrl, AnyUrl import agentstack_sdk.a2a.extensions +from agentstack_sdk.a2a.extensions.tools.call import ToolCallResponse class OAuthHandler: @@ -67,30 +68,34 @@ async def handler(request: web.Request) -> web.Response: async def run(base_url: str = "http://127.0.0.1:10000"): async with httpx.AsyncClient(timeout=30) as httpx_client: card = await a2a.client.A2ACardResolver(httpx_client, base_url=base_url).get_agent_card() - mcp_spec = agentstack_sdk.a2a.extensions.MCPServiceExtensionSpec.from_agent_card(card) + mcp_service_spec = agentstack_sdk.a2a.extensions.MCPServiceExtensionSpec.from_agent_card(card) oauth_spec = agentstack_sdk.a2a.extensions.OAuthExtensionSpec.from_agent_card(card) + tool_call_spec = agentstack_sdk.a2a.extensions.ToolCallExtensionSpec.from_agent_card(card) - if not mcp_spec: + if not mcp_service_spec: raise ValueError(f"Agent at {base_url} does not support MCP service injection") if not oauth_spec: raise ValueError(f"Agent at {base_url} does not support oAuth") + if not tool_call_spec: + raise ValueError(f"Agent at {base_url} does not support MCP") - mcp_extension_client = agentstack_sdk.a2a.extensions.MCPServiceExtensionClient(mcp_spec) + mcp_service_extension_client = agentstack_sdk.a2a.extensions.MCPServiceExtensionClient(mcp_service_spec) oauth_extension_client = agentstack_sdk.a2a.extensions.OAuthExtensionClient(oauth_spec) + tool_call_extension_client = agentstack_sdk.a2a.extensions.ToolCallExtensionClient(tool_call_spec) oauth = OAuthHandler() message = a2a.types.Message( message_id=str(uuid.uuid4()), role=a2a.types.Role.user, parts=[a2a.types.Part(root=a2a.types.TextPart(text="Howdy!"))], - metadata=mcp_extension_client.fulfillment_metadata( + metadata=mcp_service_extension_client.fulfillment_metadata( mcp_fulfillments={ key: agentstack_sdk.a2a.extensions.services.mcp.MCPFulfillment( transport=agentstack_sdk.a2a.extensions.services.mcp.StreamableHTTPTransport( url=AnyHttpUrl("https://mcp.stripe.com") ), ) - for key in mcp_spec.params.mcp_demands + for key in mcp_service_spec.params.mcp_demands } ) | oauth_extension_client.fulfillment_metadata( @@ -98,7 +103,8 @@ async def run(base_url: str = "http://127.0.0.1:10000"): key: agentstack_sdk.a2a.extensions.OAuthFulfillment(redirect_uri=AnyUrl(oauth.redirect_uri)) for key in oauth_spec.params.oauth_demands } - ), + ) + | tool_call_extension_client.metadata(), ) client = a2a.client.ClientFactory(a2a.client.ClientConfig(httpx_client=httpx_client, polling=True)).create( @@ -106,28 +112,39 @@ async def run(base_url: str = "http://127.0.0.1:10000"): ) task = None - async for event in client.send_message(message): - if isinstance(event, a2a.types.Message): - print(event) - return - task, _update = event + while True: + async for event in client.send_message(message): + if isinstance(event, a2a.types.Message): + print(event) + return + task, _update = event - if task and task.status.state == a2a.types.TaskState.auth_required: - if not task.status.message: - raise RuntimeError("Missing message") + if task and task.status.state == a2a.types.TaskState.auth_required: + if not task.status.message: + raise RuntimeError("Missing message") - auth_request = oauth_extension_client.parse_auth_request(message=task.status.message) + auth_request = oauth_extension_client.parse_auth_request(message=task.status.message) - print("Agent has requested authorization") - oauth.open_browser(str(auth_request.authorization_endpoint_url)) - request = await oauth.handle_redirect() + print("Agent has requested authorization") + oauth.open_browser(str(auth_request.authorization_endpoint_url)) + request = await oauth.handle_redirect() - async for event in client.send_message( - oauth_extension_client.create_auth_response(task_id=task.id, redirect_uri=AnyUrl(str(request.url))) - ): - if isinstance(event, a2a.types.Message): - raise RuntimeError("Agent responded with message to a task") - task, _update = event + message = oauth_extension_client.create_auth_response( + task_id=task.id, redirect_uri=AnyUrl(str(request.url)) + ) + elif task and task.status.state == a2a.types.TaskState.input_required: + if not task.status.message: + raise RuntimeError("Missing message") + + approval_request = tool_call_extension_client.parse_request(message=task.status.message) + + print("Agent has requested a tool call") + print(approval_request) + choice = input("Approve (Y/n): ") + response = ToolCallResponse(action="accept" if choice.lower() == "y" else "reject") + message = tool_call_extension_client.create_response_message(task_id=task.id, response=response) + else: + break print(task) diff --git a/apps/agentstack-sdk-py/examples/tool_call_approval_agent.py b/apps/agentstack-sdk-py/examples/tool_call_approval_agent.py new file mode 100644 index 000000000..accebf869 --- /dev/null +++ b/apps/agentstack-sdk-py/examples/tool_call_approval_agent.py @@ -0,0 +1,60 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +from typing import Annotated + +from a2a.types import Message +from mcp import ClientSession +from mcp.client.streamable_http import streamablehttp_client +from mcp.types import TextContent + +from agentstack_sdk.a2a.extensions.tools.call import ( + ToolCallExtensionParams, + ToolCallExtensionServer, + ToolCallExtensionSpec, + ToolCallRequest, +) +from agentstack_sdk.a2a.extensions.tools.exceptions import ToolCallRejectionError +from agentstack_sdk.server import Server +from agentstack_sdk.server.context import RunContext + +server = Server() + + +@server.agent() +async def tool_call_approval_agent( + message: Message, + context: RunContext, + mcp_tool_call: Annotated[ToolCallExtensionServer, ToolCallExtensionSpec(params=ToolCallExtensionParams())], +): + async with ( + streamablehttp_client(url="https://hf.co/mcp") as (read, write, _), + ClientSession(read, write) as session, + ): + session_init_result = await session.initialize() + + list_tools_result = await session.list_tools() + tools = {tool.name: tool for tool in list_tools_result.tools} + + whoami_tool = tools.get("hf_whoami") + if not whoami_tool: + raise RuntimeError("Could not find whoami_tool on the server") + + arguments = {} + try: + await mcp_tool_call.request_tool_call_approval( + ToolCallRequest.from_mcp_tool(whoami_tool, arguments, server=session_init_result.serverInfo), + context=context, + ) + result = await session.call_tool("hf_whoami", arguments) + content = result.content[0] + if isinstance(content, TextContent): + yield content.text + else: + yield "Tool call succeeded" + except ToolCallRejectionError: + yield "Tool call has been rejected by the client" + + +if __name__ == "__main__": + server.run() diff --git a/apps/agentstack-sdk-py/examples/tool_call_approval_client.py b/apps/agentstack-sdk-py/examples/tool_call_approval_client.py new file mode 100644 index 000000000..64d6f9926 --- /dev/null +++ b/apps/agentstack-sdk-py/examples/tool_call_approval_client.py @@ -0,0 +1,62 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import uuid + +import a2a.client +import a2a.types +import httpx + +import agentstack_sdk.a2a.extensions +from agentstack_sdk.a2a.extensions.tools.call import ToolCallResponse + + +async def run(base_url: str = "http://127.0.0.1:10000"): + async with httpx.AsyncClient(timeout=30) as httpx_client: + card = await a2a.client.A2ACardResolver(httpx_client, base_url=base_url).get_agent_card() + tool_call_spec = agentstack_sdk.a2a.extensions.ToolCallExtensionSpec.from_agent_card(card) + + if not tool_call_spec: + raise ValueError(f"Agent at {base_url} does not support MCP Tool Call extension") + + tool_call_extension_client = agentstack_sdk.a2a.extensions.ToolCallExtensionClient(tool_call_spec) + + message = a2a.types.Message( + message_id=str(uuid.uuid4()), + role=a2a.types.Role.user, + parts=[a2a.types.Part(root=a2a.types.TextPart(text="Howdy!"))], + metadata=tool_call_extension_client.metadata(), + ) + + client = a2a.client.ClientFactory(a2a.client.ClientConfig(httpx_client=httpx_client, polling=True)).create( + card=card + ) + + task = None + while True: + async for event in client.send_message(message): + if isinstance(event, a2a.types.Message): + print(event) + return + task, _update = event + + if task and task.status.state == a2a.types.TaskState.input_required: + if not task.status.message: + raise RuntimeError("Missing message") + + approval_request = tool_call_extension_client.parse_request(message=task.status.message) + + print("Agent has requested a tool call") + print(approval_request) + choice = input("Approve (Y/n): ") + response = ToolCallResponse(action="accept" if choice.lower() == "y" else "reject") + message = tool_call_extension_client.create_response_message(task_id=task.id, response=response) + else: + break + + print(task) + + +if __name__ == "__main__": + asyncio.run(run()) diff --git a/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/__init__.py b/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/__init__.py index 5a5d9cd6d..6fcc7e3d5 100644 --- a/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/__init__.py +++ b/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/__init__.py @@ -3,4 +3,5 @@ from .auth import * from .services import * +from .tools import * from .ui import * diff --git a/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/tools/__init__.py b/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/tools/__init__.py new file mode 100644 index 000000000..650dc8d01 --- /dev/null +++ b/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/tools/__init__.py @@ -0,0 +1,5 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +from .call import * +from .exceptions import * diff --git a/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/tools/call.py b/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/tools/call.py new file mode 100644 index 000000000..6a0f8ea85 --- /dev/null +++ b/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/tools/call.py @@ -0,0 +1,114 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import uuid +from types import NoneType +from typing import TYPE_CHECKING, Any, Literal + +import a2a.types +from mcp import Tool +from mcp.types import Implementation +from pydantic import BaseModel, Field + +from agentstack_sdk.a2a.extensions.base import BaseExtensionClient, BaseExtensionServer, BaseExtensionSpec +from agentstack_sdk.a2a.extensions.tools.exceptions import ToolCallRejectionError +from agentstack_sdk.a2a.types import AgentMessage, InputRequired + +if TYPE_CHECKING: + from agentstack_sdk.server.context import RunContext + + +class ToolCallServer(BaseModel): + name: str = Field(description="The programmatic name of the server.") + title: str | None = Field(description="A human-readable title for the server.") + version: str = Field(description="The version of the server.") + + +class ToolCallRequest(BaseModel): + name: str = Field(description="The programmatic name of the tool.") + title: str | None = Field(None, description="A human-readable title for the tool.") + description: str | None = Field(None, description="A human-readable description of the tool.") + + input: dict[str, Any] | None = Field(description="The input for the tool.") + + server: ToolCallServer | None = Field(None, description="The server executing the tool.") + + @staticmethod + def from_mcp_tool( + tool: Tool, input: dict[str, Any] | None, server: Implementation | None = None + ) -> ToolCallRequest: + return ToolCallRequest( + name=tool.name, + title=tool.annotations.title if tool.annotations else None, + description=tool.description, + input=input, + server=ToolCallServer(name=server.name, title=server.title, version=server.version) if server else None, + ) + + +class ToolCallResponse(BaseModel): + action: Literal["accept", "reject"] + + +class ToolCallExtensionParams(BaseModel): + pass + + +class ToolCallExtensionSpec(BaseExtensionSpec[ToolCallExtensionParams]): + URI: str = "https://a2a-extensions.agentstack.beeai.dev/tools/call/v1" + + +class ToolCallExtensionMetadata(BaseModel): + pass + + +class ToolCallExtensionServer(BaseExtensionServer[ToolCallExtensionSpec, ToolCallExtensionMetadata]): + def create_request_message(self, *, request: ToolCallRequest): + return AgentMessage( + text="Tool call approval requested", metadata={self.spec.URI: request.model_dump(mode="json")} + ) + + def parse_response(self, *, message: a2a.types.Message): + if not message or not message.metadata or not (data := message.metadata.get(self.spec.URI)): + raise RuntimeError("Invalid mcp response") + return ToolCallResponse.model_validate(data) + + async def request_tool_call_approval( + self, + request: ToolCallRequest, + *, + context: RunContext, + ) -> ToolCallResponse: + message = self.create_request_message(request=request) + message = await context.yield_async(InputRequired(message=message)) + if message: + result = self.parse_response(message=message) + match result.action: + case "accept": + return result + case "reject": + raise ToolCallRejectionError("User has rejected the tool call") + + else: + raise RuntimeError("Yield did not return a message") + + +class ToolCallExtensionClient(BaseExtensionClient[ToolCallExtensionSpec, NoneType]): + def create_response_message(self, *, response: ToolCallResponse, task_id: str | None): + return a2a.types.Message( + message_id=str(uuid.uuid4()), + role=a2a.types.Role.user, + parts=[], + task_id=task_id, + metadata={self.spec.URI: response.model_dump(mode="json")}, + ) + + def parse_request(self, *, message: a2a.types.Message): + if not message or not message.metadata or not (data := message.metadata.get(self.spec.URI)): + raise ValueError("Invalid tool call request") + return ToolCallRequest.model_validate(data) + + def metadata(self) -> dict[str, Any]: + return {self.spec.URI: ToolCallExtensionMetadata().model_dump(mode="json")} diff --git a/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/tools/exceptions.py b/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/tools/exceptions.py new file mode 100644 index 000000000..51748acac --- /dev/null +++ b/apps/agentstack-sdk-py/src/agentstack_sdk/a2a/extensions/tools/exceptions.py @@ -0,0 +1,6 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + + +class ToolCallRejectionError(RuntimeError): + pass diff --git a/docs/agent-development/tool-calls.mdx b/docs/agent-development/tool-calls.mdx new file mode 100644 index 000000000..bd185474f --- /dev/null +++ b/docs/agent-development/tool-calls.mdx @@ -0,0 +1,82 @@ +--- +title: "Approve Tool Calls" +description: "Have tool calls approved by the user before execution" +--- + +Many agent frameworks support the ability to request user approval before executing certain actions. This is especially useful when an agent is calling external tools that may have significant effects or costs associated with their usage. + +The Tool Call extension provides a mechanism for implementing this functionality over A2A connection. + +## Usage + + + + Inject the `ToolCallExtension` into your agent function using the `Annotated` + type hint. + + + + Use `request_tool_call_approval()` method to request tool call approval from the A2A client side. + + + +## Basic Example + +Here's how to use this extension with the [BeeAI Framework](https://framework.beeai.dev/modules/agents/requirement-agent#ask-permission-requirement) to request user approval before executing a tool call: + +```python +from typing import Annotated, Any + +from a2a.types import ( + Message, +) +from agentstack_sdk.server import Server +from agentstack_sdk.server.context import RunContext +from agentstack_sdk.a2a.extensions.tools.call import ( + ToolCallExtensionParams, + ToolCallExtensionServer, + ToolCallExtensionSpec, + ToolCallRequest, + ToolCallRejectionError, +) +from beeai_framework.agents.requirement import RequirementAgent +from beeai_framework.backend import ChatModel +from beeai_framework.agents.requirement.requirements.ask_permission import AskPermissionRequirement +from beeai_framework.tools import Tool +from beeai_framework.tools.think import ThinkTool +from beeai_framework.adapters.mcp.serve.server import _tool_factory + +server = Server() + + +@server.agent() +async def tool_call_agent( + input: Message, + context: RunContext, + mcp_tool_call: Annotated[ToolCallExtensionServer, ToolCallExtensionSpec(params=ToolCallExtensionParams())], +): + async def handler(tool: Tool, input: dict[str, Any]) -> bool: + try: + await mcp_tool_call.request_tool_call_approval( + # using MCP Tool data model as intermediary to simplify conversion + ToolCallRequest.from_mcp_tool(_tool_factory(tool), input=input), + context=context, + ) + return True + except ToolCallRejectionError: + return False + + think_tool = ThinkTool() + agent = RequirementAgent( + llm=ChatModel.from_name("ollama:gpt-oss:20b"), + tools=[think_tool], + requirements=[AskPermissionRequirement([think_tool], handler=handler)], + ) + + result = await agent.run("".join(part.root.text for part in input.parts if part.root.kind == "text")) + yield result.output[0].text + + +if __name__ == "__main__": + server.run() +``` diff --git a/docs/docs.json b/docs/docs.json index 213979b7d..8afba5383 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -12,10 +12,7 @@ "groups": [ { "group": "Introduction", - "pages": [ - "introduction/welcome", - "introduction/quickstart" - ] + "pages": ["introduction/welcome", "introduction/quickstart"] }, { "group": "Deploy Agents", @@ -43,14 +40,13 @@ "agent-development/mcp-oauth", "agent-development/mcp", "agent-development/env-variables", - "agent-development/error" + "agent-development/error", + "agent-development/tool-calls" ] }, { "group": "Reference", - "pages": [ - "reference/cli-reference" - ] + "pages": ["reference/cli-reference"] }, { "group": "Deploy Agent Stack", @@ -62,17 +58,11 @@ }, { "group": "Custom UI Integration", - "pages": [ - "custom-ui/client-sdk", - "custom-ui/permissions-and-tokens" - ] + "pages": ["custom-ui/client-sdk", "custom-ui/permissions-and-tokens"] }, { "group": "Experimental", - "pages": [ - "experimental/connectors", - "experimental/a2a-proxy" - ] + "pages": ["experimental/connectors", "experimental/a2a-proxy"] }, { "group": "Community",