Skip to content

Commit b7e0772

Browse files
Mateusz Krawieccopybara-github
authored andcommitted
fix: race condition while appending Agent's events to SessionService
Agent events were processed by agentEvents.flatMap(e -> session.appendEvent(e)), which has no ordering guarantees, so the events could have been appended to session out of order. PiperOrigin-RevId: 853618840
1 parent e7380f1 commit b7e0772

File tree

7 files changed

+179
-52
lines changed

7 files changed

+179
-52
lines changed

core/src/main/java/com/google/adk/agents/LlmAgent.java

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
import com.google.adk.models.Model;
4949
import com.google.adk.tools.BaseTool;
5050
import com.google.adk.tools.BaseToolset;
51-
import com.google.adk.tools.ToolMarker;
5251
import com.google.common.base.Preconditions;
5352
import com.google.common.collect.ImmutableList;
5453
import com.google.errorprone.annotations.CanIgnoreReturnValue;
@@ -85,7 +84,7 @@ public enum IncludeContents {
8584
private final Optional<Model> model;
8685
private final Instruction instruction;
8786
private final Instruction globalInstruction;
88-
private final List<ToolMarker> toolsUnion;
87+
private final List<Object> toolsUnion;
8988
private final ImmutableList<BaseToolset> toolsets;
9089
private final Optional<GenerateContentConfig> generateContentConfig;
9190
// TODO: Remove exampleProvider field - examples should only be provided via ExampleTool
@@ -153,7 +152,7 @@ public static Builder builder() {
153152
}
154153

155154
/** Extracts BaseToolset instances from the toolsUnion list. */
156-
private static ImmutableList<BaseToolset> extractToolsets(List<ToolMarker> toolsUnion) {
155+
private static ImmutableList<BaseToolset> extractToolsets(List<Object> toolsUnion) {
157156
return toolsUnion.stream()
158157
.filter(obj -> obj instanceof BaseToolset)
159158
.map(obj -> (BaseToolset) obj)
@@ -166,7 +165,7 @@ public static class Builder extends BaseAgent.Builder<Builder> {
166165

167166
private Instruction instruction;
168167
private Instruction globalInstruction;
169-
private ImmutableList<ToolMarker> toolsUnion;
168+
private ImmutableList<Object> toolsUnion;
170169
private GenerateContentConfig generateContentConfig;
171170
private BaseExampleProvider exampleProvider;
172171
private IncludeContents includeContents;
@@ -222,13 +221,13 @@ public Builder globalInstruction(String globalInstruction) {
222221
}
223222

224223
@CanIgnoreReturnValue
225-
public Builder tools(List<? extends ToolMarker> tools) {
224+
public Builder tools(List<?> tools) {
226225
this.toolsUnion = ImmutableList.copyOf(tools);
227226
return this;
228227
}
229228

230229
@CanIgnoreReturnValue
231-
public Builder tools(ToolMarker... tools) {
230+
public Builder tools(Object... tools) {
232231
this.toolsUnion = ImmutableList.copyOf(tools);
233232
return this;
234233
}
@@ -680,7 +679,7 @@ public Single<Map.Entry<String, Boolean>> canonicalGlobalInstruction(ReadonlyCon
680679
*/
681680
public Flowable<BaseTool> canonicalTools(Optional<ReadonlyContext> context) {
682681
List<Flowable<BaseTool>> toolFlowables = new ArrayList<>();
683-
for (ToolMarker toolOrToolset : toolsUnion) {
682+
for (Object toolOrToolset : toolsUnion) {
684683
if (toolOrToolset instanceof BaseTool baseTool) {
685684
toolFlowables.add(Flowable.just(baseTool));
686685
} else if (toolOrToolset instanceof BaseToolset baseToolset) {
@@ -741,7 +740,7 @@ public Single<List<BaseTool>> tools() {
741740
return canonicalTools().toList();
742741
}
743742

744-
public List<ToolMarker> toolsUnion() {
743+
public List<Object> toolsUnion() {
745744
return toolsUnion;
746745
}
747746

core/src/main/java/com/google/adk/agents/ToolResolver.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import com.google.adk.tools.BaseTool.ToolArgsConfig;
2424
import com.google.adk.tools.BaseTool.ToolConfig;
2525
import com.google.adk.tools.BaseToolset;
26-
import com.google.adk.tools.ToolMarker;
2726
import com.google.adk.utils.ComponentRegistry;
2827
import com.google.common.collect.ImmutableList;
2928
import java.lang.reflect.Constructor;
@@ -57,14 +56,14 @@ private ToolResolver() {}
5756
* @throws ConfigurationException if any tool configuration is invalid (e.g., missing name), if a
5857
* tool cannot be found by its name or class, or if tool instantiation fails.
5958
*/
60-
static ImmutableList<ToolMarker> resolveToolsAndToolsets(
59+
static ImmutableList<Object> resolveToolsAndToolsets(
6160
List<ToolConfig> toolConfigs, String configAbsPath) throws ConfigurationException {
6261

6362
if (toolConfigs == null || toolConfigs.isEmpty()) {
6463
return ImmutableList.of();
6564
}
6665

67-
ImmutableList.Builder<ToolMarker> resolvedItems = ImmutableList.builder();
66+
ImmutableList.Builder<Object> resolvedItems = ImmutableList.builder();
6867

6968
for (ToolConfig toolConfig : toolConfigs) {
7069
try {

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -459,29 +459,27 @@ public Flowable<Event> runAsync(
459459
contextWithUpdatedSession
460460
.agent()
461461
.runAsync(contextWithUpdatedSession)
462-
.flatMap(
462+
.concatMap(
463463
agentEvent ->
464464
this.sessionService
465465
.appendEvent(
466466
updatedSession, agentEvent)
467-
.flatMap(
468-
registeredEvent -> {
469-
// TODO: remove this hack
470-
// after
471-
// deprecating runAsync with
472-
// Session.
473-
copySessionStates(
474-
updatedSession,
475-
session);
476-
return contextWithUpdatedSession
477-
.pluginManager()
478-
.onEventCallback(
479-
contextWithUpdatedSession,
480-
registeredEvent)
481-
.defaultIfEmpty(
482-
registeredEvent);
483-
})
484-
.toFlowable());
467+
.toFlowable())
468+
.concatMap(
469+
registeredEvent -> {
470+
// TODO: remove this hack after
471+
// deprecating runAsync with
472+
// Session.
473+
copySessionStates(
474+
updatedSession, session);
475+
return contextWithUpdatedSession
476+
.pluginManager()
477+
.onEventCallback(
478+
contextWithUpdatedSession,
479+
registeredEvent)
480+
.defaultIfEmpty(registeredEvent)
481+
.toFlowable();
482+
});
485483

486484
// If beforeRunCallback returns content, emit it and
487485
// skip

core/src/main/java/com/google/adk/tools/BaseTool.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
import org.slf4j.LoggerFactory;
4545

4646
/** The base class for all ADK tools. */
47-
public abstract class BaseTool implements ToolMarker {
47+
public abstract class BaseTool {
4848
private final String name;
4949
private final String description;
5050
private final boolean isLongRunning;

core/src/main/java/com/google/adk/tools/BaseToolset.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import java.util.Optional;
77

88
/** Base interface for toolsets. */
9-
public interface BaseToolset extends AutoCloseable, ToolMarker {
9+
public interface BaseToolset extends AutoCloseable {
1010

1111
/**
1212
* Return all tools in the toolset based on the provided context.

core/src/main/java/com/google/adk/tools/ToolMarker.java

Lines changed: 0 additions & 19 deletions
This file was deleted.

core/src/test/java/com/google/adk/runner/RunnerTest.java

Lines changed: 151 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import static org.mockito.Mockito.when;
3030

3131
import com.google.adk.Telemetry;
32+
import com.google.adk.agents.BaseAgent;
3233
import com.google.adk.agents.InvocationContext;
3334
import com.google.adk.agents.LiveRequestQueue;
3435
import com.google.adk.agents.LlmAgent;
@@ -37,6 +38,10 @@
3738
import com.google.adk.flows.llmflows.ResumabilityConfig;
3839
import com.google.adk.models.LlmResponse;
3940
import com.google.adk.plugins.BasePlugin;
41+
import com.google.adk.sessions.BaseSessionService;
42+
import com.google.adk.sessions.GetSessionConfig;
43+
import com.google.adk.sessions.ListEventsResponse;
44+
import com.google.adk.sessions.ListSessionsResponse;
4045
import com.google.adk.sessions.Session;
4146
import com.google.adk.testing.TestLlm;
4247
import com.google.adk.testing.TestUtils;
@@ -53,11 +58,19 @@
5358
import io.reactivex.rxjava3.core.Completable;
5459
import io.reactivex.rxjava3.core.Flowable;
5560
import io.reactivex.rxjava3.core.Maybe;
61+
import io.reactivex.rxjava3.core.Single;
62+
import io.reactivex.rxjava3.schedulers.Schedulers;
5663
import io.reactivex.rxjava3.subscribers.TestSubscriber;
5764
import java.util.List;
5865
import java.util.Objects;
5966
import java.util.Optional;
67+
import java.util.UUID;
6068
import java.util.concurrent.ConcurrentHashMap;
69+
import java.util.concurrent.ConcurrentMap;
70+
71+
import java.util.function.Consumer;
72+
import java.util.function.Supplier;
73+
import javax.annotation.Nullable;
6174
import org.junit.After;
6275
import org.junit.Before;
6376
import org.junit.Rule;
@@ -639,7 +652,7 @@ public void beforeRunCallback_withStateDelta_seesMergedState() {
639652
assertThat(sessionInCallback.state()).containsEntry("number", 123);
640653
}
641654

642-
private Content createContent(String text) {
655+
static Content createContent(String text) {
643656
return Content.builder().parts(Part.builder().text(text).build()).build();
644657
}
645658

@@ -791,4 +804,141 @@ public void resumabilityConfig_isNotResumable_isFalseInInvocationContext() {
791804
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();
792805
assertThat(contextCaptor.getValue().isResumable()).isFalse();
793806
}
807+
808+
static class TestAgent extends BaseAgent {
809+
private final Supplier<Flowable<Event>> eventSupplier;
810+
811+
public TestAgent(String name, String description, Supplier<Flowable<Event>> eventSupplier) {
812+
super(name, description, ImmutableList.of(), ImmutableList.of(), ImmutableList.of());
813+
this.eventSupplier = eventSupplier;
814+
}
815+
816+
@Override
817+
public Flowable<Event> runAsyncImpl(InvocationContext context) {
818+
return eventSupplier.get();
819+
}
820+
821+
@Override
822+
public Flowable<Event> runLiveImpl(InvocationContext context) {
823+
throw new UnsupportedOperationException("runLiveImpl not supported in this test");
824+
}
825+
}
826+
827+
static class FakeSessionService implements BaseSessionService {
828+
private static final String SESSION_ID = "1234";
829+
private final Session session;
830+
private final Consumer<Event> onAppendEventFn;
831+
832+
FakeSessionService(Consumer<Event> onAppendEventFn) {
833+
this.session = Session.builder(SESSION_ID).build();
834+
this.onAppendEventFn = onAppendEventFn;
835+
}
836+
837+
@Override
838+
public Single<Event> appendEvent(Session session, Event event) {
839+
return Flowable.defer(
840+
() -> {
841+
this.onAppendEventFn.accept(event);
842+
843+
synchronized (this) {
844+
session.events().add(event);
845+
}
846+
return Flowable.just(event);
847+
})
848+
.firstElement()
849+
.toSingle()
850+
// Run this in a separate thread, to unblock the main thread that processes the events.
851+
.subscribeOn(Schedulers.io());
852+
}
853+
854+
@Override
855+
public Maybe<Session> getSession(
856+
String appName, String userId, String sessionId, Optional<GetSessionConfig> configOpt) {
857+
if (sessionId.equals(SESSION_ID)) {
858+
return Maybe.just(session);
859+
}
860+
return Maybe.empty();
861+
}
862+
863+
@Override
864+
public Single<ListSessionsResponse> listSessions(String appName, String userId) {
865+
return Single.just(
866+
ListSessionsResponse.builder().sessions(ImmutableList.of(this.session)).build());
867+
}
868+
869+
@Override
870+
public Single<ListEventsResponse> listEvents(String appName, String userId, String sessionId) {
871+
return getSession(appName, userId, sessionId, Optional.empty())
872+
.map(
873+
session ->
874+
ListEventsResponse.builder()
875+
.events(ImmutableList.copyOf(session.events()))
876+
.build())
877+
.toSingle();
878+
}
879+
880+
@Override
881+
public Single<Session> createSession(
882+
String appName,
883+
String userId,
884+
@Nullable ConcurrentMap<String, Object> state,
885+
@Nullable String sessionId) {
886+
throw new UnsupportedOperationException("createSession not supported in this test");
887+
}
888+
889+
@Override
890+
public Completable deleteSession(String appName, String userId, String sessionId) {
891+
throw new UnsupportedOperationException("deleteSession not supported in this test");
892+
}
893+
}
894+
895+
@Test
896+
public void runAsync_sessionService_appendsEventsInCorrectOrder() throws Exception {
897+
// Arrange
898+
String invocationId = UUID.randomUUID().toString();
899+
Event ev1 =
900+
Event.builder()
901+
.id("1")
902+
.invocationId(invocationId)
903+
.author("model")
904+
.content(Optional.of(createContent("event 1")))
905+
.build();
906+
Event ev2 =
907+
Event.builder()
908+
.id("2")
909+
.invocationId(invocationId)
910+
.author("model")
911+
.content(Optional.of(createContent("event 2")))
912+
.build();
913+
TestAgent testAgent = new TestAgent("test agent", "description", () -> Flowable.just(ev1, ev2));
914+
FakeSessionService laggingSessionService =
915+
new FakeSessionService(
916+
(event) -> {
917+
if (event.id().equals(ev1.id())) {
918+
try {
919+
// Lags completion of appendEvent() on the first event (id=1) from an agent, while
920+
// the second event (id=2) will be immediately appended.
921+
Thread.sleep(2000);
922+
} catch (InterruptedException e) {
923+
throw new RuntimeException(e);
924+
}
925+
}
926+
});
927+
Runner runner =
928+
Runner.builder()
929+
.agent(testAgent)
930+
.appName("test")
931+
.sessionService(laggingSessionService)
932+
.build();
933+
Session session = laggingSessionService.session;
934+
935+
// Act
936+
var unused =
937+
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();
938+
939+
// Assert that events are stored in the correct order.
940+
assertThat(session.events()).hasSize(3);
941+
assertThat(session.events().get(1).id()).isEqualTo(ev1.id());
942+
assertThat(session.events().get(2).id()).isEqualTo(ev2.id());
943+
}
794944
}

0 commit comments

Comments
 (0)