Skip to content

Commit a99c75b

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: Make FunctionResponses respect the order of FunctionCalls
Switched Flowable<Event> to Observable<Event> because it's immediately used as .toList(), so using Flowable<> brings no extra value. PiperOrigin-RevId: 853162382
1 parent d432c64 commit a99c75b

File tree

2 files changed

+82
-15
lines changed

2 files changed

+82
-15
lines changed

core/src/main/java/com/google/adk/flows/llmflows/Functions.java

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import io.opentelemetry.context.Scope;
4747
import io.reactivex.rxjava3.core.Flowable;
4848
import io.reactivex.rxjava3.core.Maybe;
49+
import io.reactivex.rxjava3.core.Observable;
4950
import io.reactivex.rxjava3.core.Single;
5051
import io.reactivex.rxjava3.disposables.Disposable;
5152
import io.reactivex.rxjava3.functions.Function;
@@ -152,15 +153,16 @@ public static Maybe<Event> handleFunctionCalls(
152153
Function<FunctionCall, Maybe<Event>> functionCallMapper =
153154
getFunctionCallMapper(invocationContext, tools, toolConfirmations, false);
154155

155-
Flowable<Event> functionResponseEventsFlowable;
156+
Observable<Event> functionResponseEventsObservable;
156157
if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) {
157-
functionResponseEventsFlowable =
158-
Flowable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper);
158+
functionResponseEventsObservable =
159+
Observable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper);
159160
} else {
160-
functionResponseEventsFlowable =
161-
Flowable.fromIterable(functionCalls).flatMapMaybe(functionCallMapper);
161+
functionResponseEventsObservable =
162+
Observable.fromIterable(functionCalls)
163+
.concatMapEager(call -> functionCallMapper.apply(call).toObservable());
162164
}
163-
return functionResponseEventsFlowable
165+
return functionResponseEventsObservable
164166
.toList()
165167
.flatMapMaybe(
166168
events -> {
@@ -217,16 +219,17 @@ public static Maybe<Event> handleFunctionCallsLive(
217219
Function<FunctionCall, Maybe<Event>> functionCallMapper =
218220
getFunctionCallMapper(invocationContext, tools, toolConfirmations, true);
219221

220-
Flowable<Event> responseEventsFlowable;
222+
Observable<Event> responseEventsObservable;
221223
if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) {
222-
responseEventsFlowable =
223-
Flowable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper);
224+
responseEventsObservable =
225+
Observable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper);
224226
} else {
225-
responseEventsFlowable =
226-
Flowable.fromIterable(functionCalls).flatMapMaybe(functionCallMapper);
227+
responseEventsObservable =
228+
Observable.fromIterable(functionCalls)
229+
.concatMapEager(call -> functionCallMapper.apply(call).toObservable());
227230
}
228231

229-
return responseEventsFlowable
232+
return responseEventsObservable
230233
.toList()
231234
.flatMapMaybe(
232235
events -> {

core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import static org.junit.Assert.assertThrows;
2424

2525
import com.google.adk.agents.InvocationContext;
26+
import com.google.adk.agents.RunConfig;
27+
import com.google.adk.agents.RunConfig.ToolExecutionMode;
2628
import com.google.adk.events.Event;
2729
import com.google.adk.testing.TestUtils;
2830
import com.google.common.collect.ImmutableList;
@@ -151,8 +153,11 @@ public void handleFunctionCalls_singleFunctionCall() {
151153
}
152154

153155
@Test
154-
public void handleFunctionCalls_multipleFunctionCalls() {
155-
InvocationContext invocationContext = createInvocationContext(createRootAgent());
156+
public void handleFunctionCalls_multipleFunctionCalls_parallel() {
157+
InvocationContext invocationContext =
158+
createInvocationContext(
159+
createRootAgent(),
160+
RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build());
156161
ImmutableMap<String, Object> args1 = ImmutableMap.<String, Object>of("key1", "value2");
157162
ImmutableMap<String, Object> args2 = ImmutableMap.<String, Object>of("key2", "value2");
158163
Event event =
@@ -201,7 +206,66 @@ public void handleFunctionCalls_multipleFunctionCalls() {
201206
.name("echo_tool")
202207
.response(ImmutableMap.of("result", args2))
203208
.build())
204-
.build());
209+
.build())
210+
.inOrder();
211+
}
212+
213+
@Test
214+
public void handleFunctionCalls_multipleFunctionCalls_sequential() {
215+
InvocationContext invocationContext =
216+
createInvocationContext(
217+
createRootAgent(),
218+
RunConfig.builder().setToolExecutionMode(ToolExecutionMode.SEQUENTIAL).build());
219+
ImmutableMap<String, Object> args1 = ImmutableMap.<String, Object>of("key1", "value2");
220+
ImmutableMap<String, Object> args2 = ImmutableMap.<String, Object>of("key2", "value2");
221+
Event event =
222+
createEvent("event").toBuilder()
223+
.content(
224+
Content.fromParts(
225+
Part.fromText("..."),
226+
Part.builder()
227+
.functionCall(
228+
FunctionCall.builder()
229+
.id("function_call_id1")
230+
.name("echo_tool")
231+
.args(args1)
232+
.build())
233+
.build(),
234+
Part.builder()
235+
.functionCall(
236+
FunctionCall.builder()
237+
.id("function_call_id2")
238+
.name("echo_tool")
239+
.args(args2)
240+
.build())
241+
.build()))
242+
.build();
243+
244+
Event functionResponseEvent =
245+
Functions.handleFunctionCalls(
246+
invocationContext, event, ImmutableMap.of("echo_tool", new TestUtils.EchoTool()))
247+
.blockingGet();
248+
249+
assertThat(functionResponseEvent).isNotNull();
250+
assertThat(functionResponseEvent.content().get().parts().get())
251+
.containsExactly(
252+
Part.builder()
253+
.functionResponse(
254+
FunctionResponse.builder()
255+
.id("function_call_id1")
256+
.name("echo_tool")
257+
.response(ImmutableMap.of("result", args1))
258+
.build())
259+
.build(),
260+
Part.builder()
261+
.functionResponse(
262+
FunctionResponse.builder()
263+
.id("function_call_id2")
264+
.name("echo_tool")
265+
.response(ImmutableMap.of("result", args2))
266+
.build())
267+
.build())
268+
.inOrder();
205269
}
206270

207271
@Test

0 commit comments

Comments
 (0)