Skip to content

Commit 89b57ab

Browse files
author
wenhaozhao
committed
feat:
1 add McpAsyncToolset 2 add NamedToolPredicate
1 parent f5b8fda commit 89b57ab

File tree

4 files changed

+316
-25
lines changed

4 files changed

+316
-25
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package com.google.adk.tools;
2+
3+
import com.google.adk.agents.ReadonlyContext;
4+
import java.util.List;
5+
import java.util.Optional;
6+
7+
public class NamedToolPredicate implements ToolPredicate {
8+
9+
private final List<String> toolNames;
10+
11+
public NamedToolPredicate(List<String> toolNames) {
12+
this.toolNames = List.copyOf(toolNames);
13+
}
14+
15+
public NamedToolPredicate(String... toolNames) {
16+
this.toolNames = List.of(toolNames);
17+
}
18+
19+
@Override
20+
public boolean test(BaseTool tool, Optional<ReadonlyContext> readonlyContext) {
21+
return toolNames.contains(tool.name());
22+
}
23+
}
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.tools.mcp;
18+
19+
import static com.google.common.collect.ImmutableList.toImmutableList;
20+
21+
import com.fasterxml.jackson.databind.ObjectMapper;
22+
import com.google.adk.JsonBaseModel;
23+
import com.google.adk.agents.ReadonlyContext;
24+
import com.google.adk.tools.BaseTool;
25+
import com.google.adk.tools.BaseToolset;
26+
import com.google.adk.tools.NamedToolPredicate;
27+
import com.google.adk.tools.ToolPredicate;
28+
import com.google.common.collect.ImmutableList;
29+
import io.modelcontextprotocol.client.McpAsyncClient;
30+
import io.modelcontextprotocol.client.transport.ServerParameters;
31+
import io.reactivex.rxjava3.core.Flowable;
32+
import io.reactivex.rxjava3.core.Maybe;
33+
import io.reactivex.rxjava3.core.Single;
34+
import java.time.Duration;
35+
import java.util.List;
36+
import java.util.Objects;
37+
import java.util.Optional;
38+
import java.util.concurrent.atomic.AtomicReference;
39+
import org.slf4j.Logger;
40+
import org.slf4j.LoggerFactory;
41+
import reactor.core.publisher.Mono;
42+
import reactor.util.retry.RetrySpec;
43+
44+
/**
45+
* Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
46+
*
47+
* <p>Attributes:
48+
*
49+
* <ul>
50+
* <li>{@code connectionParams}: The connection parameters to the MCP server. Can be either {@code
51+
* ServerParameters} or {@code SseServerParameters}.
52+
* <li>{@code session}: The MCP session being initialized with the connection.
53+
* </ul>
54+
*/
55+
public class McpAsyncToolset implements BaseToolset {
56+
57+
private static final Logger logger = LoggerFactory.getLogger(McpAsyncToolset.class);
58+
59+
private static final int MAX_RETRIES = 3;
60+
private static final long RETRY_DELAY_MILLIS = 100;
61+
62+
private final McpSessionManager mcpSessionManager;
63+
private final AtomicReference<McpAsyncClient> mcpSession = new AtomicReference<>();
64+
private final ObjectMapper objectMapper;
65+
private final ToolPredicate toolFilter;
66+
67+
/** Builder for McpAsyncToolset */
68+
public static class Builder {
69+
private Object connectionParams = null;
70+
private ObjectMapper objectMapper = null;
71+
private ToolPredicate toolFilter = null;
72+
73+
public Builder connectionParams(ServerParameters connectionParams) {
74+
this.connectionParams = connectionParams;
75+
return this;
76+
}
77+
78+
public Builder connectionParams(SseServerParameters connectionParams) {
79+
this.connectionParams = connectionParams;
80+
return this;
81+
}
82+
83+
public Builder objectMapper(ObjectMapper objectMapper) {
84+
this.objectMapper = objectMapper;
85+
return this;
86+
}
87+
88+
public Builder toolFilter(ToolPredicate toolFilter) {
89+
this.toolFilter = toolFilter;
90+
return this;
91+
}
92+
93+
public Builder toolFilter(List<String> toolNames) {
94+
this.toolFilter = new NamedToolPredicate(toolNames);
95+
return this;
96+
}
97+
98+
public McpAsyncToolset build() {
99+
if (objectMapper == null) {
100+
objectMapper = JsonBaseModel.getMapper();
101+
}
102+
if (toolFilter == null) {
103+
toolFilter = (tool, context) -> true;
104+
}
105+
if (connectionParams instanceof ServerParameters setSelectedParams) {
106+
return new McpAsyncToolset(setSelectedParams, objectMapper, toolFilter);
107+
} else if (connectionParams instanceof SseServerParameters sseServerParameters) {
108+
return new McpAsyncToolset(sseServerParameters, objectMapper, toolFilter);
109+
} else {
110+
throw new IllegalArgumentException(
111+
"connectionParams must be either ServerParameters or SseServerParameters");
112+
}
113+
}
114+
}
115+
116+
/**
117+
* Initializes the McpAsyncToolset with SSE server parameters.
118+
*
119+
* @param connectionParams The SSE connection parameters to the MCP server.
120+
* @param objectMapper An ObjectMapper instance for parsing schemas.
121+
* @param toolFilter null or an implement for {@link ToolPredicate}, {@link
122+
* com.google.adk.tools.NamedToolPredicate}
123+
*/
124+
public McpAsyncToolset(
125+
SseServerParameters connectionParams, ObjectMapper objectMapper, ToolPredicate toolFilter) {
126+
Objects.requireNonNull(connectionParams);
127+
Objects.requireNonNull(objectMapper);
128+
this.objectMapper = objectMapper;
129+
this.mcpSessionManager = new McpSessionManager(connectionParams);
130+
this.toolFilter = toolFilter;
131+
}
132+
133+
/**
134+
* Initializes the McpAsyncToolset with local server parameters.
135+
*
136+
* @param connectionParams The local server connection parameters to the MCP server.
137+
* @param objectMapper An ObjectMapper instance for parsing schemas.
138+
* @param toolFilter null or an implement for {@link ToolPredicate}, {@link
139+
* com.google.adk.tools.NamedToolPredicate}
140+
*/
141+
public McpAsyncToolset(
142+
ServerParameters connectionParams, ObjectMapper objectMapper, ToolPredicate toolFilter) {
143+
Objects.requireNonNull(connectionParams);
144+
Objects.requireNonNull(objectMapper);
145+
this.objectMapper = objectMapper;
146+
this.mcpSessionManager = new McpSessionManager(connectionParams);
147+
this.toolFilter = toolFilter;
148+
}
149+
150+
@Override
151+
public Flowable<BaseTool> getTools(ReadonlyContext readonlyContext) {
152+
Mono<ImmutableList<McpAsyncTool>> mono =
153+
Mono.fromCallable(
154+
() ->
155+
this.mcpSession.accumulateAndGet(
156+
null,
157+
(prev, _next) -> {
158+
// set if not present
159+
if (prev == null) {
160+
return this.mcpSessionManager.createAsyncSession();
161+
}
162+
return prev;
163+
}))
164+
.flatMap(
165+
mcpSession -> {
166+
if (mcpSession.isInitialized()) {
167+
return Mono.just(mcpSession);
168+
} else {
169+
return mcpSession
170+
.initialize()
171+
.doOnSuccess(
172+
initResult -> logger.debug("Initialize Client Result: {}", initResult))
173+
.thenReturn(mcpSession);
174+
}
175+
})
176+
.flatMap(
177+
mcpSession ->
178+
mcpSession
179+
.listTools()
180+
.map(
181+
toolsResponse ->
182+
toolsResponse.tools().stream()
183+
.map(
184+
tool ->
185+
new McpAsyncTool(
186+
tool,
187+
mcpSession,
188+
this.mcpSessionManager,
189+
this.objectMapper))
190+
.filter(
191+
tool ->
192+
isToolSelected(
193+
tool,
194+
Optional.of(toolFilter),
195+
Optional.ofNullable(readonlyContext)))
196+
.collect(toImmutableList())))
197+
.retryWhen(
198+
RetrySpec.from(
199+
retrySignal ->
200+
retrySignal.flatMap(
201+
signal -> {
202+
Throwable err = signal.failure();
203+
if (err instanceof IllegalArgumentException) {
204+
// This could happen if parameters for tool loading are somehow
205+
// invalid.
206+
// This is likely a fatal error and should not be retried.
207+
logger.error(
208+
"Invalid argument encountered during tool loading.", err);
209+
return Mono.error(
210+
new McpToolsetException.McpToolLoadingException(
211+
"Invalid argument encountered during tool loading.", err));
212+
}
213+
long totalRetries = signal.totalRetries();
214+
logger.error(
215+
"Unexpected error during tool loading, retry attempt "
216+
+ (totalRetries + 1),
217+
err);
218+
if (totalRetries < MAX_RETRIES) {
219+
logger.info(
220+
"Reinitializing MCP session before next retry for unexpected error.");
221+
this.close();
222+
return Mono.just(err)
223+
.delayElement(Duration.ofMillis(RETRY_DELAY_MILLIS));
224+
} else {
225+
logger.error(
226+
"Failed to load tools after multiple retries due to unexpected error.",
227+
err);
228+
return Mono.error(
229+
new McpToolsetException.McpToolLoadingException(
230+
"Failed to load tools after multiple retries due to unexpected error.",
231+
err));
232+
}
233+
})));
234+
return Maybe.defer(() -> Maybe.fromCompletionStage(mono.toFuture()))
235+
.defaultIfEmpty(ImmutableList.of())
236+
.onErrorResumeNext(
237+
err -> {
238+
if (err instanceof McpToolsetException.McpToolLoadingException) {
239+
return Single.error(err);
240+
} else {
241+
return Single.error(
242+
new McpToolsetException.McpInitializationException(
243+
"Failed to reinitialize session during tool loading retry (unexpected"
244+
+ " error).",
245+
err));
246+
}
247+
})
248+
.flattenAsFlowable(it -> it);
249+
}
250+
251+
@Override
252+
public void close() {
253+
McpAsyncClient mcpSession = this.mcpSession.getAndSet(null);
254+
if (mcpSession != null) {
255+
try {
256+
mcpSession.close();
257+
logger.debug("MCP session closed successfully.");
258+
} catch (RuntimeException e) {
259+
logger.error("Failed to close MCP session", e);
260+
// We don't throw an exception here, as closing is a cleanup operation and
261+
// failing to close shouldn't prevent the program from continuing (or exiting).
262+
// However, we log the error for debugging purposes.
263+
}
264+
}
265+
}
266+
}

