Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions core/src/main/java/com/google/adk/tools/NamedToolPredicate.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.google.adk.tools;

import com.google.adk.agents.ReadonlyContext;
import java.util.List;
import java.util.Optional;

public class NamedToolPredicate implements ToolPredicate {

private final List<String> toolNames;

public NamedToolPredicate(List<String> toolNames) {
this.toolNames = List.copyOf(toolNames);
}

public NamedToolPredicate(String... toolNames) {
this.toolNames = List.of(toolNames);
}

@Override
public boolean test(BaseTool tool, Optional<ReadonlyContext> readonlyContext) {
return toolNames.contains(tool.name());
}
}
274 changes: 274 additions & 0 deletions core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.adk.tools.mcp;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.adk.JsonBaseModel;
import com.google.adk.agents.ReadonlyContext;
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.BaseToolset;
import com.google.adk.tools.NamedToolPredicate;
import com.google.adk.tools.ToolPredicate;
import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.client.transport.ServerParameters;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.Single;
import java.time.Duration;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
import reactor.util.retry.RetrySpec;

/**
* Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
*
* <p>Attributes:
*
* <ul>
* <li>{@code connectionParams}: The connection parameters to the MCP server. Can be either {@code
* ServerParameters} or {@code SseServerParameters}.
* <li>{@code session}: The MCP session being initialized with the connection.
* </ul>
*/
public class McpAsyncToolset implements BaseToolset {

private static final Logger logger = LoggerFactory.getLogger(McpAsyncToolset.class);

private static final int MAX_RETRIES = 3;
private static final long RETRY_DELAY_MILLIS = 100;

private final McpSessionManager mcpSessionManager;
private final ObjectMapper objectMapper;
private final ToolPredicate toolFilter;
private final AtomicReference<Mono<List<McpAsyncTool>>> mcpTools = new AtomicReference<>();

/** Builder for McpAsyncToolset */
public static class Builder {
private Object connectionParams = null;
private ObjectMapper objectMapper = null;
private ToolPredicate toolFilter = null;

public Builder connectionParams(ServerParameters connectionParams) {
this.connectionParams = connectionParams;
return this;
}

public Builder connectionParams(SseServerParameters connectionParams) {
this.connectionParams = connectionParams;
return this;
}

public Builder objectMapper(ObjectMapper objectMapper) {
this.objectMapper = objectMapper;
return this;
}

public Builder toolFilter(ToolPredicate toolFilter) {
this.toolFilter = toolFilter;
return this;
}

public Builder toolFilter(List<String> toolNames) {
this.toolFilter = new NamedToolPredicate(toolNames);
return this;
}

public McpAsyncToolset build() {
if (objectMapper == null) {
objectMapper = JsonBaseModel.getMapper();
}
if (toolFilter == null) {
toolFilter = (tool, context) -> true;
}
if (connectionParams instanceof ServerParameters setSelectedParams) {
return new McpAsyncToolset(setSelectedParams, objectMapper, toolFilter);
} else if (connectionParams instanceof SseServerParameters sseServerParameters) {
return new McpAsyncToolset(sseServerParameters, objectMapper, toolFilter);
} else {
throw new IllegalArgumentException(
"connectionParams must be either ServerParameters or SseServerParameters");
}
}
}

/**
* Initializes the McpAsyncToolset with SSE server parameters.
*
* @param connectionParams The SSE connection parameters to the MCP server.
* @param objectMapper An ObjectMapper instance for parsing schemas.
* @param toolFilter null or an implement for {@link ToolPredicate}, {@link
* com.google.adk.tools.NamedToolPredicate}
*/
public McpAsyncToolset(
SseServerParameters connectionParams, ObjectMapper objectMapper, ToolPredicate toolFilter) {
Objects.requireNonNull(connectionParams);
Objects.requireNonNull(objectMapper);
this.objectMapper = objectMapper;
this.mcpSessionManager = new McpSessionManager(connectionParams);
this.toolFilter = toolFilter;
}

/**
* Initializes the McpAsyncToolset with local server parameters.
*
* @param connectionParams The local server connection parameters to the MCP server.
* @param objectMapper An ObjectMapper instance for parsing schemas.
* @param toolFilter null or an implement for {@link ToolPredicate}, {@link
* com.google.adk.tools.NamedToolPredicate}
*/
public McpAsyncToolset(
ServerParameters connectionParams, ObjectMapper objectMapper, ToolPredicate toolFilter) {
Objects.requireNonNull(connectionParams);
Objects.requireNonNull(objectMapper);
this.objectMapper = objectMapper;
this.mcpSessionManager = new McpSessionManager(connectionParams);
this.toolFilter = toolFilter;
}

@Override
public Flowable<BaseTool> getTools(ReadonlyContext readonlyContext) {
return Maybe.defer(() -> Maybe.fromCompletionStage(this.initAndGetTools().toFuture()))
.defaultIfEmpty(List.of())
.map(
tools ->
tools.stream()
.filter(
tool ->
isToolSelected(
tool,
Optional.of(toolFilter),
Optional.ofNullable(readonlyContext)))
.toList())
.onErrorResumeNext(
err -> {
if (err instanceof McpToolsetException) {
return Single.error(err);
} else {
return Single.error(
new McpToolsetException.McpInitializationException(
"Failed to reinitialize session during tool loading retry (unexpected"
+ " error).",
err));
}
})
.flattenAsFlowable(it -> it);
}

private Mono<List<McpAsyncTool>> initAndGetTools() {
return this.mcpTools.accumulateAndGet(
null,
(prev, _ignore) -> {
if (prev == null) {
// lazy init and cache tools
return this.initTools().cache();
}
return prev;
});
}

private Mono<List<McpAsyncTool>> initTools() {
return Mono.defer(
() -> {
McpAsyncClient mcpSession = this.mcpSessionManager.createAsyncSession();
return mcpSession
.initialize()
.doOnSuccess(
initResult -> logger.debug("Initialize Client Result: {}", initResult))
.thenReturn(mcpSession);
})
.flatMap(
mcpSession ->
mcpSession
.listTools()
.map(
toolsResponse ->
toolsResponse.tools().stream()
.map(
tool ->
new McpAsyncTool(
tool,
mcpSession, // move mcpSession to McpAsyncTool
this.mcpSessionManager,
this.objectMapper))
.toList()))
.retryWhen(
RetrySpec.from(
retrySignal ->
retrySignal.flatMap(
signal -> {
Throwable err = signal.failure();
if (err instanceof IllegalArgumentException) {
// This could happen if parameters for tool loading are somehow
// invalid.
// This is likely a fatal error and should not be retried.
logger.error("Invalid argument encountered during tool loading.", err);
return Mono.error(
new McpToolsetException.McpToolLoadingException(
"Invalid argument encountered during tool loading.", err));
}
long totalRetries = signal.totalRetries();
logger.error(
"Unexpected error during tool loading, retry attempt "
+ (totalRetries + 1),
err);
if (totalRetries < MAX_RETRIES) {
logger.info(
"Reinitializing MCP session before next retry for unexpected error.");
return Mono.just(err)
.delayElement(Duration.ofMillis(RETRY_DELAY_MILLIS));
} else {
logger.error(
"Failed to load tools after multiple retries due to unexpected error.",
err);
return Mono.error(
new McpToolsetException.McpToolLoadingException(
"Failed to load tools after multiple retries due to unexpected error.",
err));
}
})));
}

@Override
public void close() {
Mono<List<McpAsyncTool>> tools = this.mcpTools.getAndSet(null);
if (tools != null) {
tools
.flatMapIterable(it -> it)
.flatMap(
it ->
it.mcpSession
.closeGracefully()
.onErrorResume(
e -> {
logger.error("Failed to close MCP session", e);
// We don't throw an exception here, as closing is a cleanup operation
// and
// failing to close shouldn't prevent the program from continuing (or
// exiting).
// However, we log the error for debugging purposes.
return Mono.empty();
}))
.doOnComplete(() -> logger.debug("MCP session closed successfully."))
.subscribe();
}
}
}
29 changes: 4 additions & 25 deletions core/src/main/java/com/google/adk/tools/mcp/McpToolset.java
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ public Flowable<BaseTool> getTools(ReadonlyContext readonlyContext) {
// This could happen if parameters for tool loading are somehow invalid.
// This is likely a fatal error and should not be retried.
logger.error("Invalid argument encountered during tool loading.", e);
throw new McpToolLoadingException(
throw new McpToolsetException.McpToolLoadingException(
"Invalid argument encountered during tool loading.", e);
} catch (RuntimeException e) { // Catch any other unexpected runtime exceptions
logger.error("Unexpected error during tool loading, retry attempt " + (i + 1), e);
Expand All @@ -194,21 +194,21 @@ public Flowable<BaseTool> getTools(ReadonlyContext readonlyContext) {
Thread.currentThread().interrupt();
logger.error(
"Interrupted during retry delay for loadTools (unexpected error).", ie);
throw new McpToolLoadingException(
throw new McpToolsetException.McpToolLoadingException(
"Interrupted during retry delay (unexpected error)", ie);
} catch (RuntimeException reinitE) {
logger.error(
"Failed to reinitialize session during retry (unexpected error).",
reinitE);
throw new McpInitializationException(
throw new McpToolsetException.McpInitializationException(
"Failed to reinitialize session during tool loading retry (unexpected"
+ " error).",
reinitE);
}
} else {
logger.error(
"Failed to load tools after multiple retries due to unexpected error.", e);
throw new McpToolLoadingException(
throw new McpToolsetException.McpToolLoadingException(
"Failed to load tools after multiple retries due to unexpected error.", e);
}
}
Expand Down Expand Up @@ -236,25 +236,4 @@ public void close() {
}
}
}

/** Base exception for all errors originating from {@code McpToolset}. */
public static class McpToolsetException extends RuntimeException {
public McpToolsetException(String message, Throwable cause) {
super(message, cause);
}
}

/** Exception thrown when there's an error during MCP session initialization. */
public static class McpInitializationException extends McpToolsetException {
public McpInitializationException(String message, Throwable cause) {
super(message, cause);
}
}

/** Exception thrown when there's an error during loading tools from the MCP server. */
public static class McpToolLoadingException extends McpToolsetException {
public McpToolLoadingException(String message, Throwable cause) {
super(message, cause);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.google.adk.tools.mcp;

/** Base exception for all errors originating from {@code McpToolset}. */
public class McpToolsetException extends RuntimeException {

public McpToolsetException(String message, Throwable cause) {
super(message, cause);
}

/** Exception thrown when there's an error during MCP session initialization. */
public static class McpInitializationException extends McpToolsetException {
public McpInitializationException(String message, Throwable cause) {
super(message, cause);
}
}

/** Exception thrown when there's an error during loading tools from the MCP server. */
public static class McpToolLoadingException extends McpToolsetException {
public McpToolLoadingException(String message, Throwable cause) {
super(message, cause);
}
}
}