Skip to content

Commit c037893

Browse files
bjacotgcopybara-github
authored andcommitted
feat: Integrating Plugin with ADK
This change integrates the plugin system with ADK. PluginManager is attached to the invocation context similar to session/artifact/memory. It includes integrations with following ADK internal callbacks: * App callbacks: Integrated in the BaseRunner class, in run_async and run_live * On Message callbacks: Integrated in the BaseRunner class, triggers on run_async. * Agent callbacks: Integrated in the BaseAgent class. Leveraging the existing *callback functions * Model callbacks: Integrating in the BaseLlmFlow. * Tool callbacks: Integrated in Function.java, wrapped around the code for agent tool_callbacks The plugin integrations currently do not work with Bidi-streaming (live) mode. Sample code to use plugins: ```java # Add plugins to Runner Runner runner = Runner( agent, "my-app", artifact_service, session_service, memory_service, ImmutableList.of( MySamplePlugin(), LoggingPlugin()), ) ``` PiperOrigin-RevId: 808029712
1 parent dc29535 commit c037893

File tree

8 files changed

+674
-103
lines changed

8 files changed

+674
-103
lines changed

core/src/main/java/com/google/adk/agents/BaseAgent.java

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import com.google.adk.agents.Callbacks.AfterAgentCallback;
2323
import com.google.adk.agents.Callbacks.BeforeAgentCallback;
2424
import com.google.adk.events.Event;
25+
import com.google.adk.plugins.PluginManager;
2526
import com.google.common.collect.ImmutableList;
2627
import com.google.errorprone.annotations.DoNotCall;
2728
import com.google.genai.types.Content;
@@ -34,6 +35,7 @@
3435
import java.util.List;
3536
import java.util.Optional;
3637
import java.util.function.Function;
38+
import java.util.stream.Stream;
3739
import org.jspecify.annotations.Nullable;
3840

3941
/** Base class for all agents. */
@@ -208,11 +210,11 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
208210
InvocationContext invocationContext = createInvocationContext(parentContext);
209211

