diff --git a/core/src/main/java/com/google/adk/tools/NamedToolPredicate.java b/core/src/main/java/com/google/adk/tools/NamedToolPredicate.java new file mode 100644 index 00000000..f943de36 --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/NamedToolPredicate.java @@ -0,0 +1,23 @@ +package com.google.adk.tools; + +import com.google.adk.agents.ReadonlyContext; +import java.util.List; +import java.util.Optional; + +public class NamedToolPredicate implements ToolPredicate { + + private final List toolNames; + + public NamedToolPredicate(List toolNames) { + this.toolNames = List.copyOf(toolNames); + } + + public NamedToolPredicate(String... toolNames) { + this.toolNames = List.of(toolNames); + } + + @Override + public boolean test(BaseTool tool, Optional readonlyContext) { + return toolNames.contains(tool.name()); + } +} diff --git a/core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java b/core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java new file mode 100644 index 00000000..d71525cf --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java @@ -0,0 +1,274 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.mcp; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.JsonBaseModel; +import com.google.adk.agents.ReadonlyContext; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.NamedToolPredicate; +import com.google.adk.tools.ToolPredicate; +import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.client.transport.ServerParameters; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import java.time.Duration; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; +import reactor.util.retry.RetrySpec; + +/** + * Connects to a MCP Server, and retrieves MCP Tools into ADK Tools. + * + *

Attributes: + * + *

    + *
  • {@code connectionParams}: The connection parameters to the MCP server. Can be either {@code + * ServerParameters} or {@code SseServerParameters}. + *
  • {@code session}: The MCP session being initialized with the connection. + *
+ */ +public class McpAsyncToolset implements BaseToolset { + + private static final Logger logger = LoggerFactory.getLogger(McpAsyncToolset.class); + + private static final int MAX_RETRIES = 3; + private static final long RETRY_DELAY_MILLIS = 100; + + private final McpSessionManager mcpSessionManager; + private final ObjectMapper objectMapper; + private final ToolPredicate toolFilter; + private final AtomicReference>> mcpTools = new AtomicReference<>(); + + /** Builder for McpAsyncToolset */ + public static class Builder { + private Object connectionParams = null; + private ObjectMapper objectMapper = null; + private ToolPredicate toolFilter = null; + + public Builder connectionParams(ServerParameters connectionParams) { + this.connectionParams = connectionParams; + return this; + } + + public Builder connectionParams(SseServerParameters connectionParams) { + this.connectionParams = connectionParams; + return this; + } + + public Builder objectMapper(ObjectMapper objectMapper) { + this.objectMapper = objectMapper; + return this; + } + + public Builder toolFilter(ToolPredicate toolFilter) { + this.toolFilter = toolFilter; + return this; + } + + public Builder toolFilter(List toolNames) { + this.toolFilter = new NamedToolPredicate(toolNames); + return this; + } + + public McpAsyncToolset build() { + if (objectMapper == null) { + objectMapper = JsonBaseModel.getMapper(); + } + if (toolFilter == null) { + toolFilter = (tool, context) -> true; + } + if (connectionParams instanceof ServerParameters setSelectedParams) { + return new McpAsyncToolset(setSelectedParams, objectMapper, toolFilter); + } else if (connectionParams instanceof SseServerParameters sseServerParameters) { + return new McpAsyncToolset(sseServerParameters, objectMapper, toolFilter); + } else { + throw new IllegalArgumentException( + "connectionParams must be either ServerParameters or SseServerParameters"); + } + } + } + + /** + * Initializes the McpAsyncToolset with SSE server parameters. + * + * @param connectionParams The SSE connection parameters to the MCP server. + * @param objectMapper An ObjectMapper instance for parsing schemas. + * @param toolFilter null or an implement for {@link ToolPredicate}, {@link + * com.google.adk.tools.NamedToolPredicate} + */ + public McpAsyncToolset( + SseServerParameters connectionParams, ObjectMapper objectMapper, ToolPredicate toolFilter) { + Objects.requireNonNull(connectionParams); + Objects.requireNonNull(objectMapper); + this.objectMapper = objectMapper; + this.mcpSessionManager = new McpSessionManager(connectionParams); + this.toolFilter = toolFilter; + } + + /** + * Initializes the McpAsyncToolset with local server parameters. + * + * @param connectionParams The local server connection parameters to the MCP server. + * @param objectMapper An ObjectMapper instance for parsing schemas. + * @param toolFilter null or an implement for {@link ToolPredicate}, {@link + * com.google.adk.tools.NamedToolPredicate} + */ + public McpAsyncToolset( + ServerParameters connectionParams, ObjectMapper objectMapper, ToolPredicate toolFilter) { + Objects.requireNonNull(connectionParams); + Objects.requireNonNull(objectMapper); + this.objectMapper = objectMapper; + this.mcpSessionManager = new McpSessionManager(connectionParams); + this.toolFilter = toolFilter; + } + + @Override + public Flowable getTools(ReadonlyContext readonlyContext) { + return Maybe.defer(() -> Maybe.fromCompletionStage(this.initAndGetTools().toFuture())) + .defaultIfEmpty(List.of()) + .map( + tools -> + tools.stream() + .filter( + tool -> + isToolSelected( + tool, + Optional.of(toolFilter), + Optional.ofNullable(readonlyContext))) + .toList()) + .onErrorResumeNext( + err -> { + if (err instanceof McpToolsetException) { + return Single.error(err); + } else { + return Single.error( + new McpToolsetException.McpInitializationException( + "Failed to reinitialize session during tool loading retry (unexpected" + + " error).", + err)); + } + }) + .flattenAsFlowable(it -> it); + } + + private Mono> initAndGetTools() { + return this.mcpTools.accumulateAndGet( + null, + (prev, _ignore) -> { + if (prev == null) { + // lazy init and cache tools + return this.initTools().cache(); + } + return prev; + }); + } + + private Mono> initTools() { + return Mono.defer( + () -> { + McpAsyncClient mcpSession = this.mcpSessionManager.createAsyncSession(); + return mcpSession + .initialize() + .doOnSuccess( + initResult -> logger.debug("Initialize Client Result: {}", initResult)) + .thenReturn(mcpSession); + }) + .flatMap( + mcpSession -> + mcpSession + .listTools() + .map( + toolsResponse -> + toolsResponse.tools().stream() + .map( + tool -> + new McpAsyncTool( + tool, + mcpSession, // move mcpSession to McpAsyncTool + this.mcpSessionManager, + this.objectMapper)) + .toList())) + .retryWhen( + RetrySpec.from( + retrySignal -> + retrySignal.flatMap( + signal -> { + Throwable err = signal.failure(); + if (err instanceof IllegalArgumentException) { + // This could happen if parameters for tool loading are somehow + // invalid. + // This is likely a fatal error and should not be retried. + logger.error("Invalid argument encountered during tool loading.", err); + return Mono.error( + new McpToolsetException.McpToolLoadingException( + "Invalid argument encountered during tool loading.", err)); + } + long totalRetries = signal.totalRetries(); + logger.error( + "Unexpected error during tool loading, retry attempt " + + (totalRetries + 1), + err); + if (totalRetries < MAX_RETRIES) { + logger.info( + "Reinitializing MCP session before next retry for unexpected error."); + return Mono.just(err) + .delayElement(Duration.ofMillis(RETRY_DELAY_MILLIS)); + } else { + logger.error( + "Failed to load tools after multiple retries due to unexpected error.", + err); + return Mono.error( + new McpToolsetException.McpToolLoadingException( + "Failed to load tools after multiple retries due to unexpected error.", + err)); + } + }))); + } + + @Override + public void close() { + Mono> tools = this.mcpTools.getAndSet(null); + if (tools != null) { + tools + .flatMapIterable(it -> it) + .flatMap( + it -> + it.mcpSession + .closeGracefully() + .onErrorResume( + e -> { + logger.error("Failed to close MCP session", e); + // We don't throw an exception here, as closing is a cleanup operation + // and + // failing to close shouldn't prevent the program from continuing (or + // exiting). + // However, we log the error for debugging purposes. + return Mono.empty(); + })) + .doOnComplete(() -> logger.debug("MCP session closed successfully.")) + .subscribe(); + } + } +} diff --git a/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java b/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java index 8e502c3c..a91539e5 100644 --- a/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java +++ b/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java @@ -175,7 +175,7 @@ public Flowable getTools(ReadonlyContext readonlyContext) { // This could happen if parameters for tool loading are somehow invalid. // This is likely a fatal error and should not be retried. logger.error("Invalid argument encountered during tool loading.", e); - throw new McpToolLoadingException( + throw new McpToolsetException.McpToolLoadingException( "Invalid argument encountered during tool loading.", e); } catch (RuntimeException e) { // Catch any other unexpected runtime exceptions logger.error("Unexpected error during tool loading, retry attempt " + (i + 1), e); @@ -194,13 +194,13 @@ public Flowable getTools(ReadonlyContext readonlyContext) { Thread.currentThread().interrupt(); logger.error( "Interrupted during retry delay for loadTools (unexpected error).", ie); - throw new McpToolLoadingException( + throw new McpToolsetException.McpToolLoadingException( "Interrupted during retry delay (unexpected error)", ie); } catch (RuntimeException reinitE) { logger.error( "Failed to reinitialize session during retry (unexpected error).", reinitE); - throw new McpInitializationException( + throw new McpToolsetException.McpInitializationException( "Failed to reinitialize session during tool loading retry (unexpected" + " error).", reinitE); @@ -208,7 +208,7 @@ public Flowable getTools(ReadonlyContext readonlyContext) { } else { logger.error( "Failed to load tools after multiple retries due to unexpected error.", e); - throw new McpToolLoadingException( + throw new McpToolsetException.McpToolLoadingException( "Failed to load tools after multiple retries due to unexpected error.", e); } } @@ -236,25 +236,4 @@ public void close() { } } } - - /** Base exception for all errors originating from {@code McpToolset}. */ - public static class McpToolsetException extends RuntimeException { - public McpToolsetException(String message, Throwable cause) { - super(message, cause); - } - } - - /** Exception thrown when there's an error during MCP session initialization. */ - public static class McpInitializationException extends McpToolsetException { - public McpInitializationException(String message, Throwable cause) { - super(message, cause); - } - } - - /** Exception thrown when there's an error during loading tools from the MCP server. */ - public static class McpToolLoadingException extends McpToolsetException { - public McpToolLoadingException(String message, Throwable cause) { - super(message, cause); - } - } } diff --git a/core/src/main/java/com/google/adk/tools/mcp/McpToolsetException.java b/core/src/main/java/com/google/adk/tools/mcp/McpToolsetException.java new file mode 100644 index 00000000..062fe5e2 --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/mcp/McpToolsetException.java @@ -0,0 +1,23 @@ +package com.google.adk.tools.mcp; + +/** Base exception for all errors originating from {@code McpToolset}. */ +public class McpToolsetException extends RuntimeException { + + public McpToolsetException(String message, Throwable cause) { + super(message, cause); + } + + /** Exception thrown when there's an error during MCP session initialization. */ + public static class McpInitializationException extends McpToolsetException { + public McpInitializationException(String message, Throwable cause) { + super(message, cause); + } + } + + /** Exception thrown when there's an error during loading tools from the MCP server. */ + public static class McpToolLoadingException extends McpToolsetException { + public McpToolLoadingException(String message, Throwable cause) { + super(message, cause); + } + } +}