Skip to content

Commit acb3da4

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor: Introducing a CallbackPlugin to wrap the old style Callbacks
The goal is to unify the processing of Plugins and Callbacks. We should consider depercating and removing the old Callbacks. There are a bunch of cyclical dependencies caused by requests back to the agent to get specific Callbacks. The next step will be to augmet the InvocationContext's PluginManager with the appropriate agent specific callbacks PiperOrigin-RevId: 853407402
1 parent 864d606 commit acb3da4

File tree

10 files changed

+1192
-237
lines changed

10 files changed

+1192
-237
lines changed

core/src/main/java/com/google/adk/agents/BaseAgent.java

Lines changed: 64 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package com.google.adk.agents;
1818

1919
import static com.google.common.collect.ImmutableList.toImmutableList;
20+
import static java.util.Arrays.stream;
2021

2122
import com.google.adk.Telemetry;
2223
import com.google.adk.agents.Callbacks.AfterAgentCallback;
@@ -59,8 +60,7 @@ public abstract class BaseAgent {
5960

6061
private final List<? extends BaseAgent> subAgents;
6162

62-
private final Optional<List<? extends BeforeAgentCallback>> beforeAgentCallback;
63-
private final Optional<List<? extends AfterAgentCallback>> afterAgentCallback;
63+
protected final CallbackPlugin callbackPlugin;
6464

6565
/**
6666
* Creates a new BaseAgent.
@@ -77,21 +77,53 @@ public BaseAgent(
7777
String name,
7878
String description,
7979
List<? extends BaseAgent> subAgents,
80-
List<? extends BeforeAgentCallback> beforeAgentCallback,
81-
List<? extends AfterAgentCallback> afterAgentCallback) {
80+
@Nullable List<? extends BeforeAgentCallback> beforeAgentCallback,
81+
@Nullable List<? extends AfterAgentCallback> afterAgentCallback) {
82+
this(
83+
name,
84+
description,
85+
subAgents,
86+
createCallbackPlugin(beforeAgentCallback, afterAgentCallback));
87+
}
88+
89+
/**
90+
* Creates a new BaseAgent.
91+
*
92+
* @param name Unique agent name. Cannot be "user" (reserved).
93+
* @param description Agent purpose.
94+
* @param subAgents Agents managed by this agent.
95+
* @param callbackPlugin The callback plugin for this agent.
96+
*/
97+
protected BaseAgent(
98+
String name,
99+
String description,
100+
List<? extends BaseAgent> subAgents,
101+
CallbackPlugin callbackPlugin) {
82102
this.name = name;
83103
this.description = description;
84104
this.parentAgent = null;
85105
this.subAgents = subAgents != null ? subAgents : ImmutableList.of();
86-
this.beforeAgentCallback = Optional.ofNullable(beforeAgentCallback);
87-
this.afterAgentCallback = Optional.ofNullable(afterAgentCallback);
106+
this.callbackPlugin =
107+
callbackPlugin == null ? CallbackPlugin.builder().build() : callbackPlugin;
88108

89109
// Establish parent relationships for all sub-agents if needed.
90110
for (BaseAgent subAgent : this.subAgents) {
91111
subAgent.parentAgent(this);
92112
}
93113
}
94114

115+
/** Creates a {@link CallbackPlugin} from lists of before and after agent callbacks. */
116+
private static CallbackPlugin createCallbackPlugin(
117+
@Nullable List<? extends BeforeAgentCallback> beforeAgentCallbacks,
118+
@Nullable List<? extends AfterAgentCallback> afterAgentCallbacks) {
119+
CallbackPlugin.Builder builder = CallbackPlugin.builder();
120+
Stream.ofNullable(beforeAgentCallbacks).flatMap(List::stream).forEach(builder::addCallback);
121+
Optional.ofNullable(afterAgentCallbacks).stream()
122+
.flatMap(List::stream)
123+
.forEach(builder::addCallback);
124+
return builder.build();
125+
}
126+
95127
/**
96128
* Gets the agent's unique name.
97129
*
@@ -172,11 +204,15 @@ public List<? extends BaseAgent> subAgents() {
172204
}
173205

174206
public Optional<List<? extends BeforeAgentCallback>> beforeAgentCallback() {
175-
return beforeAgentCallback;
207+
return Optional.of(callbackPlugin.getBeforeAgentCallback());
176208
}
177209

178210
public Optional<List<? extends AfterAgentCallback>> afterAgentCallback() {
179-
return afterAgentCallback;
211+
return Optional.of(callbackPlugin.getAfterAgentCallback());
212+
}
213+
214+
public Plugin getPlugin() {
215+
return callbackPlugin;
180216
}
181217

182218
/**
@@ -221,8 +257,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
221257
() ->
222258
callCallback(
223259
beforeCallbacksToFunctions(
224-
invocationContext.pluginManager(),
225-
beforeAgentCallback.orElse(ImmutableList.of())),
260+
invocationContext.pluginManager(), callbackPlugin),
226261
invocationContext)
227262
.flatMapPublisher(
228263
beforeEventOpt -> {
@@ -239,7 +274,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
239274
callCallback(
240275
afterCallbacksToFunctions(
241276
invocationContext.pluginManager(),
242-
afterAgentCallback.orElse(ImmutableList.of())),
277+
callbackPlugin),
243278
invocationContext)
244279
.flatMapPublisher(Flowable::fromOptional));
245280

@@ -251,30 +286,27 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
251286
/**
252287
* Converts before-agent callbacks to functions.
253288
*
254-
* @param callbacks Before-agent callbacks.
255289
* @return callback functions.
256290
*/
257291
private ImmutableList<Function<CallbackContext, Maybe<Content>>> beforeCallbacksToFunctions(
258-
Plugin pluginManager, List<? extends BeforeAgentCallback> callbacks) {
259-
return Stream.concat(
260-
Stream.of(ctx -> pluginManager.beforeAgentCallback(this, ctx)),
261-
callbacks.stream()
262-
.map(callback -> (Function<CallbackContext, Maybe<Content>>) callback::call))
292+
Plugin... plugins) {
293+
return stream(plugins)
294+
.map(
295+
p ->
296+
(Function<CallbackContext, Maybe<Content>>) ctx -> p.beforeAgentCallback(this, ctx))
263297
.collect(toImmutableList());
264298
}
265299

266300
/**
267301
* Converts after-agent callbacks to functions.
268302
*
269-
* @param callbacks After-agent callbacks.
270303
* @return callback functions.
271304
*/
272305
private ImmutableList<Function<CallbackContext, Maybe<Content>>> afterCallbacksToFunctions(
273-
Plugin pluginManager, List<? extends AfterAgentCallback> callbacks) {
274-
return Stream.concat(
275-
Stream.of(ctx -> pluginManager.afterAgentCallback(this, ctx)),
276-
callbacks.stream()
277-
.map(callback -> (Function<CallbackContext, Maybe<Content>>) callback::call))
306+
Plugin... plugins) {
307+
return stream(plugins)
308+
.map(
309+
p -> (Function<CallbackContext, Maybe<Content>>) ctx -> p.afterAgentCallback(this, ctx))
278310
.collect(toImmutableList());
279311
}
280312

@@ -399,8 +431,11 @@ public abstract static class Builder<B extends Builder<B>> {
399431
protected String name;
400432
protected String description;
401433
protected ImmutableList<BaseAgent> subAgents;
402-
protected ImmutableList<BeforeAgentCallback> beforeAgentCallback;
403-
protected ImmutableList<AfterAgentCallback> afterAgentCallback;
434+
protected final CallbackPlugin.Builder callbackPluginBuilder = CallbackPlugin.builder();
435+
436+
protected CallbackPlugin.Builder callbackPluginBuilder() {
437+
return callbackPluginBuilder;
438+
}
404439

405440
/** This is a safe cast to the concrete builder type. */
406441
@SuppressWarnings("unchecked")
@@ -434,25 +469,25 @@ public B subAgents(BaseAgent... subAgents) {
434469

435470
@CanIgnoreReturnValue
436471
public B beforeAgentCallback(BeforeAgentCallback beforeAgentCallback) {
437-
this.beforeAgentCallback = ImmutableList.of(beforeAgentCallback);
472+
callbackPluginBuilder.addBeforeAgentCallback(beforeAgentCallback);
438473
return self();
439474
}
440475

441476
@CanIgnoreReturnValue
442477
public B beforeAgentCallback(List<Callbacks.BeforeAgentCallbackBase> beforeAgentCallback) {
443-
this.beforeAgentCallback = CallbackUtil.getBeforeAgentCallbacks(beforeAgentCallback);
478+
beforeAgentCallback.forEach(callbackPluginBuilder::addCallback);
444479
return self();
445480
}
446481

447482
@CanIgnoreReturnValue
448483
public B afterAgentCallback(AfterAgentCallback afterAgentCallback) {
449-
this.afterAgentCallback = ImmutableList.of(afterAgentCallback);
484+
callbackPluginBuilder.addAfterAgentCallback(afterAgentCallback);
450485
return self();
451486
}
452487

453488
@CanIgnoreReturnValue
454489
public B afterAgentCallback(List<Callbacks.AfterAgentCallbackBase> afterAgentCallback) {
455-
this.afterAgentCallback = CallbackUtil.getAfterAgentCallbacks(afterAgentCallback);
490+
afterAgentCallback.forEach(callbackPluginBuilder::addCallback);
456491
return self();
457492
}
458493

0 commit comments

Comments
 (0)