Skip to content

Commit 7f2484d

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor: Using a new InvocationContext.combinedPlugin() to simplify code
PiperOrigin-RevId: 855427656
1 parent 229654e commit 7f2484d

File tree

6 files changed

+184
-174
lines changed

6 files changed

+184
-174
lines changed

core/src/main/java/com/google/adk/agents/BaseAgent.java

Lines changed: 34 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616

1717
package com.google.adk.agents;
1818

19-
import static com.google.common.collect.ImmutableList.toImmutableList;
20-
import static java.util.Arrays.stream;
21-
2219
import com.google.adk.Telemetry;
2320
import com.google.adk.agents.Callbacks.AfterAgentCallback;
2421
import com.google.adk.agents.Callbacks.BeforeAgentCallback;
@@ -255,10 +252,11 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
255252
spanContext,
256253
span,
257254
() ->
258-
callCallback(
259-
beforeCallbacksToFunctions(
260-
invocationContext.pluginManager(), callbackPlugin),
255+
processAgentCallbackResult(
256+
ctx -> invocationContext.combinedPlugin().beforeAgentCallback(this, ctx),
261257
invocationContext)
258+
.map(Optional::of)
259+
.switchIfEmpty(Single.just(Optional.empty()))
262260
.flatMapPublisher(
263261
beforeEventOpt -> {
264262
if (invocationContext.endInvocation()) {
@@ -271,11 +269,14 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
271269
Flowable<Event> afterEvents =
272270
Flowable.defer(
273271
() ->
274-
callCallback(
275-
afterCallbacksToFunctions(
276-
invocationContext.pluginManager(),
277-
callbackPlugin),
272+
processAgentCallbackResult(
273+
ctx ->
274+
invocationContext
275+
.combinedPlugin()
276+
.afterAgentCallback(this, ctx),
278277
invocationContext)
278+
.map(Optional::of)
279+
.switchIfEmpty(Single.just(Optional.empty()))
279280
.flatMapPublisher(Flowable::fromOptional));
280281

281282
return Flowable.concat(beforeEvents, mainEvents, afterEvents);
@@ -284,73 +285,32 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
284285
}
285286

286287
/**
287-
* Converts before-agent callbacks to functions.
288+
* Processes the result of an agent callback, creating an {@link Event} if necessary.
288289
*
289-
* @return callback functions.
290+
* @param agentCallback The callback function.
291+
* @param invocationContext The current invocation context.
292+
* @return A {@link Maybe} emitting an {@link Event} if one is produced, or empty otherwise.
290293
*/
291-
private ImmutableList<Function<CallbackContext, Maybe<Content>>> beforeCallbacksToFunctions(
292-
Plugin... plugins) {
293-
return stream(plugins)
294-
.map(
295-
p ->
296-
(Function<CallbackContext, Maybe<Content>>) ctx -> p.beforeAgentCallback(this, ctx))
297-
.collect(toImmutableList());
298-
}
299-
300-
/**
301-
* Converts after-agent callbacks to functions.
302-
*
303-
* @return callback functions.
304-
*/
305-
private ImmutableList<Function<CallbackContext, Maybe<Content>>> afterCallbacksToFunctions(
306-
Plugin... plugins) {
307-
return stream(plugins)
308-
.map(
309-
p -> (Function<CallbackContext, Maybe<Content>>) ctx -> p.afterAgentCallback(this, ctx))
310-
.collect(toImmutableList());
311-
}
312-
313-
/**
314-
* Calls agent callbacks and returns the first produced event, if any.
315-
*
316-
* @param agentCallbacks Callback functions.
317-
* @param invocationContext Current invocation context.
318-
* @return single emitting first event, or empty if none.
319-
*/
320-
private Single<Optional<Event>> callCallback(
321-
List<Function<CallbackContext, Maybe<Content>>> agentCallbacks,
294+
private Maybe<Event> processAgentCallbackResult(
295+
Function<CallbackContext, Maybe<Content>> agentCallback,
322296
InvocationContext invocationContext) {
323-
if (agentCallbacks == null || agentCallbacks.isEmpty()) {
324-
return Single.just(Optional.empty());
325-
}
326-
327-
CallbackContext callbackContext =
328-
new CallbackContext(invocationContext, /* eventActions= */ null);
329-
330-
return Flowable.fromIterable(agentCallbacks)
331-
.concatMap(
332-
callback -> {
333-
Maybe<Content> maybeContent = callback.apply(callbackContext);
334-
335-
return maybeContent
336-
.map(
337-
content -> {
338-
invocationContext.setEndInvocation(true);
339-
return Optional.of(
340-
Event.builder()
341-
.id(Event.generateEventId())
342-
.invocationId(invocationContext.invocationId())
343-
.author(name())
344-
.branch(invocationContext.branch())
345-
.actions(callbackContext.eventActions())
346-
.content(content)
347-
.build());
348-
})
349-
.toFlowable();
297+
var callbackContext = new CallbackContext(invocationContext, /* eventActions= */ null);
298+
return agentCallback
299+
.apply(callbackContext)
300+
.map(
301+
content -> {
302+
invocationContext.setEndInvocation(true);
303+
return Event.builder()
304+
.id(Event.generateEventId())
305+
.invocationId(invocationContext.invocationId())
306+
.author(name())
307+
.branch(invocationContext.branch())
308+
.actions(callbackContext.eventActions())
309+
.content(content)
310+
.build();
350311
})
351-
.firstElement()
352312
.switchIfEmpty(
353-
Single.defer(
313+
Maybe.defer(
354314
() -> {
355315
if (callbackContext.state().hasDelta()) {
356316
Event.Builder eventBuilder =
@@ -361,9 +321,9 @@ private Single<Optional<Event>> callCallback(
361321
.branch(invocationContext.branch())
362322
.actions(callbackContext.eventActions());
363323

364-
return Single.just(Optional.of(eventBuilder.build()));
324+
return Maybe.just(eventBuilder.build());
365325
} else {
366-
return Single.just(Optional.empty());
326+
return Maybe.empty();
367327
}
368328
}));
369329
}

core/src/main/java/com/google/adk/agents/InvocationContext.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import com.google.adk.plugins.PluginManager;
2626
import com.google.adk.sessions.BaseSessionService;
2727
import com.google.adk.sessions.Session;
28+
import com.google.common.collect.ImmutableList;
2829
import com.google.common.collect.ImmutableSet;
2930
import com.google.errorprone.annotations.CanIgnoreReturnValue;
3031
import com.google.errorprone.annotations.InlineMe;
@@ -44,6 +45,7 @@ public class InvocationContext {
4445
private final BaseArtifactService artifactService;
4546
private final BaseMemoryService memoryService;
4647
private final Plugin pluginManager;
48+
private final Plugin combinedPlugin;
4749
private final Optional<LiveRequestQueue> liveRequestQueue;
4850
private final Map<String, ActiveStreamingTool> activeStreamingTools;
4951
private final String invocationId;
@@ -73,6 +75,13 @@ protected InvocationContext(Builder builder) {
7375
this.endInvocation = builder.endInvocation;
7476
this.resumabilityConfig = builder.resumabilityConfig;
7577
this.invocationCostManager = builder.invocationCostManager;
78+
this.combinedPlugin =
79+
Optional.ofNullable(builder.agent)
80+
.map(BaseAgent::getPlugin)
81+
.map(
82+
agentPlugin ->
83+
(Plugin) new PluginManager(ImmutableList.of(pluginManager, agentPlugin)))
84+
.orElse(pluginManager);
7685
}
7786

7887
/**
@@ -235,6 +244,14 @@ public Plugin pluginManager() {
235244
return pluginManager;
236245
}
237246

247+
/**
248+
* Returns a {@link Plugin} that combines agent-specific plugins with framework-level plugins,
249+
* allowing tools from both to be invoked.
250+
*/
251+
public Plugin combinedPlugin() {
252+
return combinedPlugin;
253+
}
254+
238255
/** Returns a map of tool call IDs to active streaming tools for the current invocation. */
239256
public Map<String, ActiveStreamingTool> activeStreamingTools() {
240257
return activeStreamingTools;

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 8 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ private Flowable<LlmResponse> callLlm(
199199
.onErrorResumeNext(
200200
exception ->
201201
context
202-
.pluginManager()
202+
.combinedPlugin()
203203
.onModelErrorCallback(
204204
new CallbackContext(
205205
context, eventForCallbackUsage.actions()),
@@ -243,27 +243,9 @@ private Single<Optional<LlmResponse>> handleBeforeModelCallback(
243243
Event callbackEvent = modelResponseEvent.toBuilder().build();
244244
CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions());
245245

246-
Maybe<LlmResponse> pluginResult =
247-
context.pluginManager().beforeModelCallback(callbackContext, llmRequestBuilder);
248-
249-
LlmAgent agent = (LlmAgent) context.agent();
250-
251-
Optional<List<? extends BeforeModelCallback>> callbacksOpt = agent.beforeModelCallback();
252-
if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) {
253-
return pluginResult.map(Optional::of).defaultIfEmpty(Optional.empty());
254-
}
255-
256-
List<? extends BeforeModelCallback> callbacks = callbacksOpt.get();
257-
258-
Maybe<LlmResponse> callbackResult =
259-
Maybe.defer(
260-
() ->
261-
Flowable.fromIterable(callbacks)
262-
.concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder))
263-
.firstElement());
264-
265-
return pluginResult
266-
.switchIfEmpty(callbackResult)
246+
return context
247+
.combinedPlugin()
248+
.beforeModelCallback(callbackContext, llmRequestBuilder)
267249
.map(Optional::of)
268250
.defaultIfEmpty(Optional.empty());
269251
}
@@ -279,24 +261,10 @@ private Single<LlmResponse> handleAfterModelCallback(
279261
Event callbackEvent = modelResponseEvent.toBuilder().build();
280262
CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions());
281263

282-
Maybe<LlmResponse> pluginResult =
283-
context.pluginManager().afterModelCallback(callbackContext, llmResponse);
284-
285-
LlmAgent agent = (LlmAgent) context.agent();
286-
Optional<List<? extends AfterModelCallback>> callbacksOpt = agent.afterModelCallback();
287-
288-
if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) {
289-
return pluginResult.defaultIfEmpty(llmResponse);
290-
}
291-
292-
Maybe<LlmResponse> callbackResult =
293-
Maybe.defer(
294-
() ->
295-
Flowable.fromIterable(callbacksOpt.get())
296-
.concatMapMaybe(callback -> callback.call(callbackContext, llmResponse))
297-
.firstElement());
298-
299-
return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse);
264+
return context
265+
.combinedPlugin()
266+
.afterModelCallback(callbackContext, llmResponse)
267+
.defaultIfEmpty(llmResponse);
300268
}
301269

302270
/**

core/src/main/java/com/google/adk/flows/llmflows/Functions.java

Lines changed: 5 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222

2323
import com.google.adk.Telemetry;
2424
import com.google.adk.agents.ActiveStreamingTool;
25-
import com.google.adk.agents.Callbacks.AfterToolCallback;
26-
import com.google.adk.agents.Callbacks.BeforeToolCallback;
2725
import com.google.adk.agents.InvocationContext;
28-
import com.google.adk.agents.LlmAgent;
2926
import com.google.adk.agents.RunConfig.ToolExecutionMode;
3027
import com.google.adk.events.Event;
3128
import com.google.adk.events.EventActions;
@@ -388,7 +385,7 @@ private static Maybe<Event> postProcessFunctionResult(
388385
.onErrorResumeNext(
389386
t ->
390387
invocationContext
391-
.pluginManager()
388+
.combinedPlugin()
392389
.onToolErrorCallback(tool, functionArgs, toolContext, t)
393390
.map(isLive ? Optional::ofNullable : Optional::of)
394391
.switchIfEmpty(Single.error(t)))
@@ -457,30 +454,7 @@ private static Maybe<Map<String, Object>> maybeInvokeBeforeToolCall(
457454
BaseTool tool,
458455
Map<String, Object> functionArgs,
459456
ToolContext toolContext) {
460-
if (invocationContext.agent() instanceof LlmAgent) {
461-
LlmAgent agent = (LlmAgent) invocationContext.agent();
462-
463-
Maybe<Map<String, Object>> pluginResult =
464-
invocationContext.pluginManager().beforeToolCallback(tool, functionArgs, toolContext);
465-
466-
Optional<List<? extends BeforeToolCallback>> callbacksOpt = agent.beforeToolCallback();
467-
if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) {
468-
return pluginResult;
469-
}
470-
List<? extends BeforeToolCallback> callbacks = callbacksOpt.get();
471-
472-
Maybe<Map<String, Object>> callbackResult =
473-
Maybe.defer(
474-
() ->
475-
Flowable.fromIterable(callbacks)
476-
.concatMapMaybe(
477-
callback ->
478-
callback.call(invocationContext, tool, functionArgs, toolContext))
479-
.firstElement());
480-
481-
return pluginResult.switchIfEmpty(callbackResult);
482-
}
483-
return Maybe.empty();
457+
return invocationContext.combinedPlugin().beforeToolCallback(tool, functionArgs, toolContext);
484458
}
485459

486460
private static Maybe<Map<String, Object>> maybeInvokeAfterToolCall(
@@ -489,37 +463,9 @@ private static Maybe<Map<String, Object>> maybeInvokeAfterToolCall(
489463
Map<String, Object> functionArgs,
490464
ToolContext toolContext,
491465
Map<String, Object> functionResult) {
492-
if (invocationContext.agent() instanceof LlmAgent) {
493-
LlmAgent agent = (LlmAgent) invocationContext.agent();
494-
495-
Maybe<Map<String, Object>> pluginResult =
496-
invocationContext
497-
.pluginManager()
498-
.afterToolCallback(tool, functionArgs, toolContext, functionResult);
499-
500-
Optional<List<? extends AfterToolCallback>> callbacksOpt = agent.afterToolCallback();
501-
if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) {
502-
return pluginResult;
503-
}
504-
List<? extends AfterToolCallback> callbacks = callbacksOpt.get();
505-
506-
Maybe<Map<String, Object>> callbackResult =
507-
Maybe.defer(
508-
() ->
509-
Flowable.fromIterable(callbacks)
510-
.concatMapMaybe(
511-
callback ->
512-
callback.call(
513-
invocationContext,
514-
tool,
515-
functionArgs,
516-
toolContext,
517-
functionResult))
518-
.firstElement());
519-
520-
return pluginResult.switchIfEmpty(callbackResult);
521-
}
522-
return Maybe.empty();
466+
return invocationContext
467+
.combinedPlugin()
468+
.afterToolCallback(tool, functionArgs, toolContext, functionResult);
523469
}
524470

525471
private static Maybe<Map<String, Object>> callTool(

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ public Flowable<Event> runAsync(
514514
updatedSession,
515515
session);
516516
return contextWithUpdatedSession
517-
.pluginManager()
517+
.combinedPlugin()
518518
.onEventCallback(
519519
contextWithUpdatedSession,
520520
registeredEvent)

0 commit comments

Comments
 (0)