Skip to content

Commit d5ca1ec

Browse files
Merge pull request #273 from wenhaozhao:feat-async_mcp_tool
PiperOrigin-RevId: 788076672
2 parents cb95b56 + 0c50970 commit d5ca1ec

File tree

3 files changed

+251
-53
lines changed

3 files changed

+251
-53
lines changed
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.tools.mcp;
18+
19+
import static java.util.concurrent.TimeUnit.MILLISECONDS;
20+
21+
import com.fasterxml.jackson.databind.ObjectMapper;
22+
import com.google.adk.JsonBaseModel;
23+
import com.google.adk.tools.BaseTool;
24+
import com.google.adk.tools.ToolContext;
25+
import com.google.common.collect.ImmutableMap;
26+
import com.google.genai.types.FunctionDeclaration;
27+
import com.google.genai.types.Schema;
28+
import io.modelcontextprotocol.client.McpAsyncClient;
29+
import io.modelcontextprotocol.spec.McpSchema;
30+
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
31+
import io.modelcontextprotocol.spec.McpSchema.JsonSchema;
32+
import io.modelcontextprotocol.spec.McpSchema.Tool;
33+
import io.reactivex.rxjava3.core.Maybe;
34+
import io.reactivex.rxjava3.core.Single;
35+
import java.util.Map;
36+
import java.util.Optional;
37+
import org.slf4j.Logger;
38+
import org.slf4j.LoggerFactory;
39+
40+
// TODO(b/413489523): Add support for auth. This is a TODO for Python as well.
41+
42+
/**
43+
* Initializes a MCP tool.
44+
*
45+
* <p>This wraps a MCP Tool interface and an active MCP Session. It invokes the MCP Tool through
46+
* executing the tool from remote MCP Session.
47+
*/
48+
public final class McpAsyncTool extends BaseTool {
49+
50+
private static final Logger logger = LoggerFactory.getLogger(McpAsyncTool.class);
51+
52+
Tool mcpTool;
53+
// Volatile ensures write visibility in the asynchronous chain.
54+
volatile McpAsyncClient mcpSession;
55+
McpSessionManager mcpSessionManager;
56+
ObjectMapper objectMapper;
57+
58+
/**
59+
* Creates a new McpAsyncTool with the default ObjectMapper.
60+
*
61+
* @param mcpTool The MCP tool to wrap.
62+
* @param mcpSession The MCP session to use to call the tool.
63+
* @param mcpSessionManager The MCP session manager to use to create new sessions.
64+
* @throws IllegalArgumentException If mcpTool or mcpSession are null.
65+
*/
66+
public McpAsyncTool(
67+
Tool mcpTool, McpAsyncClient mcpSession, McpSessionManager mcpSessionManager) {
68+
this(mcpTool, mcpSession, mcpSessionManager, JsonBaseModel.getMapper());
69+
}
70+
71+
/**
72+
* Creates a new McpAsyncTool
73+
*
74+
* @param mcpTool The MCP tool to wrap.
75+
* @param mcpSession The MCP session to use to call the tool.
76+
* @param mcpSessionManager The MCP session manager to use to create new sessions.
77+
* @param objectMapper The ObjectMapper to use to convert JSON schemas.
78+
* @throws IllegalArgumentException If mcpTool or mcpSession are null.
79+
*/
80+
public McpAsyncTool(
81+
Tool mcpTool,
82+
McpAsyncClient mcpSession,
83+
McpSessionManager mcpSessionManager,
84+
ObjectMapper objectMapper) {
85+
super(
86+
mcpTool == null ? "" : mcpTool.name(),
87+
mcpTool == null ? "" : (mcpTool.description().isEmpty() ? "" : mcpTool.description()));
88+
89+
if (mcpTool == null) {
90+
throw new IllegalArgumentException("mcpTool cannot be null");
91+
}
92+
if (mcpSession == null) {
93+
throw new IllegalArgumentException("mcpSession cannot be null");
94+
}
95+
if (objectMapper == null) {
96+
throw new IllegalArgumentException("objectMapper cannot be null");
97+
}
98+
this.mcpTool = mcpTool;
99+
this.mcpSession = mcpSession;
100+
this.mcpSessionManager = mcpSessionManager;
101+
this.objectMapper = objectMapper;
102+
}
103+
104+
public McpAsyncClient getMcpSession() {
105+
return this.mcpSession;
106+
}
107+
108+
public Schema toGeminiSchema(JsonSchema openApiSchema) {
109+
return Schema.fromJson(objectMapper.valueToTree(openApiSchema).toString());
110+
}
111+
112+
private Single<McpSchema.InitializeResult> reintializeSession() {
113+
McpAsyncClient client = this.mcpSessionManager.createAsyncSession();
114+
return Single.fromCompletionStage(
115+
client
116+
.initialize()
117+
.doOnSuccess(
118+
initResult -> {
119+
logger.debug("Initialize McpAsyncClient Result: {}", initResult);
120+
})
121+
.doOnError(
122+
e -> {
123+
logger.error("Initialize McpAsyncClient Failed: {}", e.getMessage(), e);
124+
})
125+
.doOnNext(
126+
_initResult -> {
127+
this.mcpSession = client;
128+
})
129+
.toFuture());
130+
}
131+
132+
@Override
133+
public Optional<FunctionDeclaration> declaration() {
134+
return Optional.of(
135+
FunctionDeclaration.builder()
136+
.name(this.name())
137+
.description(this.description())
138+
.parameters(toGeminiSchema(this.mcpTool.inputSchema()))
139+
.build());
140+
}
141+
142+
@Override
143+
public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContext toolContext) {
144+
return Single.defer(
145+
() ->
146+
Maybe.fromCompletionStage(
147+
this.mcpSession
148+
.callTool(new CallToolRequest(this.name(), ImmutableMap.copyOf(args)))
149+
.toFuture())
150+
.map(
151+
callResult ->
152+
McpTool.wrapCallResult(this.objectMapper, this.name(), callResult))
153+
.switchIfEmpty(
154+
Single.fromCallable(
155+
() -> McpTool.wrapCallResult(this.objectMapper, this.name(), null))))
156+
.retryWhen(
157+
errors ->
158+
errors
159+
.delay(100, MILLISECONDS)
160+
.take(3)
161+
.doOnNext(
162+
error ->
163+
logger.error("Retrying callTool due to: {}", error.getMessage(), error))
164+
.flatMapSingle(_ignore -> this.reintializeSession()));
165+
}
166+
}

