diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index 53a97897..8c9d1aa8 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -17,6 +17,7 @@ package com.google.adk.agents; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Arrays.stream; import com.google.adk.Telemetry; import com.google.adk.agents.Callbacks.AfterAgentCallback; @@ -36,7 +37,6 @@ import java.util.List; import java.util.Optional; import java.util.function.Function; -import java.util.stream.Stream; import org.jspecify.annotations.Nullable; /** Base class for all agents. */ @@ -59,8 +59,7 @@ public abstract class BaseAgent { private final List subAgents; - private final Optional> beforeAgentCallback; - private final Optional> afterAgentCallback; + protected final CallbackPlugin callbackPlugin; /** * Creates a new BaseAgent. @@ -79,12 +78,35 @@ public BaseAgent( List subAgents, List beforeAgentCallback, List afterAgentCallback) { + this( + name, + description, + subAgents, + CallbackPlugin.builder() + .addBeforeAgentCallbacks(beforeAgentCallback) + .addAfterAgentCallbacks(afterAgentCallback) + .build()); + } + + /** + * Creates a new BaseAgent. + * + * @param name Unique agent name. Cannot be "user" (reserved). + * @param description Agent purpose. + * @param subAgents Agents managed by this agent. + * @param callbackPlugin The callback plugin for this agent. + */ + protected BaseAgent( + String name, + String description, + List subAgents, + CallbackPlugin callbackPlugin) { this.name = name; this.description = description; this.parentAgent = null; this.subAgents = subAgents != null ? subAgents : ImmutableList.of(); - this.beforeAgentCallback = Optional.ofNullable(beforeAgentCallback); - this.afterAgentCallback = Optional.ofNullable(afterAgentCallback); + this.callbackPlugin = + callbackPlugin == null ? CallbackPlugin.builder().build() : callbackPlugin; // Establish parent relationships for all sub-agents if needed. for (BaseAgent subAgent : this.subAgents) { @@ -172,11 +194,15 @@ public List subAgents() { } public Optional> beforeAgentCallback() { - return beforeAgentCallback; + return Optional.of(callbackPlugin.getBeforeAgentCallback()); } public Optional> afterAgentCallback() { - return afterAgentCallback; + return Optional.of(callbackPlugin.getAfterAgentCallback()); + } + + public Plugin getPlugin() { + return callbackPlugin; } /** @@ -221,8 +247,7 @@ public Flowable runAsync(InvocationContext parentContext) { () -> callCallback( beforeCallbacksToFunctions( - invocationContext.pluginManager(), - beforeAgentCallback.orElse(ImmutableList.of())), + invocationContext.pluginManager(), callbackPlugin), invocationContext) .flatMapPublisher( beforeEventOpt -> { @@ -239,7 +264,7 @@ public Flowable runAsync(InvocationContext parentContext) { callCallback( afterCallbacksToFunctions( invocationContext.pluginManager(), - afterAgentCallback.orElse(ImmutableList.of())), + callbackPlugin), invocationContext) .flatMapPublisher(Flowable::fromOptional)); @@ -251,30 +276,27 @@ public Flowable runAsync(InvocationContext parentContext) { /** * Converts before-agent callbacks to functions. * - * @param callbacks Before-agent callbacks. * @return callback functions. */ private ImmutableList>> beforeCallbacksToFunctions( - Plugin pluginManager, List callbacks) { - return Stream.concat( - Stream.of(ctx -> pluginManager.beforeAgentCallback(this, ctx)), - callbacks.stream() - .map(callback -> (Function>) callback::call)) + Plugin... plugins) { + return stream(plugins) + .map( + p -> + (Function>) ctx -> p.beforeAgentCallback(this, ctx)) .collect(toImmutableList()); } /** * Converts after-agent callbacks to functions. * - * @param callbacks After-agent callbacks. * @return callback functions. */ private ImmutableList>> afterCallbacksToFunctions( - Plugin pluginManager, List callbacks) { - return Stream.concat( - Stream.of(ctx -> pluginManager.afterAgentCallback(this, ctx)), - callbacks.stream() - .map(callback -> (Function>) callback::call)) + Plugin... plugins) { + return stream(plugins) + .map( + p -> (Function>) ctx -> p.afterAgentCallback(this, ctx)) .collect(toImmutableList()); } @@ -399,8 +421,11 @@ public abstract static class Builder> { protected String name; protected String description; protected ImmutableList subAgents; - protected ImmutableList beforeAgentCallback; - protected ImmutableList afterAgentCallback; + protected final CallbackPlugin.Builder callbackPluginBuilder = CallbackPlugin.builder(); + + protected CallbackPlugin.Builder callbackPluginBuilder() { + return callbackPluginBuilder; + } /** This is a safe cast to the concrete builder type. */ @SuppressWarnings("unchecked") @@ -434,25 +459,25 @@ public B subAgents(BaseAgent... subAgents) { @CanIgnoreReturnValue public B beforeAgentCallback(BeforeAgentCallback beforeAgentCallback) { - this.beforeAgentCallback = ImmutableList.of(beforeAgentCallback); + callbackPluginBuilder.addBeforeAgentCallback(beforeAgentCallback); return self(); } @CanIgnoreReturnValue public B beforeAgentCallback(List beforeAgentCallback) { - this.beforeAgentCallback = CallbackUtil.getBeforeAgentCallbacks(beforeAgentCallback); + callbackPluginBuilder.addBeforeAgentCallbacks(beforeAgentCallback); return self(); } @CanIgnoreReturnValue public B afterAgentCallback(AfterAgentCallback afterAgentCallback) { - this.afterAgentCallback = ImmutableList.of(afterAgentCallback); + callbackPluginBuilder.addAfterAgentCallback(afterAgentCallback); return self(); } @CanIgnoreReturnValue public B afterAgentCallback(List afterAgentCallback) { - this.afterAgentCallback = CallbackUtil.getAfterAgentCallbacks(afterAgentCallback); + callbackPluginBuilder.addAfterAgentCallbacks(afterAgentCallback); return self(); } diff --git a/core/src/main/java/com/google/adk/agents/CallbackPlugin.java b/core/src/main/java/com/google/adk/agents/CallbackPlugin.java new file mode 100644 index 00000000..2fd40e3d --- /dev/null +++ b/core/src/main/java/com/google/adk/agents/CallbackPlugin.java @@ -0,0 +1,393 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.agents; + +import com.google.adk.agents.Callbacks.AfterAgentCallback; +import com.google.adk.agents.Callbacks.AfterAgentCallbackBase; +import com.google.adk.agents.Callbacks.AfterAgentCallbackSync; +import com.google.adk.agents.Callbacks.AfterModelCallback; +import com.google.adk.agents.Callbacks.AfterModelCallbackBase; +import com.google.adk.agents.Callbacks.AfterModelCallbackSync; +import com.google.adk.agents.Callbacks.AfterToolCallback; +import com.google.adk.agents.Callbacks.AfterToolCallbackBase; +import com.google.adk.agents.Callbacks.AfterToolCallbackSync; +import com.google.adk.agents.Callbacks.BeforeAgentCallback; +import com.google.adk.agents.Callbacks.BeforeAgentCallbackBase; +import com.google.adk.agents.Callbacks.BeforeAgentCallbackSync; +import com.google.adk.agents.Callbacks.BeforeModelCallback; +import com.google.adk.agents.Callbacks.BeforeModelCallbackBase; +import com.google.adk.agents.Callbacks.BeforeModelCallbackSync; +import com.google.adk.agents.Callbacks.BeforeToolCallback; +import com.google.adk.agents.Callbacks.BeforeToolCallbackBase; +import com.google.adk.agents.Callbacks.BeforeToolCallbackSync; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.plugins.BasePlugin; +import com.google.adk.plugins.PluginManager; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ListMultimap; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.genai.types.Content; +import io.reactivex.rxjava3.core.Maybe; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** A plugin that wraps callbacks and exposes them as a plugin. */ +public class CallbackPlugin extends PluginManager { + + private static final Logger logger = LoggerFactory.getLogger(CallbackPlugin.class); + + private final ImmutableListMultimap, Object> callbacks; + + private CallbackPlugin( + ImmutableList plugins, + ImmutableListMultimap, Object> callbacks) { + super(plugins); + this.callbacks = callbacks; + } + + @Override + public String getName() { + return "CallbackPlugin"; + } + + @SuppressWarnings("unchecked") // The builder ensures that the type is correct. + private ImmutableList getCallbacks(Class type) { + return callbacks.get(type).stream() + .map(callback -> (T) callback) + .collect(ImmutableList.toImmutableList()); + } + + public ImmutableList getBeforeAgentCallback() { + return getCallbacks(Callbacks.BeforeAgentCallback.class); + } + + public ImmutableList getAfterAgentCallback() { + return getCallbacks(Callbacks.AfterAgentCallback.class); + } + + public ImmutableList getBeforeModelCallback() { + return getCallbacks(Callbacks.BeforeModelCallback.class); + } + + public ImmutableList getAfterModelCallback() { + return getCallbacks(Callbacks.AfterModelCallback.class); + } + + public ImmutableList getBeforeToolCallback() { + return getCallbacks(Callbacks.BeforeToolCallback.class); + } + + public ImmutableList getAfterToolCallback() { + return getCallbacks(Callbacks.AfterToolCallback.class); + } + + public static Builder builder() { + return new Builder(); + } + + /** Builder for {@link CallbackPlugin}. */ + public static class Builder { + private final ImmutableList.Builder plugins = ImmutableList.builder(); + private final ListMultimap, Object> callbacks = ArrayListMultimap.create(); + + Builder() {} + + @CanIgnoreReturnValue + public Builder addBeforeAgentCallback(Callbacks.BeforeAgentCallback callback) { + callbacks.put(Callbacks.BeforeAgentCallback.class, callback); + plugins.add( + new BasePlugin("BeforeAgentCallback_" + callback.hashCode()) { + @Override + public Maybe beforeAgentCallback( + BaseAgent agent, CallbackContext callbackContext) { + return callback.call(callbackContext); + } + }); + return this; + } + + @CanIgnoreReturnValue + public Builder addBeforeAgentCallbackSync(Callbacks.BeforeAgentCallbackSync callback) { + return addBeforeAgentCallback( + callbackContext -> Maybe.fromOptional(callback.call(callbackContext))); + } + + @CanIgnoreReturnValue + public Builder addAfterAgentCallback(Callbacks.AfterAgentCallback callback) { + callbacks.put(Callbacks.AfterAgentCallback.class, callback); + plugins.add( + new BasePlugin("AfterAgentCallback_" + callback.hashCode()) { + @Override + public Maybe afterAgentCallback( + BaseAgent agent, CallbackContext callbackContext) { + return callback.call(callbackContext); + } + }); + return this; + } + + @CanIgnoreReturnValue + public Builder addAfterAgentCallbackSync(Callbacks.AfterAgentCallbackSync callback) { + return addAfterAgentCallback( + callbackContext -> Maybe.fromOptional(callback.call(callbackContext))); + } + + @CanIgnoreReturnValue + public Builder addBeforeModelCallback(Callbacks.BeforeModelCallback callback) { + callbacks.put(Callbacks.BeforeModelCallback.class, callback); + plugins.add( + new BasePlugin("BeforeModelCallback_" + callback.hashCode()) { + @Override + public Maybe beforeModelCallback( + CallbackContext callbackContext, LlmRequest.Builder llmRequest) { + return callback.call(callbackContext, llmRequest); + } + }); + return this; + } + + @CanIgnoreReturnValue + public Builder addBeforeModelCallbackSync(Callbacks.BeforeModelCallbackSync callback) { + return addBeforeModelCallback( + (callbackContext, llmRequest) -> + Maybe.fromOptional(callback.call(callbackContext, llmRequest))); + } + + @CanIgnoreReturnValue + public Builder addAfterModelCallback(Callbacks.AfterModelCallback callback) { + callbacks.put(Callbacks.AfterModelCallback.class, callback); + plugins.add( + new BasePlugin("AfterModelCallback_" + callback.hashCode()) { + @Override + public Maybe afterModelCallback( + CallbackContext callbackContext, LlmResponse llmResponse) { + return callback.call(callbackContext, llmResponse); + } + }); + return this; + } + + @CanIgnoreReturnValue + public Builder addAfterModelCallbackSync(Callbacks.AfterModelCallbackSync callback) { + return addAfterModelCallback( + (callbackContext, llmResponse) -> + Maybe.fromOptional(callback.call(callbackContext, llmResponse))); + } + + @CanIgnoreReturnValue + public Builder addBeforeToolCallback(Callbacks.BeforeToolCallback callback) { + callbacks.put(Callbacks.BeforeToolCallback.class, callback); + plugins.add( + new BasePlugin("BeforeToolCallback_" + callback.hashCode()) { + @Override + public Maybe> beforeToolCallback( + BaseTool tool, Map toolArgs, ToolContext toolContext) { + return callback.call(toolContext.invocationContext(), tool, toolArgs, toolContext); + } + }); + return this; + } + + @CanIgnoreReturnValue + public Builder addBeforeToolCallbackSync(Callbacks.BeforeToolCallbackSync callback) { + return addBeforeToolCallback( + (invocationContext, tool, toolArgs, toolContext) -> + Maybe.fromOptional(callback.call(invocationContext, tool, toolArgs, toolContext))); + } + + @CanIgnoreReturnValue + public Builder addAfterToolCallback(Callbacks.AfterToolCallback callback) { + callbacks.put(Callbacks.AfterToolCallback.class, callback); + plugins.add( + new BasePlugin("AfterToolCallback_" + callback.hashCode()) { + @Override + public Maybe> afterToolCallback( + BaseTool tool, + Map toolArgs, + ToolContext toolContext, + Map result) { + return callback.call( + toolContext.invocationContext(), tool, toolArgs, toolContext, result); + } + }); + return this; + } + + @CanIgnoreReturnValue + public Builder addAfterToolCallbackSync(Callbacks.AfterToolCallbackSync callback) { + return addAfterToolCallback( + (invocationContext, tool, toolArgs, toolContext, result) -> + Maybe.fromOptional( + callback.call(invocationContext, tool, toolArgs, toolContext, result))); + } + + @CanIgnoreReturnValue + public Builder addCallback(BeforeAgentCallbackBase callback) { + if (callback instanceof BeforeAgentCallback beforeAgentCallbackInstance) { + addBeforeAgentCallback(beforeAgentCallbackInstance); + } else if (callback instanceof BeforeAgentCallbackSync beforeAgentCallbackSyncInstance) { + addBeforeAgentCallbackSync(beforeAgentCallbackSyncInstance); + } else { + logger.warn( + "Invalid beforeAgentCallback callback type: %s. Ignoring this callback.", + callback.getClass().getName()); + } + return this; + } + + @CanIgnoreReturnValue + public Builder addCallback(AfterAgentCallbackBase callback) { + if (callback instanceof AfterAgentCallback afterAgentCallbackInstance) { + addAfterAgentCallback(afterAgentCallbackInstance); + } else if (callback instanceof AfterAgentCallbackSync afterAgentCallbackSyncInstance) { + addAfterAgentCallbackSync(afterAgentCallbackSyncInstance); + } else { + logger.warn( + "Invalid afterAgentCallback callback type: %s. Ignoring this callback.", + callback.getClass().getName()); + } + return this; + } + + @CanIgnoreReturnValue + public Builder addCallback(BeforeModelCallbackBase callback) { + if (callback instanceof BeforeModelCallback beforeModelCallbackInstance) { + addBeforeModelCallback(beforeModelCallbackInstance); + } else if (callback instanceof BeforeModelCallbackSync beforeModelCallbackSyncInstance) { + addBeforeModelCallbackSync(beforeModelCallbackSyncInstance); + } else { + logger.warn( + "Invalid beforeModelCallback callback type: %s. Ignoring this callback.", + callback.getClass().getName()); + } + return this; + } + + @CanIgnoreReturnValue + public Builder addCallback(AfterModelCallbackBase callback) { + if (callback instanceof AfterModelCallback afterModelCallbackInstance) { + addAfterModelCallback(afterModelCallbackInstance); + } else if (callback instanceof AfterModelCallbackSync afterModelCallbackSyncInstance) { + addAfterModelCallbackSync(afterModelCallbackSyncInstance); + } else { + logger.warn( + "Invalid afterModelCallback callback type: %s. Ignoring this callback.", + callback.getClass().getName()); + } + return this; + } + + @CanIgnoreReturnValue + public Builder addCallback(BeforeToolCallbackBase callback) { + if (callback instanceof BeforeToolCallback beforeToolCallbackInstance) { + addBeforeToolCallback(beforeToolCallbackInstance); + } else if (callback instanceof BeforeToolCallbackSync beforeToolCallbackSyncInstance) { + addBeforeToolCallbackSync(beforeToolCallbackSyncInstance); + } else { + logger.warn( + "Invalid beforeToolCallback callback type: {}. Ignoring this callback.", + callback.getClass().getName()); + } + return this; + } + + @CanIgnoreReturnValue + public Builder addCallback(AfterToolCallbackBase callback) { + if (callback instanceof AfterToolCallback afterToolCallbackInstance) { + addAfterToolCallback(afterToolCallbackInstance); + } else if (callback instanceof AfterToolCallbackSync afterToolCallbackSyncInstance) { + addAfterToolCallbackSync(afterToolCallbackSyncInstance); + } else { + logger.warn( + "Invalid afterToolCallback callback type: {}. Ignoring this callback.", + callback.getClass().getName()); + } + return this; + } + + @CanIgnoreReturnValue + public Builder addBeforeAgentCallbacks( + @Nullable List callbacks) { + if (callbacks == null || callbacks.isEmpty()) { + return this; + } + callbacks.forEach(this::addCallback); + return this; + } + + @CanIgnoreReturnValue + public Builder addAfterAgentCallbacks( + @Nullable List callbacks) { + if (callbacks == null || callbacks.isEmpty()) { + return this; + } + callbacks.forEach(this::addCallback); + return this; + } + + @CanIgnoreReturnValue + public Builder addBeforeModelCallbacks( + @Nullable List callbacks) { + if (callbacks == null || callbacks.isEmpty()) { + return this; + } + callbacks.forEach(this::addCallback); + return this; + } + + @CanIgnoreReturnValue + public Builder addAfterModelCallbacks( + @Nullable List callbacks) { + if (callbacks == null || callbacks.isEmpty()) { + return this; + } + callbacks.forEach(this::addCallback); + return this; + } + + @CanIgnoreReturnValue + public Builder addBeforeToolCallbacks( + @Nullable List callbacks) { + if (callbacks == null || callbacks.isEmpty()) { + return this; + } + callbacks.forEach(this::addCallback); + return this; + } + + @CanIgnoreReturnValue + public Builder addAfterToolCallbacks( + @Nullable List callbacks) { + if (callbacks == null || callbacks.isEmpty()) { + return this; + } + callbacks.forEach(this::addCallback); + return this; + } + + public CallbackPlugin build() { + return new CallbackPlugin(plugins.build(), ImmutableListMultimap.copyOf(callbacks)); + } + } +} diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index ab3ee7ed..ed224927 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -56,7 +56,6 @@ import com.google.genai.types.Part; import com.google.genai.types.Schema; import io.reactivex.rxjava3.core.Flowable; -import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.ArrayList; import java.util.List; @@ -95,10 +94,6 @@ public enum IncludeContents { private final Optional maxSteps; private final boolean disallowTransferToParent; private final boolean disallowTransferToPeers; - private final Optional> beforeModelCallback; - private final Optional> afterModelCallback; - private final Optional> beforeToolCallback; - private final Optional> afterToolCallback; private final Optional inputSchema; private final Optional outputSchema; private final Optional executor; @@ -113,8 +108,7 @@ protected LlmAgent(Builder builder) { builder.name, builder.description, builder.subAgents, - builder.beforeAgentCallback, - builder.afterAgentCallback); + builder.callbackPluginBuilder.build()); this.model = Optional.ofNullable(builder.model); this.instruction = builder.instruction == null ? new Instruction.Static("") : builder.instruction; @@ -128,10 +122,6 @@ protected LlmAgent(Builder builder) { this.maxSteps = Optional.ofNullable(builder.maxSteps); this.disallowTransferToParent = builder.disallowTransferToParent; this.disallowTransferToPeers = builder.disallowTransferToPeers; - this.beforeModelCallback = Optional.ofNullable(builder.beforeModelCallback); - this.afterModelCallback = Optional.ofNullable(builder.afterModelCallback); - this.beforeToolCallback = Optional.ofNullable(builder.beforeToolCallback); - this.afterToolCallback = Optional.ofNullable(builder.afterToolCallback); this.inputSchema = Optional.ofNullable(builder.inputSchema); this.outputSchema = Optional.ofNullable(builder.outputSchema); this.executor = Optional.ofNullable(builder.executor); @@ -173,10 +163,6 @@ public static class Builder extends BaseAgent.Builder { private Integer maxSteps; private Boolean disallowTransferToParent; private Boolean disallowTransferToPeers; - private ImmutableList beforeModelCallback; - private ImmutableList afterModelCallback; - private ImmutableList beforeToolCallback; - private ImmutableList afterToolCallback; private Schema inputSchema; private Schema outputSchema; private Executor executor; @@ -290,200 +276,86 @@ public Builder disallowTransferToPeers(boolean disallowTransferToPeers) { @CanIgnoreReturnValue public Builder beforeModelCallback(BeforeModelCallback beforeModelCallback) { - this.beforeModelCallback = ImmutableList.of(beforeModelCallback); + callbackPluginBuilder().addBeforeModelCallback(beforeModelCallback); return this; } @CanIgnoreReturnValue public Builder beforeModelCallback(List beforeModelCallback) { - if (beforeModelCallback == null) { - this.beforeModelCallback = null; - } else if (beforeModelCallback.isEmpty()) { - this.beforeModelCallback = ImmutableList.of(); - } else { - ImmutableList.Builder builder = ImmutableList.builder(); - for (BeforeModelCallbackBase callback : beforeModelCallback) { - if (callback instanceof BeforeModelCallback beforeModelCallbackInstance) { - builder.add(beforeModelCallbackInstance); - } else if (callback instanceof BeforeModelCallbackSync beforeModelCallbackSyncInstance) { - builder.add( - (BeforeModelCallback) - (callbackContext, llmRequestBuilder) -> - Maybe.fromOptional( - beforeModelCallbackSyncInstance.call( - callbackContext, llmRequestBuilder))); - } else { - logger.warn( - "Invalid beforeModelCallback callback type: %s. Ignoring this callback.", - callback.getClass().getName()); - } - } - this.beforeModelCallback = builder.build(); - } - + callbackPluginBuilder().addBeforeModelCallbacks(beforeModelCallback); return this; } @CanIgnoreReturnValue public Builder beforeModelCallbackSync(BeforeModelCallbackSync beforeModelCallbackSync) { - this.beforeModelCallback = - ImmutableList.of( - (callbackContext, llmRequestBuilder) -> - Maybe.fromOptional( - beforeModelCallbackSync.call(callbackContext, llmRequestBuilder))); + callbackPluginBuilder().addBeforeModelCallbackSync(beforeModelCallbackSync); return this; } @CanIgnoreReturnValue public Builder afterModelCallback(AfterModelCallback afterModelCallback) { - this.afterModelCallback = ImmutableList.of(afterModelCallback); + callbackPluginBuilder().addAfterModelCallback(afterModelCallback); return this; } @CanIgnoreReturnValue public Builder afterModelCallback(List afterModelCallback) { - if (afterModelCallback == null) { - this.afterModelCallback = null; - } else if (afterModelCallback.isEmpty()) { - this.afterModelCallback = ImmutableList.of(); - } else { - ImmutableList.Builder builder = ImmutableList.builder(); - for (AfterModelCallbackBase callback : afterModelCallback) { - if (callback instanceof AfterModelCallback afterModelCallbackInstance) { - builder.add(afterModelCallbackInstance); - } else if (callback instanceof AfterModelCallbackSync afterModelCallbackSyncInstance) { - builder.add( - (AfterModelCallback) - (callbackContext, llmResponse) -> - Maybe.fromOptional( - afterModelCallbackSyncInstance.call(callbackContext, llmResponse))); - } else { - logger.warn( - "Invalid afterModelCallback callback type: %s. Ignoring this callback.", - callback.getClass().getName()); - } - } - this.afterModelCallback = builder.build(); - } - + callbackPluginBuilder().addAfterModelCallbacks(afterModelCallback); return this; } @CanIgnoreReturnValue public Builder afterModelCallbackSync(AfterModelCallbackSync afterModelCallbackSync) { - this.afterModelCallback = - ImmutableList.of( - (callbackContext, llmResponse) -> - Maybe.fromOptional(afterModelCallbackSync.call(callbackContext, llmResponse))); + callbackPluginBuilder().addAfterModelCallbackSync(afterModelCallbackSync); return this; } @CanIgnoreReturnValue public Builder beforeAgentCallbackSync(BeforeAgentCallbackSync beforeAgentCallbackSync) { - this.beforeAgentCallback = - ImmutableList.of( - (callbackContext) -> - Maybe.fromOptional(beforeAgentCallbackSync.call(callbackContext))); + callbackPluginBuilder().addBeforeAgentCallbackSync(beforeAgentCallbackSync); return this; } @CanIgnoreReturnValue public Builder afterAgentCallbackSync(AfterAgentCallbackSync afterAgentCallbackSync) { - this.afterAgentCallback = - ImmutableList.of( - (callbackContext) -> - Maybe.fromOptional(afterAgentCallbackSync.call(callbackContext))); + callbackPluginBuilder().addAfterAgentCallbackSync(afterAgentCallbackSync); return this; } @CanIgnoreReturnValue public Builder beforeToolCallback(BeforeToolCallback beforeToolCallback) { - this.beforeToolCallback = ImmutableList.of(beforeToolCallback); + callbackPluginBuilder().addBeforeToolCallback(beforeToolCallback); return this; } @CanIgnoreReturnValue public Builder beforeToolCallback( @Nullable List beforeToolCallbacks) { - if (beforeToolCallbacks == null) { - this.beforeToolCallback = null; - } else if (beforeToolCallbacks.isEmpty()) { - this.beforeToolCallback = ImmutableList.of(); - } else { - ImmutableList.Builder builder = ImmutableList.builder(); - for (BeforeToolCallbackBase callback : beforeToolCallbacks) { - if (callback instanceof BeforeToolCallback beforeToolCallbackInstance) { - builder.add(beforeToolCallbackInstance); - } else if (callback instanceof BeforeToolCallbackSync beforeToolCallbackSyncInstance) { - builder.add( - (invocationContext, baseTool, input, toolContext) -> - Maybe.fromOptional( - beforeToolCallbackSyncInstance.call( - invocationContext, baseTool, input, toolContext))); - } else { - logger.warn( - "Invalid beforeToolCallback callback type: {}. Ignoring this callback.", - callback.getClass().getName()); - } - } - this.beforeToolCallback = builder.build(); - } + callbackPluginBuilder().addBeforeToolCallbacks(beforeToolCallbacks); return this; } @CanIgnoreReturnValue public Builder beforeToolCallbackSync(BeforeToolCallbackSync beforeToolCallbackSync) { - this.beforeToolCallback = - ImmutableList.of( - (invocationContext, baseTool, input, toolContext) -> - Maybe.fromOptional( - beforeToolCallbackSync.call( - invocationContext, baseTool, input, toolContext))); + callbackPluginBuilder().addBeforeToolCallbackSync(beforeToolCallbackSync); return this; } @CanIgnoreReturnValue public Builder afterToolCallback(AfterToolCallback afterToolCallback) { - this.afterToolCallback = ImmutableList.of(afterToolCallback); + callbackPluginBuilder().addAfterToolCallback(afterToolCallback); return this; } @CanIgnoreReturnValue public Builder afterToolCallback(@Nullable List afterToolCallbacks) { - if (afterToolCallbacks == null) { - this.afterToolCallback = null; - } else if (afterToolCallbacks.isEmpty()) { - this.afterToolCallback = ImmutableList.of(); - } else { - ImmutableList.Builder builder = ImmutableList.builder(); - for (AfterToolCallbackBase callback : afterToolCallbacks) { - if (callback instanceof AfterToolCallback afterToolCallbackInstance) { - builder.add(afterToolCallbackInstance); - } else if (callback instanceof AfterToolCallbackSync afterToolCallbackSyncInstance) { - builder.add( - (invocationContext, baseTool, input, toolContext, response) -> - Maybe.fromOptional( - afterToolCallbackSyncInstance.call( - invocationContext, baseTool, input, toolContext, response))); - } else { - logger.warn( - "Invalid afterToolCallback callback type: {}. Ignoring this callback.", - callback.getClass().getName()); - } - } - this.afterToolCallback = builder.build(); - } + callbackPluginBuilder().addAfterToolCallbacks(afterToolCallbacks); return this; } @CanIgnoreReturnValue public Builder afterToolCallbackSync(AfterToolCallbackSync afterToolCallbackSync) { - this.afterToolCallback = - ImmutableList.of( - (invocationContext, baseTool, input, toolContext, response) -> - Maybe.fromOptional( - afterToolCallbackSync.call( - invocationContext, baseTool, input, toolContext, response))); + callbackPluginBuilder().addAfterToolCallbackSync(afterToolCallbackSync); return this; } @@ -757,19 +629,19 @@ public boolean disallowTransferToPeers() { } public Optional> beforeModelCallback() { - return beforeModelCallback; + return Optional.of(callbackPlugin.getBeforeModelCallback()); } public Optional> afterModelCallback() { - return afterModelCallback; + return Optional.of(callbackPlugin.getAfterModelCallback()); } public Optional> beforeToolCallback() { - return beforeToolCallback; + return Optional.of(callbackPlugin.getBeforeToolCallback()); } public Optional> afterToolCallback() { - return afterToolCallback; + return Optional.of(callbackPlugin.getAfterToolCallback()); } public Optional inputSchema() { @@ -830,8 +702,8 @@ private Model resolveModelInternal() { } BaseAgent current = this.parentAgent(); while (current != null) { - if (current instanceof LlmAgent) { - return ((LlmAgent) current).resolvedModel(); + if (current instanceof LlmAgent llmAgent) { + return llmAgent.resolvedModel(); } current = current.parentAgent(); } diff --git a/core/src/main/java/com/google/adk/agents/LoopAgent.java b/core/src/main/java/com/google/adk/agents/LoopAgent.java index d9d049f8..921ef368 100644 --- a/core/src/main/java/com/google/adk/agents/LoopAgent.java +++ b/core/src/main/java/com/google/adk/agents/LoopAgent.java @@ -46,16 +46,13 @@ public class LoopAgent extends BaseAgent { * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. */ - private LoopAgent( - String name, - String description, - List subAgents, - Optional maxIterations, - List beforeAgentCallback, - List afterAgentCallback) { - - super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); - this.maxIterations = maxIterations; + private LoopAgent(Builder builder) { + super( + builder.name, + builder.description, + builder.subAgents, + builder.callbackPluginBuilder.build()); + this.maxIterations = builder.maxIterations; } /** Builder for {@link LoopAgent}. */ @@ -76,9 +73,7 @@ public Builder maxIterations(Optional maxIterations) { @Override public LoopAgent build() { - // TODO(b/410859954): Add validation for required fields like name. - return new LoopAgent( - name, description, subAgents, maxIterations, beforeAgentCallback, afterAgentCallback); + return new LoopAgent(this); } } diff --git a/core/src/main/java/com/google/adk/agents/ParallelAgent.java b/core/src/main/java/com/google/adk/agents/ParallelAgent.java index f30d951a..583bfffc 100644 --- a/core/src/main/java/com/google/adk/agents/ParallelAgent.java +++ b/core/src/main/java/com/google/adk/agents/ParallelAgent.java @@ -45,14 +45,12 @@ public class ParallelAgent extends BaseAgent { * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. */ - private ParallelAgent( - String name, - String description, - List subAgents, - List beforeAgentCallback, - List afterAgentCallback) { - - super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); + private ParallelAgent(Builder builder) { + super( + builder.name, + builder.description, + builder.subAgents, + builder.callbackPluginBuilder.build()); } /** Builder for {@link ParallelAgent}. */ @@ -60,8 +58,7 @@ public static class Builder extends BaseAgent.Builder { @Override public ParallelAgent build() { - return new ParallelAgent( - name, description, subAgents, beforeAgentCallback, afterAgentCallback); + return new ParallelAgent(this); } } diff --git a/core/src/main/java/com/google/adk/agents/ReadonlyContext.java b/core/src/main/java/com/google/adk/agents/ReadonlyContext.java index 7d3a5acb..dc7480f5 100644 --- a/core/src/main/java/com/google/adk/agents/ReadonlyContext.java +++ b/core/src/main/java/com/google/adk/agents/ReadonlyContext.java @@ -34,6 +34,11 @@ public ReadonlyContext(InvocationContext invocationContext) { this.invocationContext = invocationContext; } + /** Returns the invocation context. */ + public InvocationContext invocationContext() { + return invocationContext; + } + /** Returns the user content that initiated this invocation. */ public Optional userContent() { return invocationContext.userContent(); diff --git a/core/src/main/java/com/google/adk/agents/SequentialAgent.java b/core/src/main/java/com/google/adk/agents/SequentialAgent.java index b0b45a0e..aa4b76fb 100644 --- a/core/src/main/java/com/google/adk/agents/SequentialAgent.java +++ b/core/src/main/java/com/google/adk/agents/SequentialAgent.java @@ -18,7 +18,6 @@ import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.events.Event; import io.reactivex.rxjava3.core.Flowable; -import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -36,14 +35,12 @@ public class SequentialAgent extends BaseAgent { * @param beforeAgentCallback Optional callback before the agent runs. * @param afterAgentCallback Optional callback after the agent runs. */ - private SequentialAgent( - String name, - String description, - List subAgents, - List beforeAgentCallback, - List afterAgentCallback) { - - super(name, description, subAgents, beforeAgentCallback, afterAgentCallback); + private SequentialAgent(Builder builder) { + super( + builder.name, + builder.description, + builder.subAgents, + builder.callbackPluginBuilder.build()); } /** Builder for {@link SequentialAgent}. */ @@ -51,9 +48,7 @@ public static class Builder extends BaseAgent.Builder { @Override public SequentialAgent build() { - // TODO(b/410859954): Add validation for required fields like name. - return new SequentialAgent( - name, description, subAgents, beforeAgentCallback, afterAgentCallback); + return new SequentialAgent(this); } } diff --git a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java index 6e06a34a..92ec8791 100644 --- a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java @@ -20,14 +20,12 @@ import com.google.adk.events.Event; import com.google.adk.testing.TestBaseAgent; +import com.google.adk.testing.TestCallback; import com.google.adk.testing.TestUtils; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.Part; -import io.reactivex.rxjava3.core.Flowable; -import io.reactivex.rxjava3.core.Maybe; import java.util.List; -import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -51,37 +49,97 @@ public void constructor_setsNameAndDescription() { @Test public void runAsync_beforeAgentCallbackReturnsContent_endsInvocationAndSkipsRunAsyncImplAndAfterCallback() { - AtomicBoolean runAsyncImplCalled = new AtomicBoolean(false); - AtomicBoolean afterAgentCallbackCalled = new AtomicBoolean(false); + var runAsyncImpl = TestCallback.returningEmpty(); Content callbackContent = Content.fromParts(Part.fromText("before_callback_output")); - Callbacks.BeforeAgentCallback beforeCallback = (callbackContext) -> Maybe.just(callbackContent); - Callbacks.AfterAgentCallback afterCallback = - (callbackContext) -> { - afterAgentCallbackCalled.set(true); - return Maybe.empty(); - }; + var beforeCallback = TestCallback.returning(callbackContent); + var afterCallback = TestCallback.returningEmpty(); TestBaseAgent agent = new TestBaseAgent( TEST_AGENT_NAME, TEST_AGENT_DESCRIPTION, - ImmutableList.of(beforeCallback), - ImmutableList.of(afterCallback), - () -> - Flowable.defer( - () -> { - runAsyncImplCalled.set(true); - return Flowable.just( - Event.builder() - .content(Content.fromParts(Part.fromText("main_output"))) - .build()); - })); + ImmutableList.of(beforeCallback.asBeforeAgentCallback()), + ImmutableList.of(afterCallback.asAfterAgentCallback()), + runAsyncImpl.asRunAsyncImplSupplier("main_output")); InvocationContext invocationContext = TestUtils.createInvocationContext(agent); List results = agent.runAsync(invocationContext).toList().blockingGet(); assertThat(results).hasSize(1); assertThat(results.get(0).content()).hasValue(callbackContent); - assertThat(runAsyncImplCalled.get()).isFalse(); - assertThat(afterAgentCallbackCalled.get()).isFalse(); + assertThat(runAsyncImpl.wasCalled()).isFalse(); + assertThat(beforeCallback.wasCalled()).isTrue(); + assertThat(afterCallback.wasCalled()).isFalse(); + } + + @Test + public void runAsync_noCallbacks_invokesRunAsyncImpl() { + var runAsyncImpl = TestCallback.returningEmpty(); + Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + /* beforeAgentCallbacks= */ ImmutableList.of(), + /* afterAgentCallbacks= */ ImmutableList.of(), + runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(1); + assertThat(results.get(0).content()).hasValue(runAsyncImplContent); + assertThat(runAsyncImpl.wasCalled()).isTrue(); + } + + @Test + public void + runAsync_beforeCallbackReturnsEmptyAndAfterCallbackReturnsEmpty_invokesRunAsyncImplAndAfterCallbacks() { + var runAsyncImpl = TestCallback.returningEmpty(); + Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); + var beforeCallback = TestCallback.returningEmpty(); + var afterCallback = TestCallback.returningEmpty(); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback.asBeforeAgentCallback()), + ImmutableList.of(afterCallback.asAfterAgentCallback()), + runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(1); + assertThat(results.get(0).content()).hasValue(runAsyncImplContent); + assertThat(runAsyncImpl.wasCalled()).isTrue(); + assertThat(beforeCallback.wasCalled()).isTrue(); + assertThat(afterCallback.wasCalled()).isTrue(); + } + + @Test + public void + runAsync_afterCallbackReturnsContent_invokesRunAsyncImplAndAfterCallbacksAndReturnsAllContent() { + var runAsyncImpl = TestCallback.returningEmpty(); + Content runAsyncImplContent = Content.fromParts(Part.fromText("main_output")); + Content afterCallbackContent = Content.fromParts(Part.fromText("after_callback_output")); + var beforeCallback = TestCallback.returningEmpty(); + var afterCallback = TestCallback.returning(afterCallbackContent); + TestBaseAgent agent = + new TestBaseAgent( + TEST_AGENT_NAME, + TEST_AGENT_DESCRIPTION, + ImmutableList.of(beforeCallback.asBeforeAgentCallback()), + ImmutableList.of(afterCallback.asAfterAgentCallback()), + runAsyncImpl.asRunAsyncImplSupplier(runAsyncImplContent)); + InvocationContext invocationContext = TestUtils.createInvocationContext(agent); + + List results = agent.runAsync(invocationContext).toList().blockingGet(); + + assertThat(results).hasSize(2); + assertThat(results.get(0).content()).hasValue(runAsyncImplContent); + assertThat(results.get(1).content()).hasValue(afterCallbackContent); + assertThat(runAsyncImpl.wasCalled()).isTrue(); + assertThat(beforeCallback.wasCalled()).isTrue(); + assertThat(afterCallback.wasCalled()).isTrue(); } } diff --git a/core/src/test/java/com/google/adk/agents/CallbackPluginTest.java b/core/src/test/java/com/google/adk/agents/CallbackPluginTest.java new file mode 100644 index 00000000..fd90e2b4 --- /dev/null +++ b/core/src/test/java/com/google/adk/agents/CallbackPluginTest.java @@ -0,0 +1,560 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.agents; + +import static com.google.adk.testing.TestUtils.createInvocationContext; +import static com.google.common.truth.Truth.assertThat; + +import com.google.adk.agents.Callbacks.AfterAgentCallback; +import com.google.adk.agents.Callbacks.AfterAgentCallbackSync; +import com.google.adk.agents.Callbacks.AfterModelCallback; +import com.google.adk.agents.Callbacks.AfterModelCallbackSync; +import com.google.adk.agents.Callbacks.AfterToolCallback; +import com.google.adk.agents.Callbacks.AfterToolCallbackSync; +import com.google.adk.agents.Callbacks.BeforeAgentCallback; +import com.google.adk.agents.Callbacks.BeforeAgentCallbackSync; +import com.google.adk.agents.Callbacks.BeforeModelCallback; +import com.google.adk.agents.Callbacks.BeforeModelCallbackSync; +import com.google.adk.agents.Callbacks.BeforeToolCallback; +import com.google.adk.agents.Callbacks.BeforeToolCallbackSync; +import com.google.adk.events.EventActions; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.testing.TestCallback; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Maybe; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public final class CallbackPluginTest { + + @Rule public final MockitoRule mockito = MockitoJUnit.rule(); + @Mock private BaseAgent agent; + @Mock private BaseTool tool; + @Mock private ToolContext toolContext; + private InvocationContext invocationContext; + private CallbackContext callbackContext; + + @Before + public void setUp() { + invocationContext = createInvocationContext(agent); + callbackContext = + new CallbackContext( + invocationContext, + EventActions.builder().stateDelta(new ConcurrentHashMap<>()).build()); + } + + @Test + public void build_empty_successful() { + CallbackPlugin plugin = CallbackPlugin.builder().build(); + assertThat(plugin.getName()).isEqualTo("CallbackPlugin"); + assertThat(plugin.getBeforeAgentCallback()).isEmpty(); + assertThat(plugin.getAfterAgentCallback()).isEmpty(); + assertThat(plugin.getBeforeModelCallback()).isEmpty(); + assertThat(plugin.getAfterModelCallback()).isEmpty(); + assertThat(plugin.getBeforeToolCallback()).isEmpty(); + assertThat(plugin.getAfterToolCallback()).isEmpty(); + } + + @Test + public void addBeforeAgentCallback_isReturnedAndInvoked() { + Content expectedContent = Content.fromParts(Part.fromText("test")); + var testCallback = TestCallback.returning(expectedContent); + BeforeAgentCallback callback = testCallback.asBeforeAgentCallback(); + + CallbackPlugin plugin = CallbackPlugin.builder().addBeforeAgentCallback(callback).build(); + + assertThat(plugin.getBeforeAgentCallback()).containsExactly(callback); + + Content result = plugin.beforeAgentCallback(agent, callbackContext).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedContent); + } + + @Test + public void addBeforeAgentCallbackSync_isReturnedAndInvoked() { + Content expectedContent = Content.fromParts(Part.fromText("test")); + var testCallback = TestCallback.returning(expectedContent); + CallbackPlugin plugin = + CallbackPlugin.builder() + .addBeforeAgentCallbackSync(testCallback.asBeforeAgentCallbackSync()) + .build(); + + assertThat(plugin.getBeforeAgentCallback()).hasSize(1); + + Content result = plugin.beforeAgentCallback(agent, callbackContext).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedContent); + } + + @Test + public void addAfterAgentCallback_isReturnedAndInvoked() { + Content expectedContent = Content.fromParts(Part.fromText("test")); + var testCallback = TestCallback.returning(expectedContent); + AfterAgentCallback callback = testCallback.asAfterAgentCallback(); + + CallbackPlugin plugin = CallbackPlugin.builder().addAfterAgentCallback(callback).build(); + + assertThat(plugin.getAfterAgentCallback()).containsExactly(callback); + + Content result = plugin.afterAgentCallback(agent, callbackContext).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedContent); + } + + @Test + public void addAfterAgentCallbackSync_isReturnedAndInvoked() { + Content expectedContent = Content.fromParts(Part.fromText("test")); + var testCallback = TestCallback.returning(expectedContent); + CallbackPlugin plugin = + CallbackPlugin.builder() + .addAfterAgentCallbackSync(testCallback.asAfterAgentCallbackSync()) + .build(); + + assertThat(plugin.getAfterAgentCallback()).hasSize(1); + + Content result = plugin.afterAgentCallback(agent, callbackContext).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedContent); + } + + @Test + public void addBeforeModelCallback_isReturnedAndInvoked() { + LlmResponse expectedResponse = LlmResponse.builder().build(); + var testCallback = TestCallback.returning(expectedResponse); + BeforeModelCallback callback = testCallback.asBeforeModelCallback(); + + CallbackPlugin plugin = CallbackPlugin.builder().addBeforeModelCallback(callback).build(); + + assertThat(plugin.getBeforeModelCallback()).containsExactly(callback); + + LlmResponse result = + plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResponse); + } + + @Test + public void addBeforeModelCallbackSync_isReturnedAndInvoked() { + LlmResponse expectedResponse = LlmResponse.builder().build(); + var testCallback = TestCallback.returning(expectedResponse); + CallbackPlugin plugin = + CallbackPlugin.builder() + .addBeforeModelCallbackSync(testCallback.asBeforeModelCallbackSync()) + .build(); + + assertThat(plugin.getBeforeModelCallback()).hasSize(1); + + LlmResponse result = + plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResponse); + } + + @Test + public void addAfterModelCallback_isReturnedAndInvoked() { + LlmResponse initialResponse = LlmResponse.builder().build(); + LlmResponse expectedResponse = + LlmResponse.builder().content(Content.fromParts(Part.fromText("test"))).build(); + var testCallback = TestCallback.returning(expectedResponse); + AfterModelCallback callback = testCallback.asAfterModelCallback(); + + CallbackPlugin plugin = CallbackPlugin.builder().addAfterModelCallback(callback).build(); + + assertThat(plugin.getAfterModelCallback()).containsExactly(callback); + + LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResponse); + } + + @Test + public void addAfterModelCallbackSync_isReturnedAndInvoked() { + LlmResponse initialResponse = LlmResponse.builder().build(); + LlmResponse expectedResponse = + LlmResponse.builder().content(Content.fromParts(Part.fromText("test"))).build(); + var testCallback = TestCallback.returning(expectedResponse); + AfterModelCallbackSync callback = testCallback.asAfterModelCallbackSync(); + + CallbackPlugin plugin = CallbackPlugin.builder().addAfterModelCallbackSync(callback).build(); + + assertThat(plugin.getAfterModelCallback()).hasSize(1); + + LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResponse); + } + + @Test + public void addBeforeToolCallback_isReturnedAndInvoked() { + ImmutableMap expectedResult = ImmutableMap.of("key", "value"); + var testCallback = TestCallback.returning(expectedResult); + BeforeToolCallback callback = testCallback.asBeforeToolCallback(); + + CallbackPlugin plugin = CallbackPlugin.builder().addBeforeToolCallback(callback).build(); + + assertThat(plugin.getBeforeToolCallback()).containsExactly(callback); + + Map result = + plugin.beforeToolCallback(tool, ImmutableMap.of(), toolContext).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResult); + } + + @Test + public void addBeforeToolCallbackSync_isReturnedAndInvoked() { + ImmutableMap expectedResult = ImmutableMap.of("key", "value"); + var testCallback = TestCallback.returning(expectedResult); + CallbackPlugin plugin = + CallbackPlugin.builder() + .addBeforeToolCallbackSync(testCallback.asBeforeToolCallbackSync()) + .build(); + + assertThat(plugin.getBeforeToolCallback()).hasSize(1); + + Map result = + plugin.beforeToolCallback(tool, ImmutableMap.of(), toolContext).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResult); + } + + @Test + public void addAfterToolCallback_isReturnedAndInvoked() { + ImmutableMap initialResult = ImmutableMap.of("initial", "result"); + ImmutableMap expectedResult = ImmutableMap.of("key", "value"); + var testCallback = TestCallback.returning(expectedResult); + AfterToolCallback callback = testCallback.asAfterToolCallback(); + + CallbackPlugin plugin = CallbackPlugin.builder().addAfterToolCallback(callback).build(); + + assertThat(plugin.getAfterToolCallback()).containsExactly(callback); + + Map result = + plugin.afterToolCallback(tool, ImmutableMap.of(), toolContext, initialResult).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResult); + } + + @Test + public void addAfterToolCallbackSync_isReturnedAndInvoked() { + ImmutableMap initialResult = ImmutableMap.of("initial", "result"); + ImmutableMap expectedResult = ImmutableMap.of("key", "value"); + var testCallback = TestCallback.returning(expectedResult); + AfterToolCallbackSync callback = testCallback.asAfterToolCallbackSync(); + + CallbackPlugin plugin = CallbackPlugin.builder().addAfterToolCallbackSync(callback).build(); + + assertThat(plugin.getAfterToolCallback()).hasSize(1); + + Map result = + plugin.afterToolCallback(tool, ImmutableMap.of(), toolContext, initialResult).blockingGet(); + assertThat(testCallback.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResult); + } + + @Test + public void addCallback_beforeAgentCallback() { + BeforeAgentCallback callback = ctx -> Maybe.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getBeforeAgentCallback()).containsExactly(callback); + } + + @Test + public void addCallback_beforeAgentCallbackSync() { + BeforeAgentCallbackSync callback = ctx -> Optional.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getBeforeAgentCallback()).hasSize(1); + } + + @Test + public void addCallback_afterAgentCallback() { + AfterAgentCallback callback = ctx -> Maybe.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getAfterAgentCallback()).containsExactly(callback); + } + + @Test + public void addCallback_afterAgentCallbackSync() { + AfterAgentCallbackSync callback = ctx -> Optional.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getAfterAgentCallback()).hasSize(1); + } + + @Test + public void addCallback_beforeModelCallback() { + BeforeModelCallback callback = (ctx, req) -> Maybe.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getBeforeModelCallback()).containsExactly(callback); + } + + @Test + public void addCallback_beforeModelCallbackSync() { + BeforeModelCallbackSync callback = (ctx, req) -> Optional.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getBeforeModelCallback()).hasSize(1); + } + + @Test + public void addCallback_afterModelCallback() { + AfterModelCallback callback = (ctx, res) -> Maybe.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getAfterModelCallback()).containsExactly(callback); + } + + @Test + public void addCallback_afterModelCallbackSync() { + AfterModelCallbackSync callback = (ctx, res) -> Optional.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getAfterModelCallback()).hasSize(1); + } + + @Test + public void addCallback_beforeToolCallback() { + BeforeToolCallback callback = (invCtx, tool, toolArgs, toolCtx) -> Maybe.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getBeforeToolCallback()).containsExactly(callback); + } + + @Test + public void addCallback_beforeToolCallbackSync() { + BeforeToolCallbackSync callback = (invCtx, tool, toolArgs, toolCtx) -> Optional.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getBeforeToolCallback()).hasSize(1); + } + + @Test + public void addCallback_afterToolCallback() { + AfterToolCallback callback = (invCtx, tool, toolArgs, toolCtx, res) -> Maybe.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getAfterToolCallback()).containsExactly(callback); + } + + @Test + public void addCallback_afterToolCallbackSync() { + AfterToolCallbackSync callback = (invCtx, tool, toolArgs, toolCtx, res) -> Optional.empty(); + CallbackPlugin plugin = CallbackPlugin.builder().addCallback(callback).build(); + assertThat(plugin.getAfterToolCallback()).hasSize(1); + } + + @Test + public void addMultipleBeforeModelCallbacks_invokedInOrder() { + LlmResponse expectedResponse = LlmResponse.builder().build(); + var testCallback1 = TestCallback.returningEmpty(); + var testCallback2 = TestCallback.returning(expectedResponse); + BeforeModelCallback callback1 = testCallback1.asBeforeModelCallback(); + BeforeModelCallback callback2 = testCallback2.asBeforeModelCallback(); + + CallbackPlugin plugin = + CallbackPlugin.builder() + .addBeforeModelCallback(callback1) + .addBeforeModelCallback(callback2) + .build(); + + assertThat(plugin.getBeforeModelCallback()).containsExactly(callback1, callback2).inOrder(); + + LlmResponse result = + plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); + assertThat(testCallback1.wasCalled()).isTrue(); + assertThat(testCallback2.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResponse); + } + + @Test + public void addMultipleBeforeModelCallbacks_shortCircuit() { + LlmResponse expectedResponse = LlmResponse.builder().build(); + var testCallback1 = TestCallback.returning(expectedResponse); + var testCallback2 = TestCallback.returningEmpty(); + BeforeModelCallback callback1 = testCallback1.asBeforeModelCallback(); + BeforeModelCallback callback2 = testCallback2.asBeforeModelCallback(); + + CallbackPlugin plugin = + CallbackPlugin.builder() + .addBeforeModelCallback(callback1) + .addBeforeModelCallback(callback2) + .build(); + + assertThat(plugin.getBeforeModelCallback()).containsExactly(callback1, callback2).inOrder(); + + LlmResponse result = + plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); + assertThat(testCallback1.wasCalled()).isTrue(); + assertThat(testCallback2.wasCalled()).isFalse(); + assertThat(result).isEqualTo(expectedResponse); + } + + @Test + public void addMultipleAfterModelCallbacks_shortCircuit() { + LlmResponse initialResponse = LlmResponse.builder().build(); + LlmResponse expectedResponse = + LlmResponse.builder().content(Content.fromParts(Part.fromText("response"))).build(); + var testCallback1 = TestCallback.returning(expectedResponse); + var testCallback2 = TestCallback.returningEmpty(); + AfterModelCallback callback1 = testCallback1.asAfterModelCallback(); + AfterModelCallback callback2 = testCallback2.asAfterModelCallback(); + CallbackPlugin plugin = + CallbackPlugin.builder() + .addAfterModelCallback(callback1) + .addAfterModelCallback(callback2) + .build(); + + assertThat(plugin.getAfterModelCallback()).containsExactly(callback1, callback2).inOrder(); + LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); + assertThat(testCallback1.wasCalled()).isTrue(); + assertThat(testCallback2.wasCalled()).isFalse(); + assertThat(result).isEqualTo(expectedResponse); + } + + @Test + public void addMultipleAfterModelCallbacks_invokedInOrder() { + LlmResponse initialResponse = LlmResponse.builder().build(); + LlmResponse expectedResponse = + LlmResponse.builder().content(Content.fromParts(Part.fromText("second"))).build(); + var testCallback1 = TestCallback.returningEmpty(); + var testCallback2 = TestCallback.returning(expectedResponse); + AfterModelCallback callback1 = testCallback1.asAfterModelCallback(); + AfterModelCallback callback2 = testCallback2.asAfterModelCallback(); + + CallbackPlugin plugin = + CallbackPlugin.builder() + .addAfterModelCallback(callback1) + .addAfterModelCallback(callback2) + .build(); + + assertThat(plugin.getAfterModelCallback()).containsExactly(callback1, callback2).inOrder(); + + LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); + assertThat(testCallback1.wasCalled()).isTrue(); + assertThat(testCallback2.wasCalled()).isTrue(); + assertThat(result).isEqualTo(expectedResponse); + } + + @Test + public void addMultipleBeforeModelCallbacks_bothEmpty_returnsEmpty() { + var testCallback1 = TestCallback.returningEmpty(); + var testCallback2 = TestCallback.returningEmpty(); + BeforeModelCallback callback1 = testCallback1.asBeforeModelCallback(); + BeforeModelCallback callback2 = testCallback2.asBeforeModelCallback(); + + CallbackPlugin plugin = + CallbackPlugin.builder() + .addBeforeModelCallback(callback1) + .addBeforeModelCallback(callback2) + .build(); + + assertThat(plugin.getBeforeModelCallback()).containsExactly(callback1, callback2).inOrder(); + + LlmResponse result = + plugin.beforeModelCallback(callbackContext, LlmRequest.builder()).blockingGet(); + assertThat(testCallback1.wasCalled()).isTrue(); + assertThat(testCallback2.wasCalled()).isTrue(); + assertThat(result).isNull(); + } + + @Test + public void addMultipleAfterModelCallbacks_bothEmpty_returnsEmpty() { + LlmResponse initialResponse = LlmResponse.builder().build(); + var testCallback1 = TestCallback.returningEmpty(); + var testCallback2 = TestCallback.returningEmpty(); + AfterModelCallback callback1 = testCallback1.asAfterModelCallback(); + AfterModelCallback callback2 = testCallback2.asAfterModelCallback(); + CallbackPlugin plugin = + CallbackPlugin.builder() + .addAfterModelCallback(callback1) + .addAfterModelCallback(callback2) + .build(); + + assertThat(plugin.getAfterModelCallback()).containsExactly(callback1, callback2).inOrder(); + LlmResponse result = plugin.afterModelCallback(callbackContext, initialResponse).blockingGet(); + assertThat(testCallback1.wasCalled()).isTrue(); + assertThat(testCallback2.wasCalled()).isTrue(); + assertThat(result).isNull(); + } + + @Test + public void addBeforeAgentCallbacks_nullOrEmpty() { + CallbackPlugin plugin = + CallbackPlugin.builder() + .addBeforeAgentCallbacks(null) + .addBeforeAgentCallbacks(ImmutableList.of()) + .build(); + assertThat(plugin.getBeforeAgentCallback()).isEmpty(); + } + + @Test + public void addAfterAgentCallbacks_nullOrEmpty() { + CallbackPlugin plugin = + CallbackPlugin.builder() + .addAfterAgentCallbacks(null) + .addAfterAgentCallbacks(ImmutableList.of()) + .build(); + assertThat(plugin.getAfterAgentCallback()).isEmpty(); + } + + @Test + public void addBeforeModelCallbacks_nullOrEmpty() { + CallbackPlugin plugin = + CallbackPlugin.builder() + .addBeforeModelCallbacks(null) + .addBeforeModelCallbacks(ImmutableList.of()) + .build(); + assertThat(plugin.getBeforeModelCallback()).isEmpty(); + } + + @Test + public void addAfterModelCallbacks_nullOrEmpty() { + CallbackPlugin plugin = + CallbackPlugin.builder() + .addAfterModelCallbacks(null) + .addAfterModelCallbacks(ImmutableList.of()) + .build(); + assertThat(plugin.getAfterModelCallback()).isEmpty(); + } + + @Test + public void addBeforeToolCallbacks_nullOrEmpty() { + CallbackPlugin plugin = + CallbackPlugin.builder() + .addBeforeToolCallbacks(null) + .addBeforeToolCallbacks(ImmutableList.of()) + .build(); + assertThat(plugin.getBeforeToolCallback()).isEmpty(); + } + + @Test + public void addAfterToolCallbacks_nullOrEmpty() { + CallbackPlugin plugin = + CallbackPlugin.builder() + .addAfterToolCallbacks(null) + .addAfterToolCallbacks(ImmutableList.of()) + .build(); + assertThat(plugin.getAfterToolCallback()).isEmpty(); + } +} diff --git a/core/src/test/java/com/google/adk/testing/TestCallback.java b/core/src/test/java/com/google/adk/testing/TestCallback.java new file mode 100644 index 00000000..04f83ed9 --- /dev/null +++ b/core/src/test/java/com/google/adk/testing/TestCallback.java @@ -0,0 +1,164 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.adk.testing; + +import com.google.adk.agents.Callbacks.AfterAgentCallback; +import com.google.adk.agents.Callbacks.AfterAgentCallbackSync; +import com.google.adk.agents.Callbacks.AfterModelCallback; +import com.google.adk.agents.Callbacks.AfterModelCallbackSync; +import com.google.adk.agents.Callbacks.AfterToolCallback; +import com.google.adk.agents.Callbacks.AfterToolCallbackSync; +import com.google.adk.agents.Callbacks.BeforeAgentCallback; +import com.google.adk.agents.Callbacks.BeforeAgentCallbackSync; +import com.google.adk.agents.Callbacks.BeforeModelCallback; +import com.google.adk.agents.Callbacks.BeforeModelCallbackSync; +import com.google.adk.agents.Callbacks.BeforeToolCallback; +import com.google.adk.agents.Callbacks.BeforeToolCallbackSync; +import com.google.adk.events.Event; +import com.google.adk.models.LlmResponse; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; + +/** + * A test helper that wraps an {@link AtomicBoolean} and provides factory methods for creating + * callbacks that update the boolean when called. + * + * @param The type of the result returned by the callback. + */ +public final class TestCallback { + private final AtomicBoolean called = new AtomicBoolean(false); + private final Optional result; + + private TestCallback(Optional result) { + this.result = result; + } + + /** Creates a {@link TestCallback} that returns the given result. */ + public static TestCallback returning(T result) { + return new TestCallback<>(Optional.of(result)); + } + + /** Creates a {@link TestCallback} that returns an empty result. */ + public static TestCallback returningEmpty() { + return new TestCallback<>(Optional.empty()); + } + + /** Returns true if the callback was called. */ + public boolean wasCalled() { + return called.get(); + } + + /** Marks the callback as called. */ + public void markAsCalled() { + called.set(true); + } + + private Maybe callMaybe() { + called.set(true); + return result.map(Maybe::just).orElseGet(Maybe::empty); + } + + private Optional callOptional() { + called.set(true); + return result; + } + + /** + * Returns a {@link Supplier} that marks this callback as called and returns a {@link Flowable} + * with an event containing the given content. + */ + public Supplier> asRunAsyncImplSupplier(Content content) { + return () -> + Flowable.defer( + () -> { + markAsCalled(); + return Flowable.just(Event.builder().content(content).build()); + }); + } + + /** + * Returns a {@link Supplier} that marks this callback as called and returns a {@link Flowable} + */ + public Supplier> asRunAsyncImplSupplier(String contentText) { + return asRunAsyncImplSupplier(Content.fromParts(Part.fromText(contentText))); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is Content. + public BeforeAgentCallback asBeforeAgentCallback() { + return ctx -> (Maybe) callMaybe(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is Content. + public BeforeAgentCallbackSync asBeforeAgentCallbackSync() { + return ctx -> (Optional) callOptional(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is Content. + public AfterAgentCallback asAfterAgentCallback() { + return ctx -> (Maybe) callMaybe(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is Content. + public AfterAgentCallbackSync asAfterAgentCallbackSync() { + return ctx -> (Optional) callOptional(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. + public BeforeModelCallback asBeforeModelCallback() { + return (ctx, req) -> (Maybe) callMaybe(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. + public BeforeModelCallbackSync asBeforeModelCallbackSync() { + return (ctx, req) -> (Optional) callOptional(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. + public AfterModelCallback asAfterModelCallback() { + return (ctx, res) -> (Maybe) callMaybe(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is LlmResponse. + public AfterModelCallbackSync asAfterModelCallbackSync() { + return (ctx, res) -> (Optional) callOptional(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is Map. + public BeforeToolCallback asBeforeToolCallback() { + return (invCtx, tool, toolArgs, toolCtx) -> (Maybe>) callMaybe(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is Map. + public BeforeToolCallbackSync asBeforeToolCallbackSync() { + return (invCtx, tool, toolArgs, toolCtx) -> (Optional>) callOptional(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is Map. + public AfterToolCallback asAfterToolCallback() { + return (invCtx, tool, toolArgs, toolCtx, res) -> (Maybe>) callMaybe(); + } + + @SuppressWarnings("unchecked") // This cast is safe if T is Map. + public AfterToolCallbackSync asAfterToolCallbackSync() { + return (invCtx, tool, toolArgs, toolCtx, res) -> (Optional>) callOptional(); + } +}