core/src/main/java/com/google/adk/tools/mcp/McpToolset.java

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ public Flowable<BaseTool> getTools(ReadonlyContext readonlyContext) {
175175
// This could happen if parameters for tool loading are somehow invalid.
176176
// This is likely a fatal error and should not be retried.
177177
logger.error("Invalid argument encountered during tool loading.", e);
178-
throw new McpToolLoadingException(
178+
throw new McpToolsetException.McpToolLoadingException(
179179
"Invalid argument encountered during tool loading.", e);
180180
} catch (RuntimeException e) { // Catch any other unexpected runtime exceptions
181181
logger.error("Unexpected error during tool loading, retry attempt " + (i + 1), e);
@@ -194,21 +194,21 @@ public Flowable<BaseTool> getTools(ReadonlyContext readonlyContext) {
194194
Thread.currentThread().interrupt();
195195
logger.error(
196196
"Interrupted during retry delay for loadTools (unexpected error).", ie);
197-
throw new McpToolLoadingException(
197+
throw new McpToolsetException.McpToolLoadingException(
198198
"Interrupted during retry delay (unexpected error)", ie);
199199
} catch (RuntimeException reinitE) {
200200
logger.error(
201201
"Failed to reinitialize session during retry (unexpected error).",
202202
reinitE);
203-
throw new McpInitializationException(
203+
throw new McpToolsetException.McpInitializationException(
204204
"Failed to reinitialize session during tool loading retry (unexpected"
205205
+ " error).",
206206
reinitE);
207207
}
208208
} else {
209209
logger.error(
210210
"Failed to load tools after multiple retries due to unexpected error.", e);
211-
throw new McpToolLoadingException(
211+
throw new McpToolsetException.McpToolLoadingException(
212212
"Failed to load tools after multiple retries due to unexpected error.", e);
213213
}
214214
}
@@ -236,25 +236,4 @@ public void close() {
236236
}
237237
}
238238
}
239-
240-
/** Base exception for all errors originating from {@code McpToolset}. */
241-
public static class McpToolsetException extends RuntimeException {
242-
public McpToolsetException(String message, Throwable cause) {
243-
super(message, cause);
244-
}
245-
}
246-
247-
/** Exception thrown when there's an error during MCP session initialization. */
248-
public static class McpInitializationException extends McpToolsetException {
249-
public McpInitializationException(String message, Throwable cause) {
250-
super(message, cause);
251-
}
252-
}
253-
254-
/** Exception thrown when there's an error during loading tools from the MCP server. */
255-
public static class McpToolLoadingException extends McpToolsetException {
256-
public McpToolLoadingException(String message, Throwable cause) {
257-
super(message, cause);
258-
}
259-
}
260239
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package com.google.adk.tools.mcp;
2+
3+
/** Base exception for all errors originating from {@code McpToolset}. */
4+
public class McpToolsetException extends RuntimeException {
5+
6+
public McpToolsetException(String message, Throwable cause) {
7+
super(message, cause);
8+
}
9+
10+
/** Exception thrown when there's an error during MCP session initialization. */
11+
public static class McpInitializationException extends McpToolsetException {
12+
public McpInitializationException(String message, Throwable cause) {
13+
super(message, cause);
14+
}
15+
}
16+
17+
/** Exception thrown when there's an error during loading tools from the MCP server. */
18+
public static class McpToolLoadingException extends McpToolsetException {
19+
public McpToolLoadingException(String message, Throwable cause) {
20+
super(message, cause);
21+
}
22+
}
23+
}

0 commit comments

Comments
 (0)