core/src/main/java/com/google/adk/tools/mcp/McpSessionManager.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.adk.tools.mcp;
1818

19+
import io.modelcontextprotocol.client.McpAsyncClient;
1920
import io.modelcontextprotocol.client.McpClient;
2021
import io.modelcontextprotocol.client.McpSyncClient;
2122
import io.modelcontextprotocol.spec.McpClientTransport;
@@ -70,4 +71,29 @@ public static McpSyncClient initializeSession(
7071
logger.debug("Initialize Client Result: {}", initResult);
7172
return client;
7273
}
74+
75+
public McpAsyncClient createAsyncSession() {
76+
return initializeAsyncSession(this.connectionParams);
77+
}
78+
79+
public static McpAsyncClient initializeAsyncSession(Object connectionParams) {
80+
return initializeAsyncSession(connectionParams, new DefaultMcpTransportBuilder());
81+
}
82+
83+
public static McpAsyncClient initializeAsyncSession(
84+
Object connectionParams, McpTransportBuilder transportBuilder) {
85+
Duration initializationTimeout = null;
86+
Duration requestTimeout = null;
87+
McpClientTransport transport = transportBuilder.build(connectionParams);
88+
if (connectionParams instanceof SseServerParameters sseServerParams) {
89+
initializationTimeout = sseServerParams.timeout();
90+
requestTimeout = sseServerParams.sseReadTimeout();
91+
}
92+
return McpClient.async(transport)
93+
.initializationTimeout(
94+
initializationTimeout == null ? Duration.ofSeconds(10) : initializationTimeout)
95+
.requestTimeout(requestTimeout == null ? Duration.ofSeconds(10) : requestTimeout)
96+
.capabilities(ClientCapabilities.builder().build())
97+
.build();
98+
}
7399
}

core/src/main/java/com/google/adk/tools/mcp/McpTool.java

Lines changed: 59 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ public McpTool(
9999
this.objectMapper = objectMapper;
100100
}
101101

