|
12 | 12 | from pydantic import AnyHttpUrl, AnyUrl |
13 | 13 |
|
14 | 14 | import agentstack_sdk.a2a.extensions |
| 15 | +from agentstack_sdk.a2a.extensions.tools.call import ToolCallResponse |
15 | 16 |
|
16 | 17 |
|
17 | 18 | class OAuthHandler: |
@@ -67,67 +68,83 @@ async def handler(request: web.Request) -> web.Response: |
67 | 68 | async def run(base_url: str = "http://127.0.0.1:10000"): |
68 | 69 | async with httpx.AsyncClient(timeout=30) as httpx_client: |
69 | 70 | card = await a2a.client.A2ACardResolver(httpx_client, base_url=base_url).get_agent_card() |
70 | | - mcp_spec = agentstack_sdk.a2a.extensions.MCPServiceExtensionSpec.from_agent_card(card) |
| 71 | + mcp_service_spec = agentstack_sdk.a2a.extensions.MCPServiceExtensionSpec.from_agent_card(card) |
71 | 72 | oauth_spec = agentstack_sdk.a2a.extensions.OAuthExtensionSpec.from_agent_card(card) |
| 73 | + tool_call_spec = agentstack_sdk.a2a.extensions.ToolCallExtensionSpec.from_agent_card(card) |
72 | 74 |
|
73 | | - if not mcp_spec: |
| 75 | + if not mcp_service_spec: |
74 | 76 | raise ValueError(f"Agent at {base_url} does not support MCP service injection") |
75 | 77 | if not oauth_spec: |
76 | 78 | raise ValueError(f"Agent at {base_url} does not support oAuth") |
| 79 | + if not tool_call_spec: |
| 80 | + raise ValueError(f"Agent at {base_url} does not support MCP") |
77 | 81 |
|
78 | | - mcp_extension_client = agentstack_sdk.a2a.extensions.MCPServiceExtensionClient(mcp_spec) |
| 82 | + mcp_service_extension_client = agentstack_sdk.a2a.extensions.MCPServiceExtensionClient(mcp_service_spec) |
79 | 83 | oauth_extension_client = agentstack_sdk.a2a.extensions.OAuthExtensionClient(oauth_spec) |
| 84 | + tool_call_extension_client = agentstack_sdk.a2a.extensions.ToolCallExtensionClient(tool_call_spec) |
80 | 85 |
|
81 | 86 | oauth = OAuthHandler() |
82 | 87 | message = a2a.types.Message( |
83 | 88 | message_id=str(uuid.uuid4()), |
84 | 89 | role=a2a.types.Role.user, |
85 | 90 | parts=[a2a.types.Part(root=a2a.types.TextPart(text="Howdy!"))], |
86 | | - metadata=mcp_extension_client.fulfillment_metadata( |
| 91 | + metadata=mcp_service_extension_client.fulfillment_metadata( |
87 | 92 | mcp_fulfillments={ |
88 | 93 | key: agentstack_sdk.a2a.extensions.services.mcp.MCPFulfillment( |
89 | 94 | transport=agentstack_sdk.a2a.extensions.services.mcp.StreamableHTTPTransport( |
90 | 95 | url=AnyHttpUrl("https://mcp.stripe.com") |
91 | 96 | ), |
92 | 97 | ) |
93 | | - for key in mcp_spec.params.mcp_demands |
| 98 | + for key in mcp_service_spec.params.mcp_demands |
94 | 99 | } |
95 | 100 | ) |
96 | 101 | | oauth_extension_client.fulfillment_metadata( |
97 | 102 | oauth_fulfillments={ |
98 | 103 | key: agentstack_sdk.a2a.extensions.OAuthFulfillment(redirect_uri=AnyUrl(oauth.redirect_uri)) |
99 | 104 | for key in oauth_spec.params.oauth_demands |
100 | 105 | } |
101 | | - ), |
| 106 | + ) |
| 107 | + | tool_call_extension_client.metadata(), |
102 | 108 | ) |
103 | 109 |
|
104 | 110 | client = a2a.client.ClientFactory(a2a.client.ClientConfig(httpx_client=httpx_client, polling=True)).create( |
105 | 111 | card=card |
106 | 112 | ) |
107 | 113 |
|
108 | 114 | task = None |
109 | | - async for event in client.send_message(message): |
110 | | - if isinstance(event, a2a.types.Message): |
111 | | - print(event) |
112 | | - return |
113 | | - task, _update = event |
| 115 | + while True: |
| 116 | + async for event in client.send_message(message): |
| 117 | + if isinstance(event, a2a.types.Message): |
| 118 | + print(event) |
| 119 | + return |
| 120 | + task, _update = event |
114 | 121 |
|
115 | | - if task and task.status.state == a2a.types.TaskState.auth_required: |
116 | | - if not task.status.message: |
117 | | - raise RuntimeError("Missing message") |
| 122 | + if task and task.status.state == a2a.types.TaskState.auth_required: |
| 123 | + if not task.status.message: |
| 124 | + raise RuntimeError("Missing message") |
118 | 125 |
|
119 | | - auth_request = oauth_extension_client.parse_auth_request(message=task.status.message) |
| 126 | + auth_request = oauth_extension_client.parse_auth_request(message=task.status.message) |
120 | 127 |
|
121 | | - print("Agent has requested authorization") |
122 | | - oauth.open_browser(str(auth_request.authorization_endpoint_url)) |
123 | | - request = await oauth.handle_redirect() |
| 128 | + print("Agent has requested authorization") |
| 129 | + oauth.open_browser(str(auth_request.authorization_endpoint_url)) |
| 130 | + request = await oauth.handle_redirect() |
124 | 131 |
|
125 | | - async for event in client.send_message( |
126 | | - oauth_extension_client.create_auth_response(task_id=task.id, redirect_uri=AnyUrl(str(request.url))) |
127 | | - ): |
128 | | - if isinstance(event, a2a.types.Message): |
129 | | - raise RuntimeError("Agent responded with message to a task") |
130 | | - task, _update = event |
| 132 | + message = oauth_extension_client.create_auth_response( |
| 133 | + task_id=task.id, redirect_uri=AnyUrl(str(request.url)) |
| 134 | + ) |
| 135 | + elif task and task.status.state == a2a.types.TaskState.input_required: |
| 136 | + if not task.status.message: |
| 137 | + raise RuntimeError("Missing message") |
| 138 | + |
| 139 | + approval_request = tool_call_extension_client.parse_request(message=task.status.message) |
| 140 | + |
| 141 | + print("Agent has requested a tool call") |
| 142 | + print(approval_request) |
| 143 | + choice = input("Approve (Y/n): ") |
| 144 | + response = ToolCallResponse(action="accept" if choice.lower() == "y" else "reject") |
| 145 | + message = tool_call_extension_client.create_response_message(task_id=task.id, response=response) |
| 146 | + else: |
| 147 | + break |
131 | 148 |
|
132 | 149 | print(task) |
133 | 150 |
|
|
0 commit comments