Skip to content

Commit 11676d3

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor: Introducing a CallbackPlugin to wrap the old style Callbacks
The goal is to unify the processing of Plugins and Callbacks. We should consider depercating and removing the old Callbacks. There are a bunch of cyclical dependencies caused by requests back to the agent to get specific Callbacks. The next step will be to augmet the InvocationContext's PluginManager with the appropriate agent specific callbacks PiperOrigin-RevId: 853407402
1 parent dace210 commit 11676d3

File tree

10 files changed

+1306
-234
lines changed

10 files changed

+1306
-234
lines changed

core/src/main/java/com/google/adk/agents/BaseAgent.java

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package com.google.adk.agents;
1818

1919
import static com.google.common.collect.ImmutableList.toImmutableList;
20+
import static java.util.Arrays.stream;
2021

2122
import com.google.adk.Telemetry;
2223
import com.google.adk.agents.Callbacks.AfterAgentCallback;
@@ -36,7 +37,6 @@
3637
import java.util.List;
3738
import java.util.Optional;
3839
import java.util.function.Function;
39-
import java.util.stream.Stream;
4040
import org.jspecify.annotations.Nullable;
4141

4242
/** Base class for all agents. */
@@ -59,8 +59,7 @@ public abstract class BaseAgent {
5959

6060
private final List<? extends BaseAgent> subAgents;
6161

62-
private final Optional<List<? extends BeforeAgentCallback>> beforeAgentCallback;
63-
private final Optional<List<? extends AfterAgentCallback>> afterAgentCallback;
62+
protected final CallbackPlugin callbackPlugin;
6463

6564
/**
6665
* Creates a new BaseAgent.
@@ -79,12 +78,35 @@ public BaseAgent(
7978
List<? extends BaseAgent> subAgents,
8079
List<? extends BeforeAgentCallback> beforeAgentCallback,
8180
List<? extends AfterAgentCallback> afterAgentCallback) {
81+
this(
82+
name,
83+
description,
84+
subAgents,
85+
CallbackPlugin.builder()
86+
.addBeforeAgentCallbacks(beforeAgentCallback)
87+
.addAfterAgentCallbacks(afterAgentCallback)
88+
.build());
89+
}
90+
91+
/**
92+
* Creates a new BaseAgent.
93+
*
94+
* @param name Unique agent name. Cannot be "user" (reserved).
95+
* @param description Agent purpose.
96+
* @param subAgents Agents managed by this agent.
97+
* @param callbackPlugin The callback plugin for this agent.
98+
*/
99+
protected BaseAgent(
100+
String name,
101+
String description,
102+
List<? extends BaseAgent> subAgents,
103+
CallbackPlugin callbackPlugin) {
82104
this.name = name;
83105
this.description = description;
84106
this.parentAgent = null;
85107
this.subAgents = subAgents != null ? subAgents : ImmutableList.of();
86-
this.beforeAgentCallback = Optional.ofNullable(beforeAgentCallback);
87-
this.afterAgentCallback = Optional.ofNullable(afterAgentCallback);
108+
this.callbackPlugin =
109+
callbackPlugin == null ? CallbackPlugin.builder().build() : callbackPlugin;
88110

89111
// Establish parent relationships for all sub-agents if needed.
90112
for (BaseAgent subAgent : this.subAgents) {
@@ -172,11 +194,15 @@ public List<? extends BaseAgent> subAgents() {
172194
}
173195

174196
public Optional<List<? extends BeforeAgentCallback>> beforeAgentCallback() {
175-
return beforeAgentCallback;
197+
return Optional.of(callbackPlugin.getBeforeAgentCallback());
176198
}
177199

178200
public Optional<List<? extends AfterAgentCallback>> afterAgentCallback() {
179-
return afterAgentCallback;
201+
return Optional.of(callbackPlugin.getAfterAgentCallback());
202+
}
203+
204+
public Plugin getPlugin() {
205+
return callbackPlugin;
180206
}
181207

182208
/**
@@ -221,8 +247,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
221247
() ->
222248
callCallback(
223249
beforeCallbacksToFunctions(
224-
invocationContext.pluginManager(),
225-
beforeAgentCallback.orElse(ImmutableList.of())),
250+
invocationContext.pluginManager(), callbackPlugin),
226251
invocationContext)
227252
.flatMapPublisher(
228253
beforeEventOpt -> {
@@ -239,7 +264,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
239264
callCallback(
240265
afterCallbacksToFunctions(
241266
invocationContext.pluginManager(),
242-
afterAgentCallback.orElse(ImmutableList.of())),
267+
callbackPlugin),
243268
invocationContext)
244269
.flatMapPublisher(Flowable::fromOptional));
245270

@@ -251,30 +276,27 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
251276
/**
252277
* Converts before-agent callbacks to functions.
253278
*
254-
* @param callbacks Before-agent callbacks.
255279
* @return callback functions.
256280
*/
257281
private ImmutableList<Function<CallbackContext, Maybe<Content>>> beforeCallbacksToFunctions(
258-
Plugin pluginManager, List<? extends BeforeAgentCallback> callbacks) {
259-
return Stream.concat(
260-
Stream.of(ctx -> pluginManager.beforeAgentCallback(this, ctx)),
261-
callbacks.stream()
262-
.map(callback -> (Function<CallbackContext, Maybe<Content>>) callback::call))
282+
Plugin... plugins) {
283+
return stream(plugins)
284+
.map(
285+
p ->
286+
(Function<CallbackContext, Maybe<Content>>) ctx -> p.beforeAgentCallback(this, ctx))
263287
.collect(toImmutableList());
264288
}
265289

266290
/**
267291
* Converts after-agent callbacks to functions.
268292
*
269-
* @param callbacks After-agent callbacks.
270293
* @return callback functions.
271294
*/
272295
private ImmutableList<Function<CallbackContext, Maybe<Content>>> afterCallbacksToFunctions(
273-
Plugin pluginManager, List<? extends AfterAgentCallback> callbacks) {
274-
return Stream.concat(
275-
Stream.of(ctx -> pluginManager.afterAgentCallback(this, ctx)),
276-
callbacks.stream()
277-
.map(callback -> (Function<CallbackContext, Maybe<Content>>) callback::call))
296+
Plugin... plugins) {
297+
return stream(plugins)
298+
.map(
299+
p -> (Function<CallbackContext, Maybe<Content>>) ctx -> p.afterAgentCallback(this, ctx))
278300
.collect(toImmutableList());
279301
}
280302

@@ -399,8 +421,11 @@ public abstract static class Builder<B extends Builder<B>> {
399421
protected String name;
400422
protected String description;
401423
protected ImmutableList<BaseAgent> subAgents;
402-
protected ImmutableList<BeforeAgentCallback> beforeAgentCallback;
403-
protected ImmutableList<AfterAgentCallback> afterAgentCallback;
424+
protected final CallbackPlugin.Builder callbackPluginBuilder = CallbackPlugin.builder();
425+
426+
protected CallbackPlugin.Builder callbackPluginBuilder() {
427+
return callbackPluginBuilder;
428+
}
404429

405430
/** This is a safe cast to the concrete builder type. */
406431
@SuppressWarnings("unchecked")
@@ -434,25 +459,25 @@ public B subAgents(BaseAgent... subAgents) {
434459

435460
@CanIgnoreReturnValue
436461
public B beforeAgentCallback(BeforeAgentCallback beforeAgentCallback) {
437-
this.beforeAgentCallback = ImmutableList.of(beforeAgentCallback);
462+
callbackPluginBuilder.addBeforeAgentCallback(beforeAgentCallback);
438463
return self();
439464
}
440465

441466
@CanIgnoreReturnValue
442467
public B beforeAgentCallback(List<Callbacks.BeforeAgentCallbackBase> beforeAgentCallback) {
443-
this.beforeAgentCallback = CallbackUtil.getBeforeAgentCallbacks(beforeAgentCallback);
468+
callbackPluginBuilder.addBeforeAgentCallbacks(beforeAgentCallback);
444469
return self();
445470
}
446471

447472
@CanIgnoreReturnValue
448473
public B afterAgentCallback(AfterAgentCallback afterAgentCallback) {
449-
this.afterAgentCallback = ImmutableList.of(afterAgentCallback);
474+
callbackPluginBuilder.addAfterAgentCallback(afterAgentCallback);
450475
return self();
451476
}
452477

453478
@CanIgnoreReturnValue
454479
public B afterAgentCallback(List<Callbacks.AfterAgentCallbackBase> afterAgentCallback) {
455-
this.afterAgentCallback = CallbackUtil.getAfterAgentCallbacks(afterAgentCallback);
480+
callbackPluginBuilder.addAfterAgentCallbacks(afterAgentCallback);
456481
return self();
457482
}
458483

0 commit comments

Comments
 (0)