Skip to content

Commit f7fd470

Browse files
Merge pull request #306 from wenhaozhao:feat-mcp_async_toolset
PiperOrigin-RevId: 795274417
2 parents 8c107d2 + b867ea2 commit f7fd470

File tree

4 files changed

+324
-25
lines changed

4 files changed

+324
-25
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package com.google.adk.tools;
2+
3+
import com.google.adk.agents.ReadonlyContext;
4+
import java.util.List;
5+
import java.util.Optional;
6+
7+
public class NamedToolPredicate implements ToolPredicate {
8+
9+
private final List<String> toolNames;
10+
11+
public NamedToolPredicate(List<String> toolNames) {
12+
this.toolNames = List.copyOf(toolNames);
13+
}
14+
15+
public NamedToolPredicate(String... toolNames) {
16+
this.toolNames = List.of(toolNames);
17+
}
18+
19+
@Override
20+
public boolean test(BaseTool tool, Optional<ReadonlyContext> readonlyContext) {
21+
return toolNames.contains(tool.name());
22+
}
23+
}
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.tools.mcp;
18+
19+
import com.fasterxml.jackson.databind.ObjectMapper;
20+
import com.google.adk.JsonBaseModel;
21+
import com.google.adk.agents.ReadonlyContext;
22+
import com.google.adk.tools.BaseTool;
23+
import com.google.adk.tools.BaseToolset;
24+
import com.google.adk.tools.NamedToolPredicate;
25+
import com.google.adk.tools.ToolPredicate;
26+
import io.modelcontextprotocol.client.McpAsyncClient;
27+
import io.modelcontextprotocol.client.transport.ServerParameters;
28+
import io.reactivex.rxjava3.core.Flowable;
29+
import io.reactivex.rxjava3.core.Maybe;
30+
import io.reactivex.rxjava3.core.Single;
31+
import java.time.Duration;
32+
import java.util.List;
33+
import java.util.Objects;
34+
import java.util.Optional;
35+
import java.util.concurrent.atomic.AtomicReference;
36+
import org.slf4j.Logger;
37+
import org.slf4j.LoggerFactory;
38+
import reactor.core.publisher.Mono;
39+
import reactor.util.retry.RetrySpec;
40+
41+
/**
42+
* Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
43+
*
44+
* <p>Attributes:
45+
*
46+
* <ul>
47+
* <li>{@code connectionParams}: The connection parameters to the MCP server. Can be either {@code
48+
* ServerParameters} or {@code SseServerParameters}.
49+
* <li>{@code session}: The MCP session being initialized with the connection.
50+
* </ul>
51+
*/
52+
public class McpAsyncToolset implements BaseToolset {
53+
54+
private static final Logger logger = LoggerFactory.getLogger(McpAsyncToolset.class);
55+
56+
private static final int MAX_RETRIES = 3;
57+
private static final long RETRY_DELAY_MILLIS = 100;
58+
59+
private final McpSessionManager mcpSessionManager;
60+
private final ObjectMapper objectMapper;
61+
private final ToolPredicate toolFilter;
62+
private final AtomicReference<Mono<List<McpAsyncTool>>> mcpTools = new AtomicReference<>();
63+
64+
/** Builder for McpAsyncToolset */
65+
public static class Builder {
66+
private Object connectionParams = null;
67+
private ObjectMapper objectMapper = null;
68+
private ToolPredicate toolFilter = null;
69+
70+
public Builder connectionParams(ServerParameters connectionParams) {
71+
this.connectionParams = connectionParams;
72+
return this;
73+
}
74+
75+
public Builder connectionParams(SseServerParameters connectionParams) {
76+
this.connectionParams = connectionParams;
77+
return this;
78+
}
79+
80+
public Builder objectMapper(ObjectMapper objectMapper) {
81+
this.objectMapper = objectMapper;
82+
return this;
83+
}
84+
85+
public Builder toolFilter(ToolPredicate toolFilter) {
86+
this.toolFilter = toolFilter;
87+
return this;
88+
}
89+
90+
public Builder toolFilter(List<String> toolNames) {
91+
this.toolFilter = new NamedToolPredicate(toolNames);
92+
return this;
93+
}
94+
95+
public McpAsyncToolset build() {
96+
if (objectMapper == null) {
97+
objectMapper = JsonBaseModel.getMapper();
98+
}
99+
if (toolFilter == null) {
100+
toolFilter = (tool, context) -> true;
101+
}
102+
if (connectionParams instanceof ServerParameters setSelectedParams) {
103+
return new McpAsyncToolset(setSelectedParams, objectMapper, toolFilter);
104+
} else if (connectionParams instanceof SseServerParameters sseServerParameters) {
105+
return new McpAsyncToolset(sseServerParameters, objectMapper, toolFilter);
106+
} else {
107+
throw new IllegalArgumentException(
108+
"connectionParams must be either ServerParameters or SseServerParameters");
109+
}
110+
}
111+
}
112+
113+
/**
114+
* Initializes the McpAsyncToolset with SSE server parameters.
115+
*
116+
* @param connectionParams The SSE connection parameters to the MCP server.
117+
* @param objectMapper An ObjectMapper instance for parsing schemas.
118+
* @param toolFilter null or an implement for {@link ToolPredicate}, {@link
119+
* com.google.adk.tools.NamedToolPredicate}
120+
*/
121+
public McpAsyncToolset(
122+
SseServerParameters connectionParams, ObjectMapper objectMapper, ToolPredicate toolFilter) {
123+
Objects.requireNonNull(connectionParams);
124+
Objects.requireNonNull(objectMapper);
125+
this.objectMapper = objectMapper;
126+
this.mcpSessionManager = new McpSessionManager(connectionParams);
127+
this.toolFilter = toolFilter;
128+
}
129+
130+
/**
131+
* Initializes the McpAsyncToolset with local server parameters.
132+
*
133+
* @param connectionParams The local server connection parameters to the MCP server.
134+
* @param objectMapper An ObjectMapper instance for parsing schemas.
135+
* @param toolFilter null or an implement for {@link ToolPredicate}, {@link
136+
* com.google.adk.tools.NamedToolPredicate}
137+
*/
138+
public McpAsyncToolset(
139+
ServerParameters connectionParams, ObjectMapper objectMapper, ToolPredicate toolFilter) {
140+
Objects.requireNonNull(connectionParams);
141+
Objects.requireNonNull(objectMapper);
142+
this.objectMapper = objectMapper;
143+
this.mcpSessionManager = new McpSessionManager(connectionParams);
144+
this.toolFilter = toolFilter;
145+
}
146+
147+
@Override
148+
public Flowable<BaseTool> getTools(ReadonlyContext readonlyContext) {
149+
return Maybe.defer(() -> Maybe.fromCompletionStage(this.initAndGetTools().toFuture()))
150+
.defaultIfEmpty(List.of())
151+
.map(
152+
tools ->
153+
tools.stream()
154+
.filter(
155+
tool ->
156+
isToolSelected(
157+
tool,
158+
Optional.of(toolFilter),
159+
Optional.ofNullable(readonlyContext)))
160+
.toList())
161+
.onErrorResumeNext(
162+
err -> {
163+
if (err instanceof McpToolsetException) {
164+
return Single.error(err);
165+
} else {
166+
return Single.error(
167+
new McpToolsetException.McpInitializationException(
168+
"Failed to reinitialize session during tool loading retry (unexpected"
169+
+ " error).",
170+
err));
171+
}
172+
})
173+
.flattenAsFlowable(it -> it);
174+
}
175+
176+
private Mono<List<McpAsyncTool>> initAndGetTools() {
177+
return this.mcpTools.accumulateAndGet(
178+
null,
179+
(prev, _ignore) -> {
180+
if (prev == null) {
181+
// lazy init and cache tools
182+
return this.initTools().cache();
183+
}
184+
return prev;
185+
});
186+
}
187+
188+
private Mono<List<McpAsyncTool>> initTools() {
189+
return Mono.defer(
190+
() -> {
191+
McpAsyncClient mcpSession = this.mcpSessionManager.createAsyncSession();
192+
return mcpSession
193+
.initialize()
194+
.doOnSuccess(
195+
initResult -> logger.debug("Initialize Client Result: {}", initResult))
196+
.thenReturn(mcpSession);
197+
})
198+
.flatMap(
199+
mcpSession ->
200+
mcpSession
201+
.listTools()
202+
.map(
203+
toolsResponse ->
204+
toolsResponse.tools().stream()
205+
.map(
206+
tool ->
207+
new McpAsyncTool(
208+
tool,
209+
mcpSession, // move mcpSession to McpAsyncTool
210+
this.mcpSessionManager,
211+
this.objectMapper))
212+
.toList()))
213+
.retryWhen(
214+
RetrySpec.from(
215+
retrySignal ->
216+
retrySignal.flatMap(
217+
signal -> {
218+
Throwable err = signal.failure();
219+
if (err instanceof IllegalArgumentException) {
220+
// This could happen if parameters for tool loading are somehow
221+
// invalid.
222+
// This is likely a fatal error and should not be retried.
223+
logger.error("Invalid argument encountered during tool loading.", err);
224+
return Mono.error(
225+
new McpToolsetException.McpToolLoadingException(
226+
"Invalid argument encountered during tool loading.", err));
227+
}
228+
long totalRetries = signal.totalRetries();
229+
logger.error(
230+
"Unexpected error during tool loading, retry attempt "
231+
+ (totalRetries + 1),
232+
err);
233+
if (totalRetries < MAX_RETRIES) {
234+
logger.info(
235+
"Reinitializing MCP session before next retry for unexpected error.");
236+
return Mono.just(err)
237+
.delayElement(Duration.ofMillis(RETRY_DELAY_MILLIS));
238+
} else {
239+
logger.error(
240+
"Failed to load tools after multiple retries due to unexpected error.",
241+
err);
242+
return Mono.error(
243+
new McpToolsetException.McpToolLoadingException(
244+
"Failed to load tools after multiple retries due to unexpected error.",
245+
err));
246+
}
247+
})));
248+
}
249+
250+
@Override
251+
public void close() {
252+
Mono<List<McpAsyncTool>> tools = this.mcpTools.getAndSet(null);
253+
if (tools != null) {
254+
tools
255+
.flatMapIterable(it -> it)
256+
.flatMap(
257+
it ->
258+
it.mcpSession
259+
.closeGracefully()
260+
.onErrorResume(
261+
e -> {
262+
logger.error("Failed to close MCP session", e);
263+
// We don't throw an exception here, as closing is a cleanup operation
264+
// and
265+
// failing to close shouldn't prevent the program from continuing (or
266+
// exiting).
267+
// However, we log the error for debugging purposes.
268+
return Mono.empty();
269+
}))
270+
.doOnComplete(() -> logger.debug("MCP session closed successfully."))
271+
.subscribe();
272+
}
273+
}
274+
}