102+
public McpSyncClient getMcpSession() {
103+
return this.mcpSession;
104+
}
105+
102106
public Schema toGeminiSchema(JsonSchema openApiSchema) {
103107
return Schema.fromJson(objectMapper.valueToTree(openApiSchema).toString());
104108
}
@@ -123,59 +127,7 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
123127
() -> {
124128
CallToolResult callResult =
125129
mcpSession.callTool(new CallToolRequest(this.name(), ImmutableMap.copyOf(args)));
126-
127-
if (callResult == null) {
128-
return ImmutableMap.of("error", "MCP framework error: CallToolResult was null");
129-
}
130-
131-
List<Content> contents = callResult.content();
132-
Boolean isToolError = callResult.isError();
133-
134-
if (isToolError != null && isToolError) {
135-
String errorMessage = "Tool execution failed.";
136-
if (contents != null
137-
&& !contents.isEmpty()
138-
&& contents.get(0) instanceof TextContent) {
139-
TextContent textContent = (TextContent) contents.get(0);
140-
if (textContent.text() != null && !textContent.text().isEmpty()) {
141-
errorMessage += " Details: " + textContent.text();
142-
}
143-
}
144-
return ImmutableMap.of("error", errorMessage);
145-
}
146-
147-
if (contents == null || contents.isEmpty()) {
148-
return ImmutableMap.of();
149-
}
150-
151-
List<String> textOutputs = new ArrayList<>();
152-
for (Content content : contents) {
153-
if (content instanceof TextContent textContent) {
154-
if (textContent.text() != null) {
155-
textOutputs.add(textContent.text());
156-
}
157-
}
158-
}
159-
160-
if (textOutputs.isEmpty()) {
161-
return ImmutableMap.of(
162-
"error",
163-
"Tool '" + this.name() + "' returned content that is not TextContent.",
164-
"content_details",
165-
contents.toString());
166-
}
167-
168-
List<Map<String, Object>> resultMaps = new ArrayList<>();
169-
for (String textOutput : textOutputs) {
170-
try {
171-
resultMaps.add(
172-
objectMapper.readValue(
173-
textOutput, new TypeReference<Map<String, Object>>() {}));
174-
} catch (JsonProcessingException e) {
175-
resultMaps.add(ImmutableMap.of("text", textOutput));
176-
}
177-
}
178-
return ImmutableMap.of("text_output", resultMaps);
130+
return wrapCallResult(this.objectMapper, this.name(), callResult);
179131
})
180132
.retryWhen(
181133
errors ->
@@ -188,4 +140,58 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
188140
reintializeSession();
189141
}));
190142
}
143+
144+
static Map<String, Object> wrapCallResult(
145+
ObjectMapper objectMapper, String mcpToolName, CallToolResult callResult) {
146+
if (callResult == null) {
147+
return ImmutableMap.of("error", "MCP framework error: CallToolResult was null");
148+
}
149+
150+
List<Content> contents = callResult.content();
151+
Boolean isToolError = callResult.isError();
152+
153+
if (isToolError != null && isToolError) {
154+
String errorMessage = "Tool execution failed.";
155+
if (contents != null
156+
&& !contents.isEmpty()
157+
&& contents.get(0) instanceof TextContent textContent) {
158+
if (textContent.text() != null && !textContent.text().isEmpty()) {
159+
errorMessage += " Details: " + textContent.text();
160+
}
161+
}
162+
return ImmutableMap.of("error", errorMessage);
163+
}
164+
165+
if (contents == null || contents.isEmpty()) {
166+
return ImmutableMap.of();
167+
}
168+
169+
List<String> textOutputs = new ArrayList<>();
170+
for (Content content : contents) {
171+
if (content instanceof TextContent textContent) {
172+
if (textContent.text() != null) {
173+
textOutputs.add(textContent.text());
174+
}
175+
}
176+
}
177+
178+
if (textOutputs.isEmpty()) {
179+
return ImmutableMap.of(
180+
"error",
181+
"Tool '" + mcpToolName + "' returned content that is not TextContent.",
182+
"content_details",
183+
contents.toString());
184+
}
185+
186+
List<Map<String, Object>> resultMaps = new ArrayList<>();
187+
for (String textOutput : textOutputs) {
188+
try {
189+
resultMaps.add(
190+
objectMapper.readValue(textOutput, new TypeReference<Map<String, Object>>() {}));
191+
} catch (JsonProcessingException e) {
192+
resultMaps.add(ImmutableMap.of("text", textOutput));
193+
}
194+
}
195+
return ImmutableMap.of("text_output", resultMaps);
196+
}
191197
}

0 commit comments

Comments
 (0)