210212
Flowable<Event> executionFlowable =
211-
beforeAgentCallback
212-
.map(
213-
callback ->
214-
callCallback(beforeCallbacksToFunctions(callback), invocationContext))
215-
.orElse(Single.just(Optional.empty()))
213+
callCallback(
214+
beforeCallbacksToFunctions(
215+
invocationContext.pluginManager(),
216+
beforeAgentCallback.orElse(ImmutableList.of())),
217+
invocationContext)
216218
.flatMapPublisher(
217219
beforeEventOpt -> {
218220
if (invocationContext.endInvocation()) {
@@ -223,16 +225,14 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
223225
Flowable<Event> mainEvents =
224226
Flowable.defer(() -> runAsyncImpl(invocationContext));
225227
Flowable<Event> afterEvents =
226-
afterAgentCallback
227-
.map(
228-
callback ->
229-
Flowable.defer(
230-
() ->
231-
callCallback(
232-
afterCallbacksToFunctions(callback),
233-
invocationContext)
234-
.flatMapPublisher(Flowable::fromOptional)))
235-
.orElse(Flowable.empty());
228+
Flowable.defer(
229+
() ->
230+
callCallback(
231+
afterCallbacksToFunctions(
232+
invocationContext.pluginManager(),
233+
afterAgentCallback.orElse(ImmutableList.of())),
234+
invocationContext)
235+
.flatMapPublisher(Flowable::fromOptional));
236236

237237
return Flowable.concat(beforeEvents, mainEvents, afterEvents);
238238
});
@@ -248,9 +248,11 @@ public Flowable<Event> runAsync(InvocationContext parentContext) {
248248
* @return callback functions.
249249
*/
250250
private ImmutableList<Function<CallbackContext, Maybe<Content>>> beforeCallbacksToFunctions(
251-
List<BeforeAgentCallback> callbacks) {
252-
return callbacks.stream()
253-
.map(callback -> (Function<CallbackContext, Maybe<Content>>) callback::call)
251+
PluginManager pluginManager, List<BeforeAgentCallback> callbacks) {
252+
return Stream.concat(
253+
Stream.of(ctx -> pluginManager.runBeforeAgentCallback(this, ctx)),
254+
callbacks.stream()
255+
.map(callback -> (Function<CallbackContext, Maybe<Content>>) callback::call))
254256
.collect(toImmutableList());
255257
}
256258

@@ -261,9 +263,11 @@ private ImmutableList<Function<CallbackContext, Maybe<Content>>> beforeCallbacks
261263
* @return callback functions.
262264
*/
263265
private ImmutableList<Function<CallbackContext, Maybe<Content>>> afterCallbacksToFunctions(
264-
List<AfterAgentCallback> callbacks) {
265-
return callbacks.stream()
266-
.map(callback -> (Function<CallbackContext, Maybe<Content>>) callback::call)
266+
PluginManager pluginManager, List<AfterAgentCallback> callbacks) {
267+
return Stream.concat(
268+
Stream.of(ctx -> pluginManager.runAfterAgentCallback(this, ctx)),
269+
callbacks.stream()
270+
.map(callback -> (Function<CallbackContext, Maybe<Content>>) callback::call))
267271
.collect(toImmutableList());
268272
}
269273

core/src/main/java/com/google/adk/agents/InvocationContext.java

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import com.google.adk.artifacts.BaseArtifactService;
2020
import com.google.adk.memory.BaseMemoryService;
2121
import com.google.adk.models.LlmCallsLimitExceededException;
22+
import com.google.adk.plugins.PluginManager;
2223
import com.google.adk.sessions.BaseSessionService;
2324
import com.google.adk.sessions.Session;
2425
import com.google.errorprone.annotations.InlineMe;
@@ -36,6 +37,7 @@ public class InvocationContext {
3637
private final BaseSessionService sessionService;
3738
private final BaseArtifactService artifactService;
3839
private final BaseMemoryService memoryService;
40+
private final PluginManager pluginManager;
3941
private final Optional<LiveRequestQueue> liveRequestQueue;
4042
private final Map<String, ActiveStreamingTool> activeStreamingTools = new ConcurrentHashMap<>();
4143

@@ -53,6 +55,7 @@ public InvocationContext(
5355
BaseSessionService sessionService,
5456
BaseArtifactService artifactService,
5557
BaseMemoryService memoryService,
58+
PluginManager pluginManager,
5659
Optional<LiveRequestQueue> liveRequestQueue,
5760
Optional<String> branch,
5861
String invocationId,
@@ -64,6 +67,7 @@ public InvocationContext(
6467
this.sessionService = sessionService;
6568
this.artifactService = artifactService;
6669
this.memoryService = memoryService;
70+
this.pluginManager = pluginManager;
6771
this.liveRequestQueue = liveRequestQueue;
6872
this.branch = branch;
6973
this.invocationId = invocationId;
@@ -74,15 +78,56 @@ public InvocationContext(
7478
this.endInvocation = endInvocation;
7579
}
7680

81+
/**
82+
* @deprecated Use the {@link #InvocationContext} constructor with PluginManager directly instead
83+
*/
84+
@InlineMe(
85+
replacement =
86+
"this(sessionService, artifactService, memoryService, new"
87+
+ " PluginManager(), liveRequestQueue, branch, invocationId, agent,"
88+
+ " session, userContent, runConfig, endInvocation)",
89+
imports = "com.google.adk.plugins.PluginManager")
90+
@Deprecated
91+
public InvocationContext(
92+
BaseSessionService sessionService,
93+
BaseArtifactService artifactService,
94+
BaseMemoryService memoryService,
95+
Optional<LiveRequestQueue> liveRequestQueue,
96+
Optional<String> branch,
97+
String invocationId,
98+
BaseAgent agent,
99+
Session session,
100+
Optional<Content> userContent,
101+
RunConfig runConfig,
102+
boolean endInvocation) {
103+
this(
104+
sessionService,
105+
artifactService,
106+
memoryService,
107+
new PluginManager(),
108+
liveRequestQueue,
109+
branch,
110+
invocationId,
111+
agent,
112+
session,
113+
userContent,
114+
runConfig,
115+
endInvocation);
116+
}
117+
77118
/**
78119
* @deprecated Use the {@link #InvocationContext} constructor directly instead
79120
*/
80121
@InlineMe(
81122
replacement =
82-
"new InvocationContext(sessionService, artifactService, null, Optional.empty(),"
83-
+ " Optional.empty(), invocationId, agent, session, Optional.ofNullable(userContent),"
84-
+ " runConfig, false)",
85-
imports = {"com.google.adk.agents.InvocationContext", "java.util.Optional"})
123+
"new InvocationContext(sessionService, artifactService, null, new PluginManager(),"
124+
+ " Optional.empty(), Optional.empty(), invocationId, agent, session,"
125+
+ " Optional.ofNullable(userContent), runConfig, false)",
126+
imports = {
127+
"com.google.adk.agents.InvocationContext",
128+
"com.google.adk.plugins.PluginManager",
129+
"java.util.Optional"
130+
})
86131
@Deprecated
87132
public static InvocationContext create(
88133
BaseSessionService sessionService,
@@ -96,6 +141,7 @@ public static InvocationContext create(
96141
sessionService,
97142
artifactService,
98143
/* memoryService= */ null,
144+
new PluginManager(),
99145
/* liveRequestQueue= */ Optional.empty(),
100146
/* branch= */ Optional.empty(),
101147
invocationId,
@@ -111,11 +157,15 @@ public static InvocationContext create(
111157
*/
112158
@InlineMe(
113159
replacement =
114-
"new InvocationContext(sessionService, artifactService, null,"
160+
"new InvocationContext(sessionService, artifactService, null, new PluginManager(),"
115161
+ " Optional.ofNullable(liveRequestQueue), Optional.empty(),"
116162
+ " InvocationContext.newInvocationContextId(), agent, session, Optional.empty(),"
117163
+ " runConfig, false)",
118-
imports = {"com.google.adk.agents.InvocationContext", "java.util.Optional"})
164+
imports = {
165+
"com.google.adk.agents.InvocationContext",
166+
"com.google.adk.plugins.PluginManager",
167+
"java.util.Optional"
168+
})
119169
@Deprecated
120170
public static InvocationContext create(
121171
BaseSessionService sessionService,
@@ -128,6 +178,7 @@ public static InvocationContext create(
128178
sessionService,
129179
artifactService,
130180
/* memoryService= */ null,
181+
new PluginManager(),
131182
Optional.ofNullable(liveRequestQueue),
132183
/* branch= */ Optional.empty(),
133184
InvocationContext.newInvocationContextId(),
@@ -144,6 +195,7 @@ public static InvocationContext copyOf(InvocationContext other) {
144195
other.sessionService,
145196
other.artifactService,
146197
other.memoryService,
198+
other.pluginManager,
147199
other.liveRequestQueue,
148200
other.branch,
149201
other.invocationId,
@@ -168,6 +220,10 @@ public BaseMemoryService memoryService() {
168220
return memoryService;
169221
}
170222

223+
public PluginManager pluginManager() {
224+
return pluginManager;
225+
}
226+
171227
public Map<String, ActiveStreamingTool> activeStreamingTools() {
172228
return activeStreamingTools;
173229
}
@@ -260,6 +316,7 @@ public boolean equals(Object o) {
260316
&& Objects.equals(sessionService, that.sessionService)
261317
&& Objects.equals(artifactService, that.artifactService)
262318
&& Objects.equals(memoryService, that.memoryService)
319+
&& Objects.equals(pluginManager, that.pluginManager)
263320
&& Objects.equals(liveRequestQueue, that.liveRequestQueue)
264321
&& Objects.equals(activeStreamingTools, that.activeStreamingTools)
265322
&& Objects.equals(branch, that.branch)
@@ -276,6 +333,7 @@ public int hashCode() {
276333
sessionService,
277334
artifactService,
278335
memoryService,
336+
pluginManager,
279337
liveRequestQueue,
280338
activeStreamingTools,
281339
branch,

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,17 @@ private Flowable<LlmResponse> callLlm(
230230
return llm.generateContent(
231231
llmRequestBuilder.build(),
232232
context.runConfig().streamingMode() == StreamingMode.SSE)
233+
.onErrorResumeNext(
234+
exception ->
235+
context
236+
.pluginManager()
237+
.runOnModelErrorCallback(
238+
new CallbackContext(
239+
context, eventForCallbackUsage.actions()),
240+
llmRequest,
241+
exception)
242+
.switchIfEmpty(Single.error(exception))
243+
.toFlowable())
233244
.doOnNext(
234245
llmResp -> {
235246
try (Scope innerScope = llmCallSpan.makeCurrent()) {
@@ -260,29 +271,32 @@ private Flowable<LlmResponse> callLlm(
260271
*/
261272
private Single<Optional<LlmResponse>> handleBeforeModelCallback(
262273
InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent) {
274+
Event callbackEvent = modelResponseEvent.toBuilder().build();
275+
CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions());
276+
277+
Maybe<LlmResponse> pluginResult =
278+
context.pluginManager().runBeforeModelCallback(callbackContext, llmRequestBuilder.build());
279+
263280
LlmAgent agent = (LlmAgent) context.agent();
264281

265282
Optional<List<BeforeModelCallback>> callbacksOpt = agent.beforeModelCallback();
266283
if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) {
267-
return Single.just(Optional.empty());
284+
return pluginResult.map(Optional::of).defaultIfEmpty(Optional.empty());
268285
}
269286

270-
Event callbackEvent = modelResponseEvent.toBuilder().build();
271287
List<BeforeModelCallback> callbacks = callbacksOpt.get();
272288

273-
return Flowable.fromIterable(callbacks)
274-
.concatMapSingle(
275-
callback -> {
276-
CallbackContext callbackContext =
277-
new CallbackContext(context, callbackEvent.actions());
278-
return callback
279-
.call(callbackContext, llmRequestBuilder)
280-
.map(Optional::of)
281-
.defaultIfEmpty(Optional.empty());
282-
})
283-
.filter(Optional::isPresent)
284-
.firstElement()
285-
.switchIfEmpty(Single.just(Optional.empty()));
289+
Maybe<LlmResponse> callbackResult =
290+
Maybe.defer(
291+
() ->
292+
Flowable.fromIterable(callbacks)
293+
.concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder))
294+
.firstElement());
295+
296+
return pluginResult
297+
.switchIfEmpty(callbackResult)
298+
.map(Optional::of)
299+
.defaultIfEmpty(Optional.empty());
286300
}
287301

288302
/**
@@ -293,30 +307,27 @@ private Single<Optional<LlmResponse>> handleBeforeModelCallback(
293307
*/
294308
private Single<LlmResponse> handleAfterModelCallback(
295309
InvocationContext context, LlmResponse llmResponse, Event modelResponseEvent) {
310+
Event callbackEvent = modelResponseEvent.toBuilder().build();
311+
CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions());
312+
313+
Maybe<LlmResponse> pluginResult =
314+
context.pluginManager().runAfterModelCallback(callbackContext, llmResponse);
315+
296316
LlmAgent agent = (LlmAgent) context.agent();
297317
Optional<List<AfterModelCallback>> callbacksOpt = agent.afterModelCallback();
298318

299319
if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) {
300-
return Single.just(llmResponse);
320+
return pluginResult.defaultIfEmpty(llmResponse);
301321
}
302322

303-
Event callbackEvent = modelResponseEvent.toBuilder().content(llmResponse.content()).build();
304-
List<AfterModelCallback> callbacks = callbacksOpt.get();
305-
306-
return Flowable.fromIterable(callbacks)
307-
.concatMapSingle(
308-
callback -> {
309-
CallbackContext callbackContext =
310-
new CallbackContext(context, callbackEvent.actions());
311-
return callback
312-
.call(callbackContext, llmResponse)
313-
.map(Optional::of)
314-
.defaultIfEmpty(Optional.empty());
315-
})
316-
.filter(Optional::isPresent)
317-
.firstElement()
318-
.map(Optional::get)
319-
.switchIfEmpty(Single.just(llmResponse));
323+
Maybe<LlmResponse> callbackResult =
324+
Maybe.defer(
325+
() ->
326+
Flowable.fromIterable(callbacksOpt.get())
327+
.concatMapMaybe(callback -> callback.call(callbackContext, llmResponse))
328+
.firstElement());
329+
330+
return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse);
320331
}
321332

322333
/**

0 commit comments

Comments
 (0)