1717package com .google .adk .agents ;
1818
1919import static com .google .common .collect .ImmutableList .toImmutableList ;
20+ import static java .util .Arrays .stream ;
2021
2122import com .google .adk .Telemetry ;
2223import com .google .adk .agents .Callbacks .AfterAgentCallback ;
3637import java .util .List ;
3738import java .util .Optional ;
3839import java .util .function .Function ;
39- import java .util .stream .Stream ;
4040import org .jspecify .annotations .Nullable ;
4141
4242/** Base class for all agents. */
@@ -59,8 +59,7 @@ public abstract class BaseAgent {
5959
6060 private final List <? extends BaseAgent > subAgents ;
6161
62- private final Optional <List <? extends BeforeAgentCallback >> beforeAgentCallback ;
63- private final Optional <List <? extends AfterAgentCallback >> afterAgentCallback ;
62+ protected final CallbackPlugin callbackPlugin ;
6463
6564 /**
6665 * Creates a new BaseAgent.
@@ -79,12 +78,35 @@ public BaseAgent(
7978 List <? extends BaseAgent > subAgents ,
8079 List <? extends BeforeAgentCallback > beforeAgentCallback ,
8180 List <? extends AfterAgentCallback > afterAgentCallback ) {
81+ this (
82+ name ,
83+ description ,
84+ subAgents ,
85+ CallbackPlugin .builder ()
86+ .addBeforeAgentCallbacks (beforeAgentCallback )
87+ .addAfterAgentCallbacks (afterAgentCallback )
88+ .build ());
89+ }
90+
91+ /**
92+ * Creates a new BaseAgent.
93+ *
94+ * @param name Unique agent name. Cannot be "user" (reserved).
95+ * @param description Agent purpose.
96+ * @param subAgents Agents managed by this agent.
97+ * @param callbackPlugin The callback plugin for this agent.
98+ */
99+ protected BaseAgent (
100+ String name ,
101+ String description ,
102+ List <? extends BaseAgent > subAgents ,
103+ CallbackPlugin callbackPlugin ) {
82104 this .name = name ;
83105 this .description = description ;
84106 this .parentAgent = null ;
85107 this .subAgents = subAgents != null ? subAgents : ImmutableList .of ();
86- this .beforeAgentCallback = Optional . ofNullable ( beforeAgentCallback );
87- this . afterAgentCallback = Optional . ofNullable ( afterAgentCallback ) ;
108+ this .callbackPlugin =
109+ callbackPlugin == null ? CallbackPlugin . builder (). build () : callbackPlugin ;
88110
89111 // Establish parent relationships for all sub-agents if needed.
90112 for (BaseAgent subAgent : this .subAgents ) {
@@ -172,11 +194,15 @@ public List<? extends BaseAgent> subAgents() {
172194 }
173195
174196 public Optional <List <? extends BeforeAgentCallback >> beforeAgentCallback () {
175- return beforeAgentCallback ;
197+ return Optional . of ( callbackPlugin . getBeforeAgentCallback ()) ;
176198 }
177199
178200 public Optional <List <? extends AfterAgentCallback >> afterAgentCallback () {
179- return afterAgentCallback ;
201+ return Optional .of (callbackPlugin .getAfterAgentCallback ());
202+ }
203+
204+ public Plugin getPlugin () {
205+ return callbackPlugin ;
180206 }
181207
182208 /**
@@ -221,8 +247,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
221247 () ->
222248 callCallback (
223249 beforeCallbacksToFunctions (
224- invocationContext .pluginManager (),
225- beforeAgentCallback .orElse (ImmutableList .of ())),
250+ invocationContext .pluginManager (), callbackPlugin ),
226251 invocationContext )
227252 .flatMapPublisher (
228253 beforeEventOpt -> {
@@ -239,7 +264,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
239264 callCallback (
240265 afterCallbacksToFunctions (
241266 invocationContext .pluginManager (),
242- afterAgentCallback . orElse ( ImmutableList . of ()) ),
267+ callbackPlugin ),
243268 invocationContext )
244269 .flatMapPublisher (Flowable ::fromOptional ));
245270
@@ -251,30 +276,27 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
251276 /**
252277 * Converts before-agent callbacks to functions.
253278 *
254- * @param callbacks Before-agent callbacks.
255279 * @return callback functions.
256280 */
257281 private ImmutableList <Function <CallbackContext , Maybe <Content >>> beforeCallbacksToFunctions (
258- Plugin pluginManager , List <? extends BeforeAgentCallback > callbacks ) {
259- return Stream . concat (
260- Stream . of ( ctx -> pluginManager . beforeAgentCallback ( this , ctx )),
261- callbacks . stream ()
262- . map ( callback -> ( Function <CallbackContext , Maybe <Content >>) callback :: call ))
282+ Plugin ... plugins ) {
283+ return stream ( plugins )
284+ . map (
285+ p ->
286+ ( Function <CallbackContext , Maybe <Content >>) ctx -> p . beforeAgentCallback ( this , ctx ))
263287 .collect (toImmutableList ());
264288 }
265289
266290 /**
267291 * Converts after-agent callbacks to functions.
268292 *
269- * @param callbacks After-agent callbacks.
270293 * @return callback functions.
271294 */
272295 private ImmutableList <Function <CallbackContext , Maybe <Content >>> afterCallbacksToFunctions (
273- Plugin pluginManager , List <? extends AfterAgentCallback > callbacks ) {
274- return Stream .concat (
275- Stream .of (ctx -> pluginManager .afterAgentCallback (this , ctx )),
276- callbacks .stream ()
277- .map (callback -> (Function <CallbackContext , Maybe <Content >>) callback ::call ))
296+ Plugin ... plugins ) {
297+ return stream (plugins )
298+ .map (
299+ p -> (Function <CallbackContext , Maybe <Content >>) ctx -> p .afterAgentCallback (this , ctx ))
278300 .collect (toImmutableList ());
279301 }
280302
@@ -399,8 +421,11 @@ public abstract static class Builder<B extends Builder<B>> {
399421 protected String name ;
400422 protected String description ;
401423 protected ImmutableList <BaseAgent > subAgents ;
402- protected ImmutableList <BeforeAgentCallback > beforeAgentCallback ;
403- protected ImmutableList <AfterAgentCallback > afterAgentCallback ;
424+ protected final CallbackPlugin .Builder callbackPluginBuilder = CallbackPlugin .builder ();
425+
426+ protected CallbackPlugin .Builder callbackPluginBuilder () {
427+ return callbackPluginBuilder ;
428+ }
404429
405430 /** This is a safe cast to the concrete builder type. */
406431 @ SuppressWarnings ("unchecked" )
@@ -434,25 +459,25 @@ public B subAgents(BaseAgent... subAgents) {
434459
435460 @ CanIgnoreReturnValue
436461 public B beforeAgentCallback (BeforeAgentCallback beforeAgentCallback ) {
437- this . beforeAgentCallback = ImmutableList . of (beforeAgentCallback );
462+ callbackPluginBuilder . addBeforeAgentCallback (beforeAgentCallback );
438463 return self ();
439464 }
440465
441466 @ CanIgnoreReturnValue
442467 public B beforeAgentCallback (List <Callbacks .BeforeAgentCallbackBase > beforeAgentCallback ) {
443- this . beforeAgentCallback = CallbackUtil . getBeforeAgentCallbacks (beforeAgentCallback );
468+ callbackPluginBuilder . addBeforeAgentCallbacks (beforeAgentCallback );
444469 return self ();
445470 }
446471
447472 @ CanIgnoreReturnValue
448473 public B afterAgentCallback (AfterAgentCallback afterAgentCallback ) {
449- this . afterAgentCallback = ImmutableList . of (afterAgentCallback );
474+ callbackPluginBuilder . addAfterAgentCallback (afterAgentCallback );
450475 return self ();
451476 }
452477
453478 @ CanIgnoreReturnValue
454479 public B afterAgentCallback (List <Callbacks .AfterAgentCallbackBase > afterAgentCallback ) {
455- this . afterAgentCallback = CallbackUtil . getAfterAgentCallbacks (afterAgentCallback );
480+ callbackPluginBuilder . addAfterAgentCallbacks (afterAgentCallback );
456481 return self ();
457482 }
458483
0 commit comments