Skip to content

Commit cf64fac

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor: Let LlmAgent.tools() return an async type instead of blocking
PiperOrigin-RevId: 843824506
1 parent a77971a commit cf64fac

File tree

4 files changed

+68
-40
lines changed

4 files changed

+68
-40
lines changed

core/src/main/java/com/google/adk/agents/LlmAgent.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -736,8 +736,8 @@ public IncludeContents includeContents() {
736736
return includeContents;
737737
}
738738

739-
public List<BaseTool> tools() {
740-
return canonicalTools().toList().blockingGet();
739+
public Single<List<BaseTool>> tools() {
740+
return canonicalTools().toList();
741741
}
742742

743743
public List<Object> toolsUnion() {

core/src/main/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessor.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,16 @@ private Maybe<Event> assembleEvent(
106106
InvocationContext invocationContext,
107107
Collection<FunctionCall> functionCalls,
108108
Map<String, ToolConfirmation> toolConfirmations) {
109-
ImmutableMap.Builder<String, BaseTool> toolsBuilder = ImmutableMap.builder();
109+
Single<ImmutableMap<String, BaseTool>> toolsMapSingle;
110110
if (invocationContext.agent() instanceof LlmAgent llmAgent) {
111-
for (BaseTool tool : llmAgent.tools()) {
112-
toolsBuilder.put(tool.name(), tool);
113-
}
111+
toolsMapSingle =
112+
llmAgent
113+
.tools()
114+
.map(
115+
toolList ->
116+
toolList.stream().collect(toImmutableMap(BaseTool::name, tool -> tool)));
117+
} else {
118+
toolsMapSingle = Single.just(ImmutableMap.of());
114119
}
115120

116121
var functionCallEvent =
@@ -124,8 +129,10 @@ private Maybe<Event> assembleEvent(
124129
.build())
125130
.build();
126131

127-
return Functions.handleFunctionCalls(
128-
invocationContext, functionCallEvent, toolsBuilder.buildOrThrow(), toolConfirmations);
132+
return toolsMapSingle.flatMapMaybe(
133+
toolsMap ->
134+
Functions.handleFunctionCalls(
135+
invocationContext, functionCallEvent, toolsMap, toolConfirmations));
129136
}
130137

131138
private ImmutableMap<String, ToolConfirmation> filterRequestConfirmationFunctionResponses(

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@
5050
import io.reactivex.rxjava3.core.Flowable;
5151
import io.reactivex.rxjava3.core.Maybe;
5252
import io.reactivex.rxjava3.core.Single;
53-
import java.lang.reflect.Parameter;
5453
import java.util.ArrayList;
54+
import java.util.Arrays;
5555
import java.util.Collections;
5656
import java.util.List;
5757
import java.util.Map;
@@ -616,35 +616,39 @@ public Flowable<Event> runLive(
616616
try {
617617
InvocationContext invocationContext =
618618
newInvocationContextForLive(session, Optional.of(liveRequestQueue), runConfig);
619-
if (invocationContext.agent() instanceof LlmAgent) {
620-
LlmAgent agent = (LlmAgent) invocationContext.agent();
621-
for (BaseTool tool : agent.tools()) {
622-
if (tool instanceof FunctionTool functionTool) {
623-
for (Parameter parameter : functionTool.func().getParameters()) {
624-
if (parameter.getType().equals(LiveRequestQueue.class)) {
625-
invocationContext
626-
.activeStreamingTools()
627-
.put(functionTool.name(), new ActiveStreamingTool(new LiveRequestQueue()));
628-
}
629-
}
630-
}
631-
}
619+
620+
Single<InvocationContext> invocationContextSingle;
621+
if (invocationContext.agent() instanceof LlmAgent agent) {
622+
invocationContextSingle =
623+
agent
624+
.tools()
625+
.map(
626+
tools -> {
627+
this.addActiveStreamingTools(invocationContext, tools);
628+
return invocationContext;
629+
});
630+
} else {
631+
invocationContextSingle = Single.just(invocationContext);
632632
}
633-
return Telemetry.traceFlowable(
634-
spanContext,
635-
span,
636-
() ->
637-
invocationContext
638-
.agent()
639-
.runLive(invocationContext)
640-
.doOnNext(event -> this.sessionService.appendEvent(session, event))
641-
.onErrorResumeNext(
642-
throwable -> {
643-
span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution");
644-
span.recordException(throwable);
645-
span.end();
646-
return Flowable.error(throwable);
647-
}));
633+
634+
return invocationContextSingle.flatMapPublisher(
635+
updatedInvocationContext ->
636+
Telemetry.traceFlowable(
637+
spanContext,
638+
span,
639+
() ->
640+
updatedInvocationContext
641+
.agent()
642+
.runLive(updatedInvocationContext)
643+
.doOnNext(event -> this.sessionService.appendEvent(session, event))
644+
.onErrorResumeNext(
645+
throwable -> {
646+
span.setStatus(
647+
StatusCode.ERROR, "Error in runLive Flowable execution");
648+
span.recordException(throwable);
649+
span.end();
650+
return Flowable.error(throwable);
651+
})));
648652
} catch (Throwable t) {
649653
span.setStatus(StatusCode.ERROR, "Error during runLive synchronous setup");
650654
span.recordException(t);
@@ -740,5 +744,22 @@ private BaseAgent findAgentToRun(Session session, BaseAgent rootAgent) {
740744
return rootAgent;
741745
}
742746

747+
private void addActiveStreamingTools(InvocationContext invocationContext, List<BaseTool> tools) {
748+
tools.stream()
749+
.filter(FunctionTool.class::isInstance)
750+
.map(FunctionTool.class::cast)
751+
.filter(this::hasLiveRequestQueueParameter)
752+
.forEach(
753+
tool ->
754+
invocationContext
755+
.activeStreamingTools()
756+
.put(tool.name(), new ActiveStreamingTool(new LiveRequestQueue())));
757+
}
758+
759+
private boolean hasLiveRequestQueueParameter(FunctionTool functionTool) {
760+
return Arrays.stream(functionTool.func().getParameters())
761+
.anyMatch(parameter -> parameter.getType().equals(LiveRequestQueue.class));
762+
}
763+
743764
// TODO: run statelessly
744765
}

core/src/test/java/com/google/adk/agents/ConfigAgentUtilsTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,8 @@ public void fromConfig_withBuiltInTool_loadsTool() throws IOException, Configura
258258

259259
assertThat(agent).isInstanceOf(LlmAgent.class);
260260
LlmAgent llmAgent = (LlmAgent) agent;
261-
assertThat(llmAgent.tools()).hasSize(1);
262-
assertThat(llmAgent.tools().get(0).name()).isEqualTo("google_search");
261+
assertThat(llmAgent.tools().blockingGet()).hasSize(1);
262+
assertThat(llmAgent.tools().blockingGet().get(0).name()).isEqualTo("google_search");
263263
}
264264

265265
@Test
@@ -784,7 +784,7 @@ public void fromConfig_withIncludeContentsAndOtherFields_parsesAllFieldsCorrectl
784784
assertThat(llmAgent.outputKey()).hasValue("testOutput");
785785
assertThat(llmAgent.disallowTransferToParent()).isTrue();
786786
assertThat(llmAgent.disallowTransferToPeers()).isFalse();
787-
assertThat(llmAgent.tools()).hasSize(1);
787+
assertThat(llmAgent.tools().blockingGet()).hasSize(1);
788788
assertThat(llmAgent.model()).isPresent();
789789
}
790790

0 commit comments

Comments
 (0)