Skip to content

Commit 6c377e0

Browse files
author
wenhaozhao
committed
feat: support AsyncMcpTool
1 parent ffe5de8 commit 6c377e0

File tree

2 files changed

+44
-42
lines changed

2 files changed

+44
-42
lines changed

core/src/main/java/com/google/adk/tools/mcp/McpAsyncTool.java

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,16 @@
2626
import com.google.genai.types.FunctionDeclaration;
2727
import com.google.genai.types.Schema;
2828
import io.modelcontextprotocol.client.McpAsyncClient;
29+
import io.modelcontextprotocol.spec.McpSchema;
2930
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
3031
import io.modelcontextprotocol.spec.McpSchema.JsonSchema;
3132
import io.modelcontextprotocol.spec.McpSchema.Tool;
3233
import io.reactivex.rxjava3.core.Maybe;
3334
import io.reactivex.rxjava3.core.Single;
3435
import java.util.Map;
3536
import java.util.Optional;
37+
import org.slf4j.Logger;
38+
import org.slf4j.LoggerFactory;
3639

3740
// TODO(b/413489523): Add support for auth. This is a TODO for Python as well.
3841

@@ -44,8 +47,11 @@
4447
*/
4548
public final class McpAsyncTool extends BaseTool {
4649

50+
private static final Logger logger = LoggerFactory.getLogger(McpAsyncTool.class);
51+
4752
Tool mcpTool;
48-
Single<McpAsyncClient> mcpSession;
53+
// Volatile ensures write visibility in the asynchronous chain.
54+
volatile McpAsyncClient mcpSession;
4955
McpSessionManager mcpSessionManager;
5056
ObjectMapper objectMapper;
5157

@@ -58,7 +64,7 @@ public final class McpAsyncTool extends BaseTool {
5864
* @throws IllegalArgumentException If mcpTool or mcpSession are null.
5965
*/
6066
public McpAsyncTool(
61-
Tool mcpTool, Single<McpAsyncClient> mcpSession, McpSessionManager mcpSessionManager) {
67+
Tool mcpTool, McpAsyncClient mcpSession, McpSessionManager mcpSessionManager) {
6268
this(mcpTool, mcpSession, mcpSessionManager, JsonBaseModel.getMapper());
6369
}
6470

@@ -73,7 +79,7 @@ public McpAsyncTool(
7379
*/
7480
public McpAsyncTool(
7581
Tool mcpTool,
76-
Single<McpAsyncClient> mcpSession,
82+
McpAsyncClient mcpSession,
7783
McpSessionManager mcpSessionManager,
7884
ObjectMapper objectMapper) {
7985
super(
@@ -95,16 +101,32 @@ public McpAsyncTool(
95101
this.objectMapper = objectMapper;
96102
}
97103

98-
public Single<McpAsyncClient> getMcpSession() {
104+
public McpAsyncClient getMcpSession() {
99105
return this.mcpSession;
100106
}
101107

102108
public Schema toGeminiSchema(JsonSchema openApiSchema) {
103109
return Schema.fromJson(objectMapper.valueToTree(openApiSchema).toString());
104110
}
105111

106-
private void reintializeSession() {
107-
this.mcpSession = this.mcpSessionManager.createAsyncSession();
112+
private Single<McpSchema.InitializeResult> reintializeSession() {
113+
McpAsyncClient client = this.mcpSessionManager.createAsyncSession();
114+
return Single.fromCompletionStage(
115+
client
116+
.initialize()
117+
.doOnSuccess(
118+
initResult -> {
119+
logger.debug("Initialize McpAsyncClient Result: {}", initResult);
120+
})
121+
.doOnError(
122+
e -> {
123+
logger.error("Initialize McpAsyncClient Failed: {}", e.getMessage(), e);
124+
})
125+
.doOnNext(
126+
_initResult -> {
127+
this.mcpSession = client;
128+
})
129+
.toFuture());
108130
}
109131

110132
@Override
@@ -121,14 +143,10 @@ public Optional<FunctionDeclaration> declaration() {
121143
public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContext toolContext) {
122144
return Single.defer(
123145
() ->
124-
this.mcpSession
125-
.flatMapMaybe(
126-
client ->
127-
Maybe.fromCompletionStage(
128-
client
129-
.callTool(
130-
new CallToolRequest(this.name(), ImmutableMap.copyOf(args)))
131-
.toFuture()))
146+
Maybe.fromCompletionStage(
147+
this.mcpSession
148+
.callTool(new CallToolRequest(this.name(), ImmutableMap.copyOf(args)))
149+
.toFuture())
132150
.map(
133151
callResult ->
134152
McpTool.wrapCallResult(this.objectMapper, this.name(), callResult))
@@ -141,9 +159,8 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
141159
.delay(100, MILLISECONDS)
142160
.take(3)
143161
.doOnNext(
144-
error -> {
145-
System.err.println("Retrying callTool due to: " + error);
146-
reintializeSession();
147-
}));
162+
error ->
163+
logger.error("Retrying callTool due to: {}", error.getMessage(), error))
164+
.flatMapSingle(_ignore -> this.reintializeSession()));
148165
}
149166
}

core/src/main/java/com/google/adk/tools/mcp/McpSessionManager.java

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import io.modelcontextprotocol.spec.McpClientTransport;
2323
import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
2424
import io.modelcontextprotocol.spec.McpSchema.InitializeResult;
25-
import io.reactivex.rxjava3.core.Single;
2625
import java.time.Duration;
2726
import java.util.Optional;
2827
import org.slf4j.Logger;
@@ -73,15 +72,15 @@ public static McpSyncClient initializeSession(
7372
return client;
7473
}
7574

76-
public Single<McpAsyncClient> createAsyncSession() {
75+
public McpAsyncClient createAsyncSession() {
7776
return initializeAsyncSession(this.connectionParams);
7877
}
7978

80-
public static Single<McpAsyncClient> initializeAsyncSession(Object connectionParams) {
79+
public static McpAsyncClient initializeAsyncSession(Object connectionParams) {
8180
return initializeAsyncSession(connectionParams, new DefaultMcpTransportBuilder());
8281
}
8382

84-
public static Single<McpAsyncClient> initializeAsyncSession(
83+
public static McpAsyncClient initializeAsyncSession(
8584
Object connectionParams, McpTransportBuilder transportBuilder) {
8685
Duration initializationTimeout = null;
8786
Duration requestTimeout = null;
@@ -90,25 +89,11 @@ public static Single<McpAsyncClient> initializeAsyncSession(
9089
initializationTimeout = sseServerParams.timeout();
9190
requestTimeout = sseServerParams.sseReadTimeout();
9291
}
93-
McpAsyncClient client =
94-
McpClient.async(transport)
95-
.initializationTimeout(
96-
initializationTimeout == null ? Duration.ofSeconds(10) : initializationTimeout)
97-
.requestTimeout(requestTimeout == null ? Duration.ofSeconds(10) : requestTimeout)
98-
.capabilities(ClientCapabilities.builder().build())
99-
.build();
100-
return Single.fromCompletionStage(
101-
client
102-
.initialize()
103-
.doOnSuccess(
104-
initResult -> {
105-
logger.debug("Initialize McpAsyncClient Result: {}", initResult);
106-
})
107-
.doOnError(
108-
e -> {
109-
logger.error("Initialize McpAsyncClient Failed: {}", e.getMessage(), e);
110-
})
111-
.map(_initResult -> client)
112-
.toFuture());
92+
return McpClient.async(transport)
93+
.initializationTimeout(
94+
initializationTimeout == null ? Duration.ofSeconds(10) : initializationTimeout)
95+
.requestTimeout(requestTimeout == null ? Duration.ofSeconds(10) : requestTimeout)
96+
.capabilities(ClientCapabilities.builder().build())
97+
.build();
11398
}
11499
}

0 commit comments

Comments
 (0)