diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 8191ec8e6..adc973529 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -22,7 +22,7 @@ import com.google.adk.agents.Callbacks.AfterAgentCallback; import com.google.adk.agents.Callbacks.BeforeAgentCallback; import com.google.adk.events.Event; -import com.google.adk.plugins.PluginManager; +import com.google.adk.plugins.Plugin; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.DoNotCall; @@ -255,9 +255,9 @@ public Flowable runAsync(InvocationContext parentContext) { * @return callback functions. */ private ImmutableList>> beforeCallbacksToFunctions( - PluginManager pluginManager, List callbacks) { + Plugin pluginManager, List callbacks) { return Stream.concat( - Stream.of(ctx -> pluginManager.runBeforeAgentCallback(this, ctx)), + Stream.of(ctx -> pluginManager.beforeAgentCallback(this, ctx)), callbacks.stream() .map(callback -> (Function>) callback::call)) .collect(toImmutableList()); @@ -270,9 +270,9 @@ private ImmutableList>> beforeCallbacks * @return callback functions. */ private ImmutableList>> afterCallbacksToFunctions( - PluginManager pluginManager, List callbacks) { + Plugin pluginManager, List callbacks) { return Stream.concat( - Stream.of(ctx -> pluginManager.runAfterAgentCallback(this, ctx)), + Stream.of(ctx -> pluginManager.afterAgentCallback(this, ctx)), callbacks.stream() .map(callback -> (Function>) callback::call)) .collect(toImmutableList()); diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index 9491353fd..9396403bb 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -21,6 +21,7 @@ import com.google.adk.flows.llmflows.ResumabilityConfig; import com.google.adk.memory.BaseMemoryService; import com.google.adk.models.LlmCallsLimitExceededException; +import com.google.adk.plugins.Plugin; import com.google.adk.plugins.PluginManager; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; @@ -42,7 +43,7 @@ public class InvocationContext { private final BaseSessionService sessionService; private final BaseArtifactService artifactService; private final BaseMemoryService memoryService; - private final PluginManager pluginManager; + private final Plugin pluginManager; private final Optional liveRequestQueue; private final Map activeStreamingTools = new ConcurrentHashMap<>(); private final String invocationId; @@ -80,7 +81,7 @@ public InvocationContext( BaseSessionService sessionService, BaseArtifactService artifactService, BaseMemoryService memoryService, - PluginManager pluginManager, + Plugin pluginManager, Optional liveRequestQueue, Optional branch, String invocationId, @@ -235,7 +236,7 @@ public BaseMemoryService memoryService() { } /** Returns the plugin manager for accessing tools and plugins. */ - public PluginManager pluginManager() { + public Plugin pluginManager() { return pluginManager; } @@ -376,7 +377,7 @@ public static class Builder { private BaseSessionService sessionService; private BaseArtifactService artifactService; private BaseMemoryService memoryService; - private PluginManager pluginManager = new PluginManager(); + private Plugin pluginManager = new PluginManager(); private Optional liveRequestQueue = Optional.empty(); private Optional branch = Optional.empty(); private String invocationId = newInvocationContextId(); @@ -430,7 +431,7 @@ public Builder memoryService(BaseMemoryService memoryService) { * @return this builder instance for chaining. */ @CanIgnoreReturnValue - public Builder pluginManager(PluginManager pluginManager) { + public Builder pluginManager(Plugin pluginManager) { this.pluginManager = pluginManager; return this; } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index 07c626af1..5e6331b79 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -200,7 +200,7 @@ private Flowable callLlm( exception -> context .pluginManager() - .runOnModelErrorCallback( + .onModelErrorCallback( new CallbackContext( context, eventForCallbackUsage.actions()), llmRequestBuilder, @@ -244,7 +244,7 @@ private Single> handleBeforeModelCallback( CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions()); Maybe pluginResult = - context.pluginManager().runBeforeModelCallback(callbackContext, llmRequestBuilder); + context.pluginManager().beforeModelCallback(callbackContext, llmRequestBuilder); LlmAgent agent = (LlmAgent) context.agent(); @@ -280,7 +280,7 @@ private Single handleAfterModelCallback( CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions()); Maybe pluginResult = - context.pluginManager().runAfterModelCallback(callbackContext, llmResponse); + context.pluginManager().afterModelCallback(callbackContext, llmResponse); LlmAgent agent = (LlmAgent) context.agent(); Optional> callbacksOpt = agent.afterModelCallback(); diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 86721e985..950d4c5eb 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -390,7 +390,7 @@ private static Maybe postProcessFunctionResult( t -> invocationContext .pluginManager() - .runOnToolErrorCallback(tool, functionArgs, toolContext, t) + .onToolErrorCallback(tool, functionArgs, toolContext, t) .map(isLive ? Optional::ofNullable : Optional::of) .switchIfEmpty(Single.error(t))) .flatMapMaybe( @@ -462,7 +462,7 @@ private static Maybe> maybeInvokeBeforeToolCall( LlmAgent agent = (LlmAgent) invocationContext.agent(); Maybe> pluginResult = - invocationContext.pluginManager().runBeforeToolCallback(tool, functionArgs, toolContext); + invocationContext.pluginManager().beforeToolCallback(tool, functionArgs, toolContext); Optional> callbacksOpt = agent.beforeToolCallback(); if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) { @@ -496,7 +496,7 @@ private static Maybe> maybeInvokeAfterToolCall( Maybe> pluginResult = invocationContext .pluginManager() - .runAfterToolCallback(tool, functionArgs, toolContext, functionResult); + .afterToolCallback(tool, functionArgs, toolContext, functionResult); Optional> callbacksOpt = agent.afterToolCallback(); if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) { diff --git a/core/src/main/java/com/google/adk/plugins/BasePlugin.java b/core/src/main/java/com/google/adk/plugins/BasePlugin.java index e9ee783b3..4ec936ad0 100644 --- a/core/src/main/java/com/google/adk/plugins/BasePlugin.java +++ b/core/src/main/java/com/google/adk/plugins/BasePlugin.java @@ -15,19 +15,6 @@ */ package com.google.adk.plugins; -import com.google.adk.agents.BaseAgent; -import com.google.adk.agents.CallbackContext; -import com.google.adk.agents.InvocationContext; -import com.google.adk.events.Event; -import com.google.adk.models.LlmRequest; -import com.google.adk.models.LlmResponse; -import com.google.adk.tools.BaseTool; -import com.google.adk.tools.ToolContext; -import com.google.genai.types.Content; -import io.reactivex.rxjava3.core.Completable; -import io.reactivex.rxjava3.core.Maybe; -import java.util.Map; - /** * Base class for creating plugins. * @@ -40,168 +27,15 @@ *

A plugin can implement one or more methods of callbacks, but should not implement the same * method of callback for multiple times. */ -public abstract class BasePlugin { +public abstract class BasePlugin implements Plugin { protected final String name; public BasePlugin(String name) { this.name = name; } + @Override public String getName() { return name; } - - /** - * Callback executed when a user message is received before an invocation starts. - * - * @param invocationContext The context for the entire invocation. - * @param userMessage The message content input by user. - * @return An optional Content to replace the user message. Returning Empty to proceed normally. - */ - public Maybe onUserMessageCallback( - InvocationContext invocationContext, Content userMessage) { - return Maybe.empty(); - } - - /** - * Callback executed before the ADK runner runs. - * - * @param invocationContext The context for the entire invocation. - * @return An optional Content to halt execution. Returning Empty to proceed normally. - */ - public Maybe beforeRunCallback(InvocationContext invocationContext) { - return Maybe.empty(); - } - - /** - * Callback executed after an event is yielded from runner. - * - * @param invocationContext The context for the entire invocation. - * @param event The event raised by the runner. - * @return An optional Event to modify or replace the response. Returning Empty to proceed - * normally. - */ - public Maybe onEventCallback(InvocationContext invocationContext, Event event) { - return Maybe.empty(); - } - - /** - * Callback executed after an ADK runner run has completed. - * - * @param invocationContext The context for the entire invocation. - */ - public Completable afterRunCallback(InvocationContext invocationContext) { - return Completable.complete(); - } - - /** - * Callback executed before an agent's primary logic is invoked. - * - * @param agent The agent that is about to run. - * @param callbackContext The context for the agent invocation. - * @return An optional Content object to bypass the agent's execution. Returning Empty to proceed - * normally. - */ - public Maybe beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) { - return Maybe.empty(); - } - - /** - * Callback executed after an agent's primary logic has completed. - * - * @param agent The agent that has just run. - * @param callbackContext The context for the agent invocation. - * @return An optional Content object to replace the agent's original result. Returning Empty to - * use the original result. - */ - public Maybe afterAgentCallback(BaseAgent agent, CallbackContext callbackContext) { - return Maybe.empty(); - } - - /** - * Callback executed before a request is sent to the model. - * - * @param callbackContext The context for the current agent call. - * @param llmRequest The mutable request builder, allowing modification of the request before it - * is sent to the model. - * @return An optional LlmResponse to trigger an early exit. Returning Empty to proceed normally. - */ - public Maybe beforeModelCallback( - CallbackContext callbackContext, LlmRequest.Builder llmRequest) { - return Maybe.empty(); - } - - /** - * Callback executed after a response is received from the model. - * - * @param callbackContext The context for the current agent call. - * @param llmResponse The response object received from the model. - * @return An optional LlmResponse to modify or replace the response. Returning Empty to use the - * original response. - */ - public Maybe afterModelCallback( - CallbackContext callbackContext, LlmResponse llmResponse) { - return Maybe.empty(); - } - - /** - * Callback executed when a model call encounters an error. - * - * @param callbackContext The context for the current agent call. - * @param llmRequest The mutable request builder for the request that failed. - * @param error The exception that was raised. - * @return An optional LlmResponse to use instead of propagating the error. Returning Empty to - * allow the original error to be raised. - */ - public Maybe onModelErrorCallback( - CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { - return Maybe.empty(); - } - - /** - * Callback executed before a tool is called. - * - * @param tool The tool instance that is about to be executed. - * @param toolArgs The dictionary of arguments to be used for invoking the tool. - * @param toolContext The context specific to the tool execution. - * @return An optional Map to stop the tool execution and return this response immediately. - * Returning Empty to proceed normally. - */ - public Maybe> beforeToolCallback( - BaseTool tool, Map toolArgs, ToolContext toolContext) { - return Maybe.empty(); - } - - /** - * Callback executed after a tool has been called. - * - * @param tool The tool instance that has just been executed. - * @param toolArgs The original arguments that were passed to the tool. - * @param toolContext The context specific to the tool execution. - * @param result The dictionary returned by the tool invocation. - * @return An optional Map to replace the original result from the tool. Returning Empty to use - * the original result. - */ - public Maybe> afterToolCallback( - BaseTool tool, - Map toolArgs, - ToolContext toolContext, - Map result) { - return Maybe.empty(); - } - - /** - * Callback executed when a tool call encounters an error. - * - * @param tool The tool instance that encountered an error. - * @param toolArgs The arguments that were passed to the tool. - * @param toolContext The context specific to the tool execution. - * @param error The exception that was raised during tool execution. - * @return An optional Map to be used as the tool response instead of propagating the error. - * Returning Empty to allow the original error to be raised. - */ - public Maybe> onToolErrorCallback( - BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { - return Maybe.empty(); - } } diff --git a/core/src/main/java/com/google/adk/plugins/Plugin.java b/core/src/main/java/com/google/adk/plugins/Plugin.java new file mode 100644 index 000000000..97a9038d4 --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/Plugin.java @@ -0,0 +1,200 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.plugins; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.CallbackContext; +import com.google.adk.agents.InvocationContext; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.genai.types.Content; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Maybe; +import java.util.Map; + +/** + * Interface for creating plugins. + * + *

Plugins provide a structured way to intercept and modify agent, tool, and LLM behaviors at + * critical execution points in a callback manner. While agent callbacks apply to a particular + * agent, plugins applies globally to all agents added in the runner. Plugins are best used for + * adding custom behaviors like logging, monitoring, caching, or modifying requests and responses at + * key stages. + * + *

A plugin can implement one or more methods of callbacks, but should not implement the same + * method of callback for multiple times. + */ +public interface Plugin { + + String getName(); + + /** + * Callback executed when a user message is received before an invocation starts. + * + * @param invocationContext The context for the entire invocation. + * @param userMessage The message content input by user. + * @return An optional Content to replace the user message. Returning Empty to proceed normally. + */ + default Maybe onUserMessageCallback( + InvocationContext invocationContext, Content userMessage) { + return Maybe.empty(); + } + + /** + * Callback executed before the ADK runner runs. + * + * @param invocationContext The context for the entire invocation. + * @return An optional Content to halt execution. Returning Empty to proceed normally. + */ + default Maybe beforeRunCallback(InvocationContext invocationContext) { + return Maybe.empty(); + } + + /** + * Callback executed after an event is yielded from runner. + * + * @param invocationContext The context for the entire invocation. + * @param event The event raised by the runner. + * @return An optional Event to modify or replace the response. Returning Empty to proceed + * normally. + */ + default Maybe onEventCallback(InvocationContext invocationContext, Event event) { + return Maybe.empty(); + } + + /** + * Callback executed after an ADK runner run has completed. + * + * @param invocationContext The context for the entire invocation. + */ + default Completable afterRunCallback(InvocationContext invocationContext) { + return Completable.complete(); + } + + /** + * Callback executed before an agent's primary logic is invoked. + * + * @param agent The agent that is about to run. + * @param callbackContext The context for the agent invocation. + * @return An optional Content object to bypass the agent's execution. Returning Empty to proceed + * normally. + */ + default Maybe beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) { + return Maybe.empty(); + } + + /** + * Callback executed after an agent's primary logic has completed. + * + * @param agent The agent that has just run. + * @param callbackContext The context for the agent invocation. + * @return An optional Content object to replace the agent's original result. Returning Empty to + * use the original result. + */ + default Maybe afterAgentCallback(BaseAgent agent, CallbackContext callbackContext) { + return Maybe.empty(); + } + + /** + * Callback executed before a request is sent to the model. + * + * @param callbackContext The context for the current agent call. + * @param llmRequest The mutable request builder, allowing modification of the request before it + * is sent to the model. + * @return An optional LlmResponse to trigger an early exit. Returning Empty to proceed normally. + */ + default Maybe beforeModelCallback( + CallbackContext callbackContext, LlmRequest.Builder llmRequest) { + return Maybe.empty(); + } + + /** + * Callback executed after a response is received from the model. + * + * @param callbackContext The context for the current agent call. + * @param llmResponse The response object received from the model. + * @return An optional LlmResponse to modify or replace the response. Returning Empty to use the + * original response. + */ + default Maybe afterModelCallback( + CallbackContext callbackContext, LlmResponse llmResponse) { + return Maybe.empty(); + } + + /** + * Callback executed when a model call encounters an error. + * + * @param callbackContext The context for the current agent call. + * @param llmRequest The mutable request builder for the request that failed. + * @param error The exception that was raised. + * @return An optional LlmResponse to use instead of propagating the error. Returning Empty to + * allow the original error to be raised. + */ + default Maybe onModelErrorCallback( + CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { + return Maybe.empty(); + } + + /** + * Callback executed before a tool is called. + * + * @param tool The tool instance that is about to be executed. + * @param toolArgs The dictionary of arguments to be used for invoking the tool. + * @param toolContext The context specific to the tool execution. + * @return An optional Map to stop the tool execution and return this response immediately. + * Returning Empty to proceed normally. + */ + default Maybe> beforeToolCallback( + BaseTool tool, Map toolArgs, ToolContext toolContext) { + return Maybe.empty(); + } + + /** + * Callback executed after a tool has been called. + * + * @param tool The tool instance that has just been executed. + * @param toolArgs The original arguments that were passed to the tool. + * @param toolContext The context specific to the tool execution. + * @param result The dictionary returned by the tool invocation. + * @return An optional Map to replace the original result from the tool. Returning Empty to use + * the original result. + */ + default Maybe> afterToolCallback( + BaseTool tool, + Map toolArgs, + ToolContext toolContext, + Map result) { + return Maybe.empty(); + } + + /** + * Callback executed when a tool call encounters an error. + * + * @param tool The tool instance that encountered an error. + * @param toolArgs The arguments that were passed to the tool. + * @param toolContext The context specific to the tool execution. + * @param error The exception that was raised during tool execution. + * @return An optional Map to be used as the tool response instead of propagating the error. + * Returning Empty to allow the original error to be raised. + */ + default Maybe> onToolErrorCallback( + BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { + return Maybe.empty(); + } +} diff --git a/core/src/main/java/com/google/adk/plugins/PluginManager.java b/core/src/main/java/com/google/adk/plugins/PluginManager.java index e95a5c78e..cb284c70a 100644 --- a/core/src/main/java/com/google/adk/plugins/PluginManager.java +++ b/core/src/main/java/com/google/adk/plugins/PluginManager.java @@ -41,11 +41,11 @@ *

The PluginManager is an internal class that orchestrates the invocation of plugin callbacks at * key points in the SDK's execution lifecycle. */ -public class PluginManager { +public class PluginManager implements Plugin { private static final Logger logger = LoggerFactory.getLogger(PluginManager.class); - private final List plugins; + private final List plugins; - public PluginManager(List plugins) { + public PluginManager(List plugins) { this.plugins = new ArrayList<>(); if (plugins != null) { for (var plugin : plugins) { @@ -58,13 +58,18 @@ public PluginManager() { this(null); } + @Override + public String getName() { + return "PluginManager"; + } + /** * Registers a new plugin. * * @param plugin The plugin instance to register. * @throws IllegalArgumentException If a plugin with the same name is already registered. */ - public void registerPlugin(BasePlugin plugin) { + public void registerPlugin(Plugin plugin) { if (plugins.stream().anyMatch(p -> p.getName().equals(plugin.getName()))) { throw new IllegalArgumentException( "Plugin with name '" + plugin.getName() + "' already registered."); @@ -79,7 +84,7 @@ public void registerPlugin(BasePlugin plugin) { * @param pluginName The name of the plugin to retrieve. * @return The plugin instance if found, otherwise {@link Optional#empty()}. */ - public Optional getPlugin(String pluginName) { + public Optional getPlugin(String pluginName) { return plugins.stream().filter(p -> p.getName().equals(pluginName)).findFirst(); } @@ -87,17 +92,33 @@ public Optional getPlugin(String pluginName) { public Maybe runOnUserMessageCallback( InvocationContext invocationContext, Content userMessage) { + return onUserMessageCallback(invocationContext, userMessage); + } + + @Override + public Maybe onUserMessageCallback( + InvocationContext invocationContext, Content userMessage) { return runMaybeCallbacks( plugin -> plugin.onUserMessageCallback(invocationContext, userMessage), "onUserMessageCallback"); } public Maybe runBeforeRunCallback(InvocationContext invocationContext) { + return beforeRunCallback(invocationContext); + } + + @Override + public Maybe beforeRunCallback(InvocationContext invocationContext) { return runMaybeCallbacks( plugin -> plugin.beforeRunCallback(invocationContext), "beforeRunCallback"); } public Completable runAfterRunCallback(InvocationContext invocationContext) { + return afterRunCallback(invocationContext); + } + + @Override + public Completable afterRunCallback(InvocationContext invocationContext) { return Flowable.fromIterable(plugins) .concatMapCompletable( plugin -> @@ -112,34 +133,67 @@ public Completable runAfterRunCallback(InvocationContext invocationContext) { } public Maybe runOnEventCallback(InvocationContext invocationContext, Event event) { + return onEventCallback(invocationContext, event); + } + + @Override + public Maybe onEventCallback(InvocationContext invocationContext, Event event) { return runMaybeCallbacks( plugin -> plugin.onEventCallback(invocationContext, event), "onEventCallback"); } public Maybe runBeforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) { + return beforeAgentCallback(agent, callbackContext); + } + + @Override + public Maybe beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) { return runMaybeCallbacks( plugin -> plugin.beforeAgentCallback(agent, callbackContext), "beforeAgentCallback"); } public Maybe runAfterAgentCallback(BaseAgent agent, CallbackContext callbackContext) { + return afterAgentCallback(agent, callbackContext); + } + + @Override + public Maybe afterAgentCallback(BaseAgent agent, CallbackContext callbackContext) { return runMaybeCallbacks( plugin -> plugin.afterAgentCallback(agent, callbackContext), "afterAgentCallback"); } public Maybe runBeforeModelCallback( CallbackContext callbackContext, LlmRequest.Builder llmRequest) { + return beforeModelCallback(callbackContext, llmRequest); + } + + @Override + public Maybe beforeModelCallback( + CallbackContext callbackContext, LlmRequest.Builder llmRequest) { return runMaybeCallbacks( plugin -> plugin.beforeModelCallback(callbackContext, llmRequest), "beforeModelCallback"); } public Maybe runAfterModelCallback( CallbackContext callbackContext, LlmResponse llmResponse) { + return afterModelCallback(callbackContext, llmResponse); + } + + @Override + public Maybe afterModelCallback( + CallbackContext callbackContext, LlmResponse llmResponse) { return runMaybeCallbacks( plugin -> plugin.afterModelCallback(callbackContext, llmResponse), "afterModelCallback"); } public Maybe runOnModelErrorCallback( CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { + return onModelErrorCallback(callbackContext, llmRequest, error); + } + + @Override + public Maybe onModelErrorCallback( + CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { return runMaybeCallbacks( plugin -> plugin.onModelErrorCallback(callbackContext, llmRequest, error), "onModelErrorCallback"); @@ -147,6 +201,12 @@ public Maybe runOnModelErrorCallback( public Maybe> runBeforeToolCallback( BaseTool tool, Map toolArgs, ToolContext toolContext) { + return beforeToolCallback(tool, toolArgs, toolContext); + } + + @Override + public Maybe> beforeToolCallback( + BaseTool tool, Map toolArgs, ToolContext toolContext) { return runMaybeCallbacks( plugin -> plugin.beforeToolCallback(tool, toolArgs, toolContext), "beforeToolCallback"); } @@ -156,6 +216,15 @@ public Maybe> runAfterToolCallback( Map toolArgs, ToolContext toolContext, Map result) { + return afterToolCallback(tool, toolArgs, toolContext, result); + } + + @Override + public Maybe> afterToolCallback( + BaseTool tool, + Map toolArgs, + ToolContext toolContext, + Map result) { return runMaybeCallbacks( plugin -> plugin.afterToolCallback(tool, toolArgs, toolContext, result), "afterToolCallback"); @@ -163,6 +232,12 @@ public Maybe> runAfterToolCallback( public Maybe> runOnToolErrorCallback( BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { + return onToolErrorCallback(tool, toolArgs, toolContext, error); + } + + @Override + public Maybe> onToolErrorCallback( + BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { return runMaybeCallbacks( plugin -> plugin.onToolErrorCallback(tool, toolArgs, toolContext, error), "onToolErrorCallback"); @@ -176,7 +251,7 @@ public Maybe> runOnToolErrorCallback( * @return Maybe with the first non-empty result from a plugin, or Empty if all return Empty. */ private Maybe runMaybeCallbacks( - Function> callbackExecutor, String callbackName) { + Function> callbackExecutor, String callbackName) { return Flowable.fromIterable(this.plugins) .concatMapMaybe( diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index e0fafe703..f4b17c41f 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -400,8 +400,8 @@ public Flowable runAsync( Flowable.defer( () -> this.pluginManager - .runOnUserMessageCallback(initialContext, newMessage) - .switchIfEmpty(Single.just(newMessage)) + .onUserMessageCallback(initialContext, newMessage) + .defaultIfEmpty(newMessage) .flatMap( content -> (content != null) @@ -439,8 +439,7 @@ public Flowable runAsync( // Call beforeRunCallback with updated session Maybe beforeRunEvent = this.pluginManager - .runBeforeRunCallback( - contextWithUpdatedSession) + .beforeRunCallback(contextWithUpdatedSession) .map( content -> Event.builder() @@ -473,7 +472,7 @@ public Flowable runAsync( session); return contextWithUpdatedSession .pluginManager() - .runOnEventCallback( + .onEventCallback( contextWithUpdatedSession, registeredEvent) .defaultIfEmpty( diff --git a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java index fe115e46b..ae42bb27e 100644 --- a/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java +++ b/core/src/test/java/com/google/adk/plugins/PluginManagerTest.java @@ -55,8 +55,8 @@ public class PluginManagerTest { @Rule public MockitoRule mockitoRule = MockitoJUnit.rule(); private final PluginManager pluginManager = new PluginManager(); - @Mock private BasePlugin plugin1; - @Mock private BasePlugin plugin2; + @Mock private Plugin plugin1; + @Mock private Plugin plugin2; @Mock private InvocationContext mockInvocationContext; private final Content content = Content.builder().build(); private final Session session = Session.builder("session_id").build(); @@ -92,25 +92,25 @@ public void getPlugin_notFound() { } @Test - public void runOnUserMessageCallback_noPlugins() { - pluginManager.runOnUserMessageCallback(mockInvocationContext, content).test().assertResult(); + public void onUserMessageCallback_noPlugins() { + pluginManager.onUserMessageCallback(mockInvocationContext, content).test().assertResult(); } @Test - public void runOnUserMessageCallback_allReturnEmpty() { + public void onUserMessageCallback_allReturnEmpty() { when(plugin1.onUserMessageCallback(any(), any())).thenReturn(Maybe.empty()); when(plugin2.onUserMessageCallback(any(), any())).thenReturn(Maybe.empty()); pluginManager.registerPlugin(plugin1); pluginManager.registerPlugin(plugin2); - pluginManager.runOnUserMessageCallback(mockInvocationContext, content).test().assertResult(); + pluginManager.onUserMessageCallback(mockInvocationContext, content).test().assertResult(); verify(plugin1).onUserMessageCallback(mockInvocationContext, content); verify(plugin2).onUserMessageCallback(mockInvocationContext, content); } @Test - public void runOnUserMessageCallback_plugin1ReturnsValue_earlyExit() { + public void onUserMessageCallback_plugin1ReturnsValue_earlyExit() { Content expectedContent = Content.builder().build(); when(plugin1.onUserMessageCallback(any(), any())).thenReturn(Maybe.just(expectedContent)); when(plugin2.onUserMessageCallback(any(), any())).thenReturn(Maybe.empty()); @@ -118,7 +118,7 @@ public void runOnUserMessageCallback_plugin1ReturnsValue_earlyExit() { pluginManager.registerPlugin(plugin2); pluginManager - .runOnUserMessageCallback(mockInvocationContext, content) + .onUserMessageCallback(mockInvocationContext, content) .test() .assertResult(expectedContent); @@ -127,7 +127,7 @@ public void runOnUserMessageCallback_plugin1ReturnsValue_earlyExit() { } @Test - public void runOnUserMessageCallback_pluginOrderRespected() { + public void onUserMessageCallback_pluginOrderRespected() { Content expectedContent = Content.builder().build(); when(plugin1.onUserMessageCallback(any(), any())).thenReturn(Maybe.empty()); when(plugin2.onUserMessageCallback(any(), any())).thenReturn(Maybe.just(expectedContent)); @@ -135,7 +135,7 @@ public void runOnUserMessageCallback_pluginOrderRespected() { pluginManager.registerPlugin(plugin2); pluginManager - .runOnUserMessageCallback(mockInvocationContext, content) + .onUserMessageCallback(mockInvocationContext, content) .test() .assertResult(expectedContent); @@ -145,33 +145,33 @@ public void runOnUserMessageCallback_pluginOrderRespected() { } @Test - public void runAfterRunCallback_allComplete() { + public void afterRunCallback_allComplete() { when(plugin1.afterRunCallback(any())).thenReturn(Completable.complete()); when(plugin2.afterRunCallback(any())).thenReturn(Completable.complete()); pluginManager.registerPlugin(plugin1); pluginManager.registerPlugin(plugin2); - pluginManager.runAfterRunCallback(mockInvocationContext).test().assertResult(); + pluginManager.afterRunCallback(mockInvocationContext).test().assertResult(); verify(plugin1).afterRunCallback(mockInvocationContext); verify(plugin2).afterRunCallback(mockInvocationContext); } @Test - public void runAfterRunCallback_plugin1Fails() { + public void afterRunCallback_plugin1Fails() { RuntimeException testException = new RuntimeException("Test"); when(plugin1.afterRunCallback(any())).thenReturn(Completable.error(testException)); pluginManager.registerPlugin(plugin1); pluginManager.registerPlugin(plugin2); - pluginManager.runAfterRunCallback(mockInvocationContext).test().assertError(testException); + pluginManager.afterRunCallback(mockInvocationContext).test().assertError(testException); verify(plugin1).afterRunCallback(mockInvocationContext); verify(plugin2, never()).afterRunCallback(any()); } @Test - public void runBeforeAgentCallback_plugin2ReturnsValue() { + public void beforeAgentCallback_plugin2ReturnsValue() { BaseAgent mockAgent = mock(BaseAgent.class); CallbackContext mockCallbackContext = mock(CallbackContext.class); Content expectedContent = Content.builder().build(); @@ -182,7 +182,7 @@ public void runBeforeAgentCallback_plugin2ReturnsValue() { pluginManager.registerPlugin(plugin2); pluginManager - .runBeforeAgentCallback(mockAgent, mockCallbackContext) + .beforeAgentCallback(mockAgent, mockCallbackContext) .test() .assertResult(expectedContent); @@ -191,33 +191,30 @@ public void runBeforeAgentCallback_plugin2ReturnsValue() { } @Test - public void runBeforeRunCallback_singlePlugin() { + public void beforeRunCallback_singlePlugin() { Content expectedContent = Content.builder().build(); when(plugin1.beforeRunCallback(any())).thenReturn(Maybe.just(expectedContent)); pluginManager.registerPlugin(plugin1); - pluginManager.runBeforeRunCallback(mockInvocationContext).test().assertResult(expectedContent); + pluginManager.beforeRunCallback(mockInvocationContext).test().assertResult(expectedContent); verify(plugin1).beforeRunCallback(mockInvocationContext); } @Test - public void runOnEventCallback_singlePlugin() { + public void onEventCallback_singlePlugin() { Event mockEvent = mock(Event.class); when(plugin1.onEventCallback(any(), any())).thenReturn(Maybe.just(mockEvent)); pluginManager.registerPlugin(plugin1); - pluginManager - .runOnEventCallback(mockInvocationContext, mockEvent) - .test() - .assertResult(mockEvent); + pluginManager.onEventCallback(mockInvocationContext, mockEvent).test().assertResult(mockEvent); verify(plugin1).onEventCallback(mockInvocationContext, mockEvent); } @Test - public void runAfterAgentCallback_singlePlugin() { + public void afterAgentCallback_singlePlugin() { BaseAgent mockAgent = mock(BaseAgent.class); CallbackContext mockCallbackContext = mock(CallbackContext.class); Content expectedContent = Content.builder().build(); @@ -226,7 +223,7 @@ public void runAfterAgentCallback_singlePlugin() { pluginManager.registerPlugin(plugin1); pluginManager - .runAfterAgentCallback(mockAgent, mockCallbackContext) + .afterAgentCallback(mockAgent, mockCallbackContext) .test() .assertResult(expectedContent); @@ -234,7 +231,7 @@ public void runAfterAgentCallback_singlePlugin() { } @Test - public void runBeforeModelCallback_singlePlugin() { + public void beforeModelCallback_singlePlugin() { CallbackContext mockCallbackContext = mock(CallbackContext.class); LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); LlmResponse llmResponse = LlmResponse.builder().build(); @@ -243,7 +240,7 @@ public void runBeforeModelCallback_singlePlugin() { pluginManager.registerPlugin(plugin1); pluginManager - .runBeforeModelCallback(mockCallbackContext, llmRequestBuilder) + .beforeModelCallback(mockCallbackContext, llmRequestBuilder) .test() .assertResult(llmResponse); @@ -251,7 +248,7 @@ public void runBeforeModelCallback_singlePlugin() { } @Test - public void runAfterModelCallback_singlePlugin() { + public void afterModelCallback_singlePlugin() { CallbackContext mockCallbackContext = mock(CallbackContext.class); LlmResponse llmResponse = LlmResponse.builder().build(); @@ -259,7 +256,7 @@ public void runAfterModelCallback_singlePlugin() { pluginManager.registerPlugin(plugin1); pluginManager - .runAfterModelCallback(mockCallbackContext, llmResponse) + .afterModelCallback(mockCallbackContext, llmResponse) .test() .assertResult(llmResponse); @@ -267,7 +264,7 @@ public void runAfterModelCallback_singlePlugin() { } @Test - public void runOnModelErrorCallback_singlePlugin() { + public void onModelErrorCallback_singlePlugin() { CallbackContext mockCallbackContext = mock(CallbackContext.class); LlmRequest.Builder llmRequestBuilder = LlmRequest.builder(); Throwable mockThrowable = mock(Throwable.class); @@ -277,7 +274,7 @@ public void runOnModelErrorCallback_singlePlugin() { pluginManager.registerPlugin(plugin1); pluginManager - .runOnModelErrorCallback(mockCallbackContext, llmRequestBuilder, mockThrowable) + .onModelErrorCallback(mockCallbackContext, llmRequestBuilder, mockThrowable) .test() .assertResult(llmResponse); @@ -285,7 +282,7 @@ public void runOnModelErrorCallback_singlePlugin() { } @Test - public void runBeforeToolCallback_singlePlugin() { + public void beforeToolCallback_singlePlugin() { BaseTool mockTool = mock(BaseTool.class); ImmutableMap toolArgs = ImmutableMap.of(); ToolContext mockToolContext = mock(ToolContext.class); @@ -294,7 +291,7 @@ public void runBeforeToolCallback_singlePlugin() { pluginManager.registerPlugin(plugin1); pluginManager - .runBeforeToolCallback(mockTool, toolArgs, mockToolContext) + .beforeToolCallback(mockTool, toolArgs, mockToolContext) .test() .assertResult(toolArgs); @@ -302,7 +299,7 @@ public void runBeforeToolCallback_singlePlugin() { } @Test - public void runAfterToolCallback_singlePlugin() { + public void afterToolCallback_singlePlugin() { BaseTool mockTool = mock(BaseTool.class); ImmutableMap toolArgs = ImmutableMap.of(); ToolContext mockToolContext = mock(ToolContext.class); @@ -312,7 +309,7 @@ public void runAfterToolCallback_singlePlugin() { pluginManager.registerPlugin(plugin1); pluginManager - .runAfterToolCallback(mockTool, toolArgs, mockToolContext, result) + .afterToolCallback(mockTool, toolArgs, mockToolContext, result) .test() .assertResult(result); @@ -320,7 +317,7 @@ public void runAfterToolCallback_singlePlugin() { } @Test - public void runOnToolErrorCallback_singlePlugin() { + public void onToolErrorCallback_singlePlugin() { BaseTool mockTool = mock(BaseTool.class); ImmutableMap toolArgs = ImmutableMap.of(); ToolContext mockToolContext = mock(ToolContext.class); @@ -330,7 +327,7 @@ public void runOnToolErrorCallback_singlePlugin() { when(plugin1.onToolErrorCallback(any(), any(), any(), any())).thenReturn(Maybe.just(result)); pluginManager.registerPlugin(plugin1); pluginManager - .runOnToolErrorCallback(mockTool, toolArgs, mockToolContext, mockThrowable) + .onToolErrorCallback(mockTool, toolArgs, mockToolContext, mockThrowable) .test() .assertResult(result);