diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index c774ff361..cbf16370f 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -459,29 +459,27 @@ public Flowable runAsync( contextWithUpdatedSession .agent() .runAsync(contextWithUpdatedSession) - .flatMap( + .concatMap( agentEvent -> this.sessionService .appendEvent( updatedSession, agentEvent) - .flatMap( - registeredEvent -> { - // TODO: remove this hack - // after - // deprecating runAsync with - // Session. - copySessionStates( - updatedSession, - session); - return contextWithUpdatedSession - .pluginManager() - .onEventCallback( - contextWithUpdatedSession, - registeredEvent) - .defaultIfEmpty( - registeredEvent); - }) - .toFlowable()); + .toFlowable()) + .concatMap( + registeredEvent -> { + // TODO: remove this hack after + // deprecating runAsync with + // Session. + copySessionStates( + updatedSession, session); + return contextWithUpdatedSession + .pluginManager() + .onEventCallback( + contextWithUpdatedSession, + registeredEvent) + .defaultIfEmpty(registeredEvent) + .toFlowable(); + }); // If beforeRunCallback returns content, emit it and // skip diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 52218a0e9..7c224603e 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -29,6 +29,7 @@ import static org.mockito.Mockito.when; import com.google.adk.Telemetry; +import com.google.adk.agents.BaseAgent; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LiveRequestQueue; import com.google.adk.agents.LlmAgent; @@ -37,6 +38,10 @@ import com.google.adk.flows.llmflows.ResumabilityConfig; import com.google.adk.models.LlmResponse; import com.google.adk.plugins.BasePlugin; +import com.google.adk.sessions.BaseSessionService; +import com.google.adk.sessions.GetSessionConfig; +import com.google.adk.sessions.ListEventsResponse; +import com.google.adk.sessions.ListSessionsResponse; import com.google.adk.sessions.Session; import com.google.adk.testing.TestLlm; import com.google.adk.testing.TestUtils; @@ -53,11 +58,19 @@ import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.schedulers.Schedulers; import io.reactivex.rxjava3.subscribers.TestSubscriber; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import java.util.function.Consumer; +import java.util.function.Supplier; +import javax.annotation.Nullable; import org.junit.After; import org.junit.Before; import org.junit.Rule; @@ -639,7 +652,7 @@ public void beforeRunCallback_withStateDelta_seesMergedState() { assertThat(sessionInCallback.state()).containsEntry("number", 123); } - private Content createContent(String text) { + static Content createContent(String text) { return Content.builder().parts(Part.builder().text(text).build()).build(); } @@ -791,4 +804,141 @@ public void resumabilityConfig_isNotResumable_isFalseInInvocationContext() { runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); assertThat(contextCaptor.getValue().isResumable()).isFalse(); } + + static class TestAgent extends BaseAgent { + private final Supplier> eventSupplier; + + public TestAgent(String name, String description, Supplier> eventSupplier) { + super(name, description, ImmutableList.of(), ImmutableList.of(), ImmutableList.of()); + this.eventSupplier = eventSupplier; + } + + @Override + public Flowable runAsyncImpl(InvocationContext context) { + return eventSupplier.get(); + } + + @Override + public Flowable runLiveImpl(InvocationContext context) { + throw new UnsupportedOperationException("runLiveImpl not supported in this test"); + } + } + + static class FakeSessionService implements BaseSessionService { + private static final String SESSION_ID = "1234"; + private final Session session; + private final Consumer onAppendEventFn; + + FakeSessionService(Consumer onAppendEventFn) { + this.session = Session.builder(SESSION_ID).build(); + this.onAppendEventFn = onAppendEventFn; + } + + @Override + public Single appendEvent(Session session, Event event) { + return Flowable.defer( + () -> { + this.onAppendEventFn.accept(event); + + synchronized (this) { + session.events().add(event); + } + return Flowable.just(event); + }) + .firstElement() + .toSingle() + // Run this in a separate thread, to unblock the main thread that processes the events. + .subscribeOn(Schedulers.io()); + } + + @Override + public Maybe getSession( + String appName, String userId, String sessionId, Optional configOpt) { + if (sessionId.equals(SESSION_ID)) { + return Maybe.just(session); + } + return Maybe.empty(); + } + + @Override + public Single listSessions(String appName, String userId) { + return Single.just( + ListSessionsResponse.builder().sessions(ImmutableList.of(this.session)).build()); + } + + @Override + public Single listEvents(String appName, String userId, String sessionId) { + return getSession(appName, userId, sessionId, Optional.empty()) + .map( + session -> + ListEventsResponse.builder() + .events(ImmutableList.copyOf(session.events())) + .build()) + .toSingle(); + } + + @Override + public Single createSession( + String appName, + String userId, + @Nullable ConcurrentMap state, + @Nullable String sessionId) { + throw new UnsupportedOperationException("createSession not supported in this test"); + } + + @Override + public Completable deleteSession(String appName, String userId, String sessionId) { + throw new UnsupportedOperationException("deleteSession not supported in this test"); + } + } + + @Test + public void runAsync_sessionService_appendsEventsInCorrectOrder() throws Exception { + // Arrange + String invocationId = UUID.randomUUID().toString(); + Event ev1 = + Event.builder() + .id("1") + .invocationId(invocationId) + .author("model") + .content(Optional.of(createContent("event 1"))) + .build(); + Event ev2 = + Event.builder() + .id("2") + .invocationId(invocationId) + .author("model") + .content(Optional.of(createContent("event 2"))) + .build(); + TestAgent testAgent = new TestAgent("test agent", "description", () -> Flowable.just(ev1, ev2)); + FakeSessionService laggingSessionService = + new FakeSessionService( + (event) -> { + if (event.id().equals(ev1.id())) { + try { + // Lags completion of appendEvent() on the first event (id=1) from an agent, while + // the second event (id=2) will be immediately appended. + Thread.sleep(2000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }); + Runner runner = + Runner.builder() + .agent(testAgent) + .appName("test") + .sessionService(laggingSessionService) + .build(); + Session session = laggingSessionService.session; + + // Act + var unused = + runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); + + // Assert that events are stored in the correct order. + assertThat(session.events()).hasSize(3); + assertThat(session.events().get(1).id()).isEqualTo(ev1.id()); + assertThat(session.events().get(2).id()).isEqualTo(ev2.id()); + } }