Skip to content

Commit 86ebe47

Browse files
author
wenhaozhao
committed
feat: support AsyncMcpTool
1 parent 39a8bc6 commit 86ebe47

File tree

3 files changed

+249
-54
lines changed

3 files changed

+249
-54
lines changed
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.tools.mcp;
18+
19+
import com.fasterxml.jackson.databind.ObjectMapper;
20+
import com.google.adk.JsonBaseModel;
21+
import com.google.adk.tools.BaseTool;
22+
import com.google.adk.tools.ToolContext;
23+
import com.google.common.collect.ImmutableMap;
24+
import com.google.genai.types.FunctionDeclaration;
25+
import com.google.genai.types.Schema;
26+
import io.modelcontextprotocol.client.McpAsyncClient;
27+
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
28+
import io.modelcontextprotocol.spec.McpSchema.JsonSchema;
29+
import io.modelcontextprotocol.spec.McpSchema.Tool;
30+
import io.reactivex.rxjava3.core.Maybe;
31+
import io.reactivex.rxjava3.core.Single;
32+
33+
import java.util.Map;
34+
import java.util.Optional;
35+
36+
import static java.util.concurrent.TimeUnit.MILLISECONDS;
37+
38+
// TODO(b/413489523): Add support for auth. This is a TODO for Python as well.
39+
40+
/**
41+
* Initializes a MCP tool.
42+
*
43+
* <p>This wraps a MCP Tool interface and an active MCP Session. It invokes the MCP Tool through
44+
* executing the tool from remote MCP Session.
45+
*/
46+
public final class McpAsyncTool extends BaseTool {
47+
48+
Tool mcpTool;
49+
Single<McpAsyncClient> mcpSession;
50+
McpSessionManager mcpSessionManager;
51+
ObjectMapper objectMapper;
52+
53+
/**
54+
* Creates a new McpTool with the default ObjectMapper.
55+
*
56+
* @param mcpTool The MCP tool to wrap.
57+
* @param mcpSession The MCP session to use to call the tool.
58+
* @param mcpSessionManager The MCP session manager to use to create new sessions.
59+
* @throws IllegalArgumentException If mcpTool or mcpSession are null.
60+
*/
61+
public McpAsyncTool(Tool mcpTool, Single<McpAsyncClient> mcpSession, McpSessionManager mcpSessionManager) {
62+
this(mcpTool, mcpSession, mcpSessionManager, JsonBaseModel.getMapper());
63+
}
64+
65+
/**
66+
* Creates a new McpTool with the default ObjectMapper.
67+
*
68+
* @param mcpTool The MCP tool to wrap.
69+
* @param mcpSession The MCP session to use to call the tool.
70+
* @param mcpSessionManager The MCP session manager to use to create new sessions.
71+
* @param objectMapper The ObjectMapper to use to convert JSON schemas.
72+
* @throws IllegalArgumentException If mcpTool or mcpSession are null.
73+
*/
74+
public McpAsyncTool(
75+
Tool mcpTool,
76+
Single<McpAsyncClient> mcpSession,
77+
McpSessionManager mcpSessionManager,
78+
ObjectMapper objectMapper) {
79+
super(
80+
mcpTool == null ? "" : mcpTool.name(),
81+
mcpTool == null ? "" : (mcpTool.description().isEmpty() ? "" : mcpTool.description()));
82+
83+
if (mcpTool == null) {
84+
throw new IllegalArgumentException("mcpTool cannot be null");
85+
}
86+
if (mcpSession == null) {
87+
throw new IllegalArgumentException("mcpSession cannot be null");
88+
}
89+
if (objectMapper == null) {
90+
throw new IllegalArgumentException("objectMapper cannot be null");
91+
}
92+
this.mcpTool = mcpTool;
93+
this.mcpSession = mcpSession;
94+
this.mcpSessionManager = mcpSessionManager;
95+
this.objectMapper = objectMapper;
96+
}
97+
98+
public Single<McpAsyncClient> getMcpSession() {
99+
return this.mcpSession;
100+
}
101+
102+
public Schema toGeminiSchema(JsonSchema openApiSchema) {
103+
return Schema.fromJson(objectMapper.valueToTree(openApiSchema).toString());
104+
}
105+
106+
private void reintializeSession() {
107+
this.mcpSession = this.mcpSessionManager.createAsyncSession();
108+
}
109+
110+
@Override
111+
public Optional<FunctionDeclaration> declaration() {
112+
return Optional.of(
113+
FunctionDeclaration.builder()
114+
.name(this.name())
115+
.description(this.description())
116+
.parameters(toGeminiSchema(this.mcpTool.inputSchema()))
117+
.build());
118+
}
119+
120+
@Override
121+
public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContext toolContext) {
122+
return Single.defer(() ->
123+
this.mcpSession.flatMapMaybe(client ->
124+
Maybe.fromCompletionStage(
125+
client.callTool(new CallToolRequest(this.name(), ImmutableMap.copyOf(args)))
126+
.toFuture()
127+
)
128+
).map(callResult -> McpTool.wrapCallResult(
129+
this.objectMapper, this.name(), callResult)
130+
).switchIfEmpty(
131+
Single.fromCallable(
132+
() -> McpTool.wrapCallResult(this.objectMapper, this.name(), null)
133+
)
134+
)
135+
)
136+
.retryWhen(
137+
errors ->
138+
errors
139+
.delay(100, MILLISECONDS)
140+
.take(3)
141+
.doOnNext(
142+
error -> {
143+
System.err.println("Retrying callTool due to: " + error);
144+
reintializeSession();
145+
}));
146+
}
147+
}

