Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 53 additions & 28 deletions core/src/main/java/com/google/adk/agents/BaseAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.google.adk.agents;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Arrays.stream;

import com.google.adk.Telemetry;
import com.google.adk.agents.Callbacks.AfterAgentCallback;
Expand All @@ -36,7 +37,6 @@
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Stream;
import org.jspecify.annotations.Nullable;

/** Base class for all agents. */
Expand All @@ -59,8 +59,7 @@ public abstract class BaseAgent {

private final List<? extends BaseAgent> subAgents;

private final Optional<List<? extends BeforeAgentCallback>> beforeAgentCallback;
private final Optional<List<? extends AfterAgentCallback>> afterAgentCallback;
protected final CallbackPlugin callbackPlugin;

/**
* Creates a new BaseAgent.
Expand All @@ -79,12 +78,35 @@ public BaseAgent(
List<? extends BaseAgent> subAgents,
List<? extends BeforeAgentCallback> beforeAgentCallback,
List<? extends AfterAgentCallback> afterAgentCallback) {
this(
name,
description,
subAgents,
CallbackPlugin.builder()
.addBeforeAgentCallbacks(beforeAgentCallback)
.addAfterAgentCallbacks(afterAgentCallback)
.build());
}

/**
* Creates a new BaseAgent.
*
* @param name Unique agent name. Cannot be "user" (reserved).
* @param description Agent purpose.
* @param subAgents Agents managed by this agent.
* @param callbackPlugin The callback plugin for this agent.
*/
protected BaseAgent(
String name,
String description,
List<? extends BaseAgent> subAgents,
CallbackPlugin callbackPlugin) {
this.name = name;
this.description = description;
this.parentAgent = null;
this.subAgents = subAgents != null ? subAgents : ImmutableList.of();
this.beforeAgentCallback = Optional.ofNullable(beforeAgentCallback);
this.afterAgentCallback = Optional.ofNullable(afterAgentCallback);
this.callbackPlugin =
callbackPlugin == null ? CallbackPlugin.builder().build() : callbackPlugin;

// Establish parent relationships for all sub-agents if needed.
for (BaseAgent subAgent : this.subAgents) {
Expand Down Expand Up @@ -172,11 +194,15 @@ public List<? extends BaseAgent> subAgents() {
}

public Optional<List<? extends BeforeAgentCallback>> beforeAgentCallback() {
return beforeAgentCallback;
return Optional.of(callbackPlugin.getBeforeAgentCallback());
}

public Optional<List<? extends AfterAgentCallback>> afterAgentCallback() {
return afterAgentCallback;
return Optional.of(callbackPlugin.getAfterAgentCallback());
}

public Plugin getPlugin() {
return callbackPlugin;
}

/**
Expand Down Expand Up @@ -221,8 +247,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
() ->
callCallback(
beforeCallbacksToFunctions(
invocationContext.pluginManager(),
beforeAgentCallback.orElse(ImmutableList.of())),
invocationContext.pluginManager(), callbackPlugin),
invocationContext)
.flatMapPublisher(
beforeEventOpt -> {
Expand All @@ -239,7 +264,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
callCallback(
afterCallbacksToFunctions(
invocationContext.pluginManager(),
afterAgentCallback.orElse(ImmutableList.of())),
callbackPlugin),
invocationContext)
.flatMapPublisher(Flowable::fromOptional));

Expand All @@ -251,30 +276,27 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
/**
* Converts before-agent callbacks to functions.
*
* @param callbacks Before-agent callbacks.
* @return callback functions.
*/
private ImmutableList<Function<CallbackContext, Maybe<Content>>> beforeCallbacksToFunctions(
Plugin pluginManager, List<? extends BeforeAgentCallback> callbacks) {
return Stream.concat(
Stream.of(ctx -> pluginManager.beforeAgentCallback(this, ctx)),
callbacks.stream()
.map(callback -> (Function<CallbackContext, Maybe<Content>>) callback::call))
Plugin... plugins) {
return stream(plugins)
.map(
p ->
(Function<CallbackContext, Maybe<Content>>) ctx -> p.beforeAgentCallback(this, ctx))
.collect(toImmutableList());
}

/**
* Converts after-agent callbacks to functions.
*
* @param callbacks After-agent callbacks.
* @return callback functions.
*/
private ImmutableList<Function<CallbackContext, Maybe<Content>>> afterCallbacksToFunctions(
Plugin pluginManager, List<? extends AfterAgentCallback> callbacks) {
return Stream.concat(
Stream.of(ctx -> pluginManager.afterAgentCallback(this, ctx)),
callbacks.stream()
.map(callback -> (Function<CallbackContext, Maybe<Content>>) callback::call))
Plugin... plugins) {
return stream(plugins)
.map(
p -> (Function<CallbackContext, Maybe<Content>>) ctx -> p.afterAgentCallback(this, ctx))
.collect(toImmutableList());
}

Expand Down Expand Up @@ -399,8 +421,11 @@ public abstract static class Builder<B extends Builder<B>> {
protected String name;
protected String description;
protected ImmutableList<BaseAgent> subAgents;
protected ImmutableList<BeforeAgentCallback> beforeAgentCallback;
protected ImmutableList<AfterAgentCallback> afterAgentCallback;
protected final CallbackPlugin.Builder callbackPluginBuilder = CallbackPlugin.builder();

protected CallbackPlugin.Builder callbackPluginBuilder() {
return callbackPluginBuilder;
}

/** This is a safe cast to the concrete builder type. */
@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -434,25 +459,25 @@ public B subAgents(BaseAgent... subAgents) {

@CanIgnoreReturnValue
public B beforeAgentCallback(BeforeAgentCallback beforeAgentCallback) {
this.beforeAgentCallback = ImmutableList.of(beforeAgentCallback);
callbackPluginBuilder.addBeforeAgentCallback(beforeAgentCallback);
return self();
}

@CanIgnoreReturnValue
public B beforeAgentCallback(List<Callbacks.BeforeAgentCallbackBase> beforeAgentCallback) {
this.beforeAgentCallback = CallbackUtil.getBeforeAgentCallbacks(beforeAgentCallback);
callbackPluginBuilder.addBeforeAgentCallbacks(beforeAgentCallback);
return self();
}

@CanIgnoreReturnValue
public B afterAgentCallback(AfterAgentCallback afterAgentCallback) {
this.afterAgentCallback = ImmutableList.of(afterAgentCallback);
callbackPluginBuilder.addAfterAgentCallback(afterAgentCallback);
return self();
}

@CanIgnoreReturnValue
public B afterAgentCallback(List<Callbacks.AfterAgentCallbackBase> afterAgentCallback) {
this.afterAgentCallback = CallbackUtil.getAfterAgentCallbacks(afterAgentCallback);
callbackPluginBuilder.addAfterAgentCallbacks(afterAgentCallback);
return self();
}

Expand Down
Loading