1717package com .google .adk .agents ;
1818
1919import static com .google .common .collect .ImmutableList .toImmutableList ;
20+ import static java .util .Arrays .stream ;
2021
2122import com .google .adk .Telemetry ;
2223import com .google .adk .agents .Callbacks .AfterAgentCallback ;
@@ -59,8 +60,7 @@ public abstract class BaseAgent {
5960
6061 private final List <? extends BaseAgent > subAgents ;
6162
62- private final Optional <List <? extends BeforeAgentCallback >> beforeAgentCallback ;
63- private final Optional <List <? extends AfterAgentCallback >> afterAgentCallback ;
63+ protected final CallbackPlugin callbackPlugin ;
6464
6565 /**
6666 * Creates a new BaseAgent.
@@ -77,21 +77,53 @@ public BaseAgent(
7777 String name ,
7878 String description ,
7979 List <? extends BaseAgent > subAgents ,
80- List <? extends BeforeAgentCallback > beforeAgentCallback ,
81- List <? extends AfterAgentCallback > afterAgentCallback ) {
80+ @ Nullable List <? extends BeforeAgentCallback > beforeAgentCallback ,
81+ @ Nullable List <? extends AfterAgentCallback > afterAgentCallback ) {
82+ this (
83+ name ,
84+ description ,
85+ subAgents ,
86+ createCallbackPlugin (beforeAgentCallback , afterAgentCallback ));
87+ }
88+
89+ /**
90+ * Creates a new BaseAgent.
91+ *
92+ * @param name Unique agent name. Cannot be "user" (reserved).
93+ * @param description Agent purpose.
94+ * @param subAgents Agents managed by this agent.
95+ * @param callbackPlugin The callback plugin for this agent.
96+ */
97+ protected BaseAgent (
98+ String name ,
99+ String description ,
100+ List <? extends BaseAgent > subAgents ,
101+ CallbackPlugin callbackPlugin ) {
82102 this .name = name ;
83103 this .description = description ;
84104 this .parentAgent = null ;
85105 this .subAgents = subAgents != null ? subAgents : ImmutableList .of ();
86- this .beforeAgentCallback = Optional . ofNullable ( beforeAgentCallback );
87- this . afterAgentCallback = Optional . ofNullable ( afterAgentCallback ) ;
106+ this .callbackPlugin =
107+ callbackPlugin == null ? CallbackPlugin . builder (). build () : callbackPlugin ;
88108
89109 // Establish parent relationships for all sub-agents if needed.
90110 for (BaseAgent subAgent : this .subAgents ) {
91111 subAgent .parentAgent (this );
92112 }
93113 }
94114
115+ /** Creates a {@link CallbackPlugin} from lists of before and after agent callbacks. */
116+ private static CallbackPlugin createCallbackPlugin (
117+ @ Nullable List <? extends BeforeAgentCallback > beforeAgentCallbacks ,
118+ @ Nullable List <? extends AfterAgentCallback > afterAgentCallbacks ) {
119+ CallbackPlugin .Builder builder = CallbackPlugin .builder ();
120+ Stream .ofNullable (beforeAgentCallbacks ).flatMap (List ::stream ).forEach (builder ::addCallback );
121+ Optional .ofNullable (afterAgentCallbacks ).stream ()
122+ .flatMap (List ::stream )
123+ .forEach (builder ::addCallback );
124+ return builder .build ();
125+ }
126+
95127 /**
96128 * Gets the agent's unique name.
97129 *
@@ -172,11 +204,15 @@ public List<? extends BaseAgent> subAgents() {
172204 }
173205
174206 public Optional <List <? extends BeforeAgentCallback >> beforeAgentCallback () {
175- return beforeAgentCallback ;
207+ return Optional . of ( callbackPlugin . getBeforeAgentCallback ()) ;
176208 }
177209
178210 public Optional <List <? extends AfterAgentCallback >> afterAgentCallback () {
179- return afterAgentCallback ;
211+ return Optional .of (callbackPlugin .getAfterAgentCallback ());
212+ }
213+
214+ public Plugin getPlugin () {
215+ return callbackPlugin ;
180216 }
181217
182218 /**
@@ -221,8 +257,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
221257 () ->
222258 callCallback (
223259 beforeCallbacksToFunctions (
224- invocationContext .pluginManager (),
225- beforeAgentCallback .orElse (ImmutableList .of ())),
260+ invocationContext .pluginManager (), callbackPlugin ),
226261 invocationContext )
227262 .flatMapPublisher (
228263 beforeEventOpt -> {
@@ -239,7 +274,7 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
239274 callCallback (
240275 afterCallbacksToFunctions (
241276 invocationContext .pluginManager (),
242- afterAgentCallback . orElse ( ImmutableList . of ()) ),
277+ callbackPlugin ),
243278 invocationContext )
244279 .flatMapPublisher (Flowable ::fromOptional ));
245280
@@ -251,30 +286,27 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
251286 /**
252287 * Converts before-agent callbacks to functions.
253288 *
254- * @param callbacks Before-agent callbacks.
255289 * @return callback functions.
256290 */
257291 private ImmutableList <Function <CallbackContext , Maybe <Content >>> beforeCallbacksToFunctions (
258- Plugin pluginManager , List <? extends BeforeAgentCallback > callbacks ) {
259- return Stream . concat (
260- Stream . of ( ctx -> pluginManager . beforeAgentCallback ( this , ctx )),
261- callbacks . stream ()
262- . map ( callback -> ( Function <CallbackContext , Maybe <Content >>) callback :: call ))
292+ Plugin ... plugins ) {
293+ return stream ( plugins )
294+ . map (
295+ p ->
296+ ( Function <CallbackContext , Maybe <Content >>) ctx -> p . beforeAgentCallback ( this , ctx ))
263297 .collect (toImmutableList ());
264298 }
265299
266300 /**
267301 * Converts after-agent callbacks to functions.
268302 *
269- * @param callbacks After-agent callbacks.
270303 * @return callback functions.
271304 */
272305 private ImmutableList <Function <CallbackContext , Maybe <Content >>> afterCallbacksToFunctions (
273- Plugin pluginManager , List <? extends AfterAgentCallback > callbacks ) {
274- return Stream .concat (
275- Stream .of (ctx -> pluginManager .afterAgentCallback (this , ctx )),
276- callbacks .stream ()
277- .map (callback -> (Function <CallbackContext , Maybe <Content >>) callback ::call ))
306+ Plugin ... plugins ) {
307+ return stream (plugins )
308+ .map (
309+ p -> (Function <CallbackContext , Maybe <Content >>) ctx -> p .afterAgentCallback (this , ctx ))
278310 .collect (toImmutableList ());
279311 }
280312
@@ -399,8 +431,11 @@ public abstract static class Builder<B extends Builder<B>> {
399431 protected String name ;
400432 protected String description ;
401433 protected ImmutableList <BaseAgent > subAgents ;
402- protected ImmutableList <BeforeAgentCallback > beforeAgentCallback ;
403- protected ImmutableList <AfterAgentCallback > afterAgentCallback ;
434+ protected final CallbackPlugin .Builder callbackPluginBuilder = CallbackPlugin .builder ();
435+
436+ protected CallbackPlugin .Builder callbackPluginBuilder () {
437+ return callbackPluginBuilder ;
438+ }
404439
405440 /** This is a safe cast to the concrete builder type. */
406441 @ SuppressWarnings ("unchecked" )
@@ -434,25 +469,25 @@ public B subAgents(BaseAgent... subAgents) {
434469
435470 @ CanIgnoreReturnValue
436471 public B beforeAgentCallback (BeforeAgentCallback beforeAgentCallback ) {
437- this . beforeAgentCallback = ImmutableList . of (beforeAgentCallback );
472+ callbackPluginBuilder . addBeforeAgentCallback (beforeAgentCallback );
438473 return self ();
439474 }
440475
441476 @ CanIgnoreReturnValue
442477 public B beforeAgentCallback (List <Callbacks .BeforeAgentCallbackBase > beforeAgentCallback ) {
443- this . beforeAgentCallback = CallbackUtil . getBeforeAgentCallbacks ( beforeAgentCallback );
478+ beforeAgentCallback . forEach ( callbackPluginBuilder :: addCallback );
444479 return self ();
445480 }
446481
447482 @ CanIgnoreReturnValue
448483 public B afterAgentCallback (AfterAgentCallback afterAgentCallback ) {
449- this . afterAgentCallback = ImmutableList . of (afterAgentCallback );
484+ callbackPluginBuilder . addAfterAgentCallback (afterAgentCallback );
450485 return self ();
451486 }
452487
453488 @ CanIgnoreReturnValue
454489 public B afterAgentCallback (List <Callbacks .AfterAgentCallbackBase > afterAgentCallback ) {
455- this . afterAgentCallback = CallbackUtil . getAfterAgentCallbacks ( afterAgentCallback );
490+ afterAgentCallback . forEach ( callbackPluginBuilder :: addCallback );
456491 return self ();
457492 }
458493
0 commit comments