core/src/main/java/com/google/adk/tools/mcp/McpSessionManager.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616

1717
package com.google.adk.tools.mcp;
1818

19+
import io.modelcontextprotocol.client.McpAsyncClient;
1920
import io.modelcontextprotocol.client.McpClient;
2021
import io.modelcontextprotocol.client.McpSyncClient;
2122
import io.modelcontextprotocol.spec.McpClientTransport;
2223
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
2324
import io.modelcontextprotocol.spec.McpSchema.InitializeResult;
2425
import java.time.Duration;
26+
27+
import io.reactivex.rxjava3.core.Single;
2528
import org.slf4j.Logger;
2629
import org.slf4j.LoggerFactory;
2730

@@ -68,4 +71,40 @@ public static McpSyncClient initializeSession(
6871
logger.debug("Initialize Client Result: {}", initResult);
6972
return client;
7073
}
74+
75+
public Single<McpAsyncClient> createAsyncSession() {
76+
return initializeAsyncSession(this.connectionParams);
77+
}
78+
79+
public static Single<McpAsyncClient> initializeAsyncSession(Object connectionParams) {
80+
return initializeAsyncSession(connectionParams, new DefaultMcpTransportBuilder());
81+
}
82+
83+
public static Single<McpAsyncClient> initializeAsyncSession(
84+
Object connectionParams, McpTransportBuilder transportBuilder) {
85+
Duration initializationTimeout = null;
86+
Duration requestTimeout = null;
87+
McpClientTransport transport = transportBuilder.build(connectionParams);
88+
if (connectionParams instanceof SseServerParameters sseServerParams) {
89+
initializationTimeout = sseServerParams.timeout();
90+
requestTimeout = sseServerParams.sseReadTimeout();
91+
}
92+
McpAsyncClient client =
93+
McpClient.async(transport)
94+
.initializationTimeout(initializationTimeout == null ? Duration.ofSeconds(10) : initializationTimeout)
95+
.requestTimeout(requestTimeout == null ? Duration.ofSeconds(10) : requestTimeout)
96+
.capabilities(ClientCapabilities.builder().build())
97+
.build();
98+
return Single.fromCompletionStage(
99+
client.initialize()
100+
.doOnSuccess(initResult -> {
101+
logger.debug("Initialize McpAsyncClient Result: {}", initResult);
102+
})
103+
.doOnError(e -> {
104+
logger.error("Initialize McpAsyncClient Failed: {}", e.getMessage(), e);
105+
})
106+
.map(_initResult -> client)
107+
.toFuture()
108+
);
109+
}
71110
}

core/src/main/java/com/google/adk/tools/mcp/McpTool.java

Lines changed: 63 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,11 @@ public McpTool(
9999
this.objectMapper = objectMapper;
100100
}
101101

102-
public Schema toGeminiSchema(JsonSchema openApiSchema) {
102+
public McpSyncClient getMcpSession() {
103+
return this.mcpSession;
104+
}
105+
106+
public Schema toGeminiSchema(JsonSchema openApiSchema) {
103107
return Schema.fromJson(objectMapper.valueToTree(openApiSchema).toString());
104108
}
105109

@@ -117,65 +121,15 @@ public Optional<FunctionDeclaration> declaration() {
117121
.build());
118122
}
119123