core/src/main/java/com/google/adk/tools/mcp/McpToolset.java

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ public Flowable<BaseTool> getTools(ReadonlyContext readonlyContext) {
191191
// This could happen if parameters for tool loading are somehow invalid.
192192
// This is likely a fatal error and should not be retried.
193193
logger.error("Invalid argument encountered during tool loading.", e);
194-
throw new McpToolLoadingException(
194+
throw new McpToolsetException.McpToolLoadingException(
195195
"Invalid argument encountered during tool loading.", e);
196196
} catch (RuntimeException e) { // Catch any other unexpected runtime exceptions
197197
logger.error("Unexpected error during tool loading, retry attempt " + (i + 1), e);
@@ -210,21 +210,21 @@ public Flowable<BaseTool> getTools(ReadonlyContext readonlyContext) {
210210
Thread.currentThread().interrupt();
211211
logger.error(
212212
"Interrupted during retry delay for loadTools (unexpected error).", ie);
213-
throw new McpToolLoadingException(
213+
throw new McpToolsetException.McpToolLoadingException(
214214
"Interrupted during retry delay (unexpected error)", ie);
215215
} catch (RuntimeException reinitE) {
216216
logger.error(
217217
"Failed to reinitialize session during retry (unexpected error).",
218218
reinitE);
219-
throw new McpInitializationException(
219+
throw new McpToolsetException.McpInitializationException(
220220
"Failed to reinitialize session during tool loading retry (unexpected"
221221
+ " error).",
222222
reinitE);
223223
}
224224
} else {
225225
logger.error(
226226
"Failed to load tools after multiple retries due to unexpected error.", e);
227-
throw new McpToolLoadingException(
227+
throw new McpToolsetException.McpToolLoadingException(
228228
"Failed to load tools after multiple retries due to unexpected error.", e);
229229
}
230230
}
@@ -252,25 +252,4 @@ public void close() {
252252
}
253253
}
254254
}
255-
256-
/** Base exception for all errors originating from {@code McpToolset}. */
257-
public static class McpToolsetException extends RuntimeException {
258-
public McpToolsetException(String message, Throwable cause) {
259-
super(message, cause);
260-
}
261-
}
262-
263-
/** Exception thrown when there's an error during MCP session initialization. */
264-
public static class McpInitializationException extends McpToolsetException {
265-
public McpInitializationException(String message, Throwable cause) {
266-
super(message, cause);
267-
}
268-
}
269-
270-
/** Exception thrown when there's an error during loading tools from the MCP server. */
271-
public static class McpToolLoadingException extends McpToolsetException {
272-
public McpToolLoadingException(String message, Throwable cause) {
273-
super(message, cause);
274-
}
275-
}
276255
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package com.google.adk.tools.mcp;
2+
3+
/** Base exception for all errors originating from {@code McpToolset}. */
4+
public class McpToolsetException extends RuntimeException {
5+
6+
public McpToolsetException(String message, Throwable cause) {
7+
super(message, cause);
8+
}
9+
10+
/** Exception thrown when there's an error during MCP session initialization. */
11+
public static class McpInitializationException extends McpToolsetException {
12+
public McpInitializationException(String message, Throwable cause) {
13+
super(message, cause);
14+
}
15+
}
16+
17+
/** Exception thrown when there's an error during loading tools from the MCP server. */
18+
public static class McpToolLoadingException extends McpToolsetException {
19+
public McpToolLoadingException(String message, Throwable cause) {
20+
super(message, cause);
21+
}
22+
}
23+
}

0 commit comments

Comments
 (0)