124+
125+
120126
@Override
121127
public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContext toolContext) {
122128
return Single.<Map<String, Object>>fromCallable(
123129
() -> {
124130
CallToolResult callResult =
125131
mcpSession.callTool(new CallToolRequest(this.name(), ImmutableMap.copyOf(args)));
126-
127-
if (callResult == null) {
128-
return ImmutableMap.of("error", "MCP framework error: CallToolResult was null");
129-
}
130-
131-
List<Content> contents = callResult.content();
132-
Boolean isToolError = callResult.isError();
133-
134-
if (isToolError != null && isToolError) {
135-
String errorMessage = "Tool execution failed.";
136-
if (contents != null
137-
&& !contents.isEmpty()
138-
&& contents.get(0) instanceof TextContent) {
139-
TextContent textContent = (TextContent) contents.get(0);
140-
if (textContent.text() != null && !textContent.text().isEmpty()) {
141-
errorMessage += " Details: " + textContent.text();
142-
}
143-
}
144-
return ImmutableMap.of("error", errorMessage);
145-
}
146-
147-
if (contents == null || contents.isEmpty()) {
148-
return ImmutableMap.of();
149-
}
150-
151-
List<String> textOutputs = new ArrayList<>();
152-
for (Content content : contents) {
153-
if (content instanceof TextContent textContent) {
154-
if (textContent.text() != null) {
155-
textOutputs.add(textContent.text());
156-
}
157-
}
158-
}
159-
160-
if (textOutputs.isEmpty()) {
161-
return ImmutableMap.of(
162-
"error",
163-
"Tool '" + this.name() + "' returned content that is not TextContent.",
164-
"content_details",
165-
contents.toString());
166-
}
167-
168-
List<Map<String, Object>> resultMaps = new ArrayList<>();
169-
for (String textOutput : textOutputs) {
170-
try {
171-
resultMaps.add(
172-
objectMapper.readValue(
173-
textOutput, new TypeReference<Map<String, Object>>() {}));
174-
} catch (JsonProcessingException e) {
175-
resultMaps.add(ImmutableMap.of("text", textOutput));
176-
}
177-
}
178-
return ImmutableMap.of("text_output", resultMaps);
132+
return wrapCallResult(this.objectMapper, this.name(), callResult);
179133
})
180134
.retryWhen(
181135
errors ->
@@ -188,4 +142,59 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
188142
reintializeSession();
189143
}));
190144
}
145+
146+
static Map<String, Object> wrapCallResult(ObjectMapper objectMapper, String mcpToolName, CallToolResult callResult) {
147+
if (callResult == null) {
148+
return ImmutableMap.of("error", "MCP framework error: CallToolResult was null");
149+
}
150+
151+
List<Content> contents = callResult.content();
152+
Boolean isToolError = callResult.isError();
153+
154+
if (isToolError != null && isToolError) {
155+
String errorMessage = "Tool execution failed.";
156+
if (contents != null
157+
&& !contents.isEmpty()
158+
&& contents.get(0) instanceof TextContent textContent) {
159+
if (textContent.text() != null && !textContent.text().isEmpty()) {
160+
errorMessage += " Details: " + textContent.text();
161+
}
162+
}
163+
return ImmutableMap.of("error", errorMessage);
164+
}
165+
166+
if (contents == null || contents.isEmpty()) {
167+
return ImmutableMap.of();
168+
}
169+
170+
List<String> textOutputs = new ArrayList<>();
171+
for (Content content : contents) {
172+
if (content instanceof TextContent textContent) {
173+
if (textContent.text() != null) {
174+
textOutputs.add(textContent.text());
175+
}
176+
}
177+
}
178+
179+
if (textOutputs.isEmpty()) {
180+
return ImmutableMap.of(
181+
"error",
182+
"Tool '" + mcpToolName + "' returned content that is not TextContent.",
183+
"content_details",
184+
contents.toString());
185+
}
186+
187+
List<Map<String, Object>> resultMaps = new ArrayList<>();
188+
for (String textOutput : textOutputs) {
189+
try {
190+
resultMaps.add(
191+
objectMapper.readValue(
192+
textOutput, new TypeReference<Map<String, Object>>() {
193+
}));
194+
} catch (JsonProcessingException e) {
195+
resultMaps.add(ImmutableMap.of("text", textOutput));
196+
}
197+
}
198+
return ImmutableMap.of("text_output", resultMaps);
199+
}
191200
}

0 commit comments

Comments
 (0)