Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 17 additions & 19 deletions core/src/main/java/com/google/adk/runner/Runner.java
Original file line number Diff line number Diff line change
Expand Up @@ -459,29 +459,27 @@ public Flowable<Event> runAsync(
contextWithUpdatedSession
.agent()
.runAsync(contextWithUpdatedSession)
.flatMap(
.concatMap(
agentEvent ->
this.sessionService
.appendEvent(
updatedSession, agentEvent)
.flatMap(
registeredEvent -> {
// TODO: remove this hack
// after
// deprecating runAsync with
// Session.
copySessionStates(
updatedSession,
session);
return contextWithUpdatedSession
.pluginManager()
.onEventCallback(
contextWithUpdatedSession,
registeredEvent)
.defaultIfEmpty(
registeredEvent);
})
.toFlowable());
.toFlowable())
.concatMap(
registeredEvent -> {
// TODO: remove this hack after
// deprecating runAsync with
// Session.
copySessionStates(
updatedSession, session);
return contextWithUpdatedSession
.pluginManager()
.onEventCallback(
contextWithUpdatedSession,
registeredEvent)
.defaultIfEmpty(registeredEvent)
.toFlowable();
});

// If beforeRunCallback returns content, emit it and
// skip
Expand Down
152 changes: 151 additions & 1 deletion core/src/test/java/com/google/adk/runner/RunnerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import static org.mockito.Mockito.when;

import com.google.adk.Telemetry;
import com.google.adk.agents.BaseAgent;
import com.google.adk.agents.InvocationContext;
import com.google.adk.agents.LiveRequestQueue;
import com.google.adk.agents.LlmAgent;
Expand All @@ -37,6 +38,10 @@
import com.google.adk.flows.llmflows.ResumabilityConfig;
import com.google.adk.models.LlmResponse;
import com.google.adk.plugins.BasePlugin;
import com.google.adk.sessions.BaseSessionService;
import com.google.adk.sessions.GetSessionConfig;
import com.google.adk.sessions.ListEventsResponse;
import com.google.adk.sessions.ListSessionsResponse;
import com.google.adk.sessions.Session;
import com.google.adk.testing.TestLlm;
import com.google.adk.testing.TestUtils;
Expand All @@ -53,11 +58,19 @@
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.schedulers.Schedulers;
import io.reactivex.rxjava3.subscribers.TestSubscriber;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

import java.util.function.Consumer;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
Expand Down Expand Up @@ -639,7 +652,7 @@ public void beforeRunCallback_withStateDelta_seesMergedState() {
assertThat(sessionInCallback.state()).containsEntry("number", 123);
}

private Content createContent(String text) {
static Content createContent(String text) {
return Content.builder().parts(Part.builder().text(text).build()).build();
}

Expand Down Expand Up @@ -791,4 +804,141 @@ public void resumabilityConfig_isNotResumable_isFalseInInvocationContext() {
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();
assertThat(contextCaptor.getValue().isResumable()).isFalse();
}

static class TestAgent extends BaseAgent {
private final Supplier<Flowable<Event>> eventSupplier;

public TestAgent(String name, String description, Supplier<Flowable<Event>> eventSupplier) {
super(name, description, ImmutableList.of(), ImmutableList.of(), ImmutableList.of());
this.eventSupplier = eventSupplier;
}

@Override
public Flowable<Event> runAsyncImpl(InvocationContext context) {
return eventSupplier.get();
}

@Override
public Flowable<Event> runLiveImpl(InvocationContext context) {
throw new UnsupportedOperationException("runLiveImpl not supported in this test");
}
}

static class FakeSessionService implements BaseSessionService {
private static final String SESSION_ID = "1234";
private final Session session;
private final Consumer<Event> onAppendEventFn;

FakeSessionService(Consumer<Event> onAppendEventFn) {
this.session = Session.builder(SESSION_ID).build();
this.onAppendEventFn = onAppendEventFn;
}

@Override
public Single<Event> appendEvent(Session session, Event event) {
return Flowable.defer(
() -> {
this.onAppendEventFn.accept(event);

synchronized (this) {
session.events().add(event);
}
return Flowable.just(event);
})
.firstElement()
.toSingle()
// Run this in a separate thread, to unblock the main thread that processes the events.
.subscribeOn(Schedulers.io());
}

@Override
public Maybe<Session> getSession(
String appName, String userId, String sessionId, Optional<GetSessionConfig> configOpt) {
if (sessionId.equals(SESSION_ID)) {
return Maybe.just(session);
}
return Maybe.empty();
}

@Override
public Single<ListSessionsResponse> listSessions(String appName, String userId) {
return Single.just(
ListSessionsResponse.builder().sessions(ImmutableList.of(this.session)).build());
}

@Override
public Single<ListEventsResponse> listEvents(String appName, String userId, String sessionId) {
return getSession(appName, userId, sessionId, Optional.empty())
.map(
session ->
ListEventsResponse.builder()
.events(ImmutableList.copyOf(session.events()))
.build())
.toSingle();
}

@Override
public Single<Session> createSession(
String appName,
String userId,
@Nullable ConcurrentMap<String, Object> state,
@Nullable String sessionId) {
throw new UnsupportedOperationException("createSession not supported in this test");
}

@Override
public Completable deleteSession(String appName, String userId, String sessionId) {
throw new UnsupportedOperationException("deleteSession not supported in this test");
}
}

@Test
public void runAsync_sessionService_appendsEventsInCorrectOrder() throws Exception {
// Arrange
String invocationId = UUID.randomUUID().toString();
Event ev1 =
Event.builder()
.id("1")
.invocationId(invocationId)
.author("model")
.content(Optional.of(createContent("event 1")))
.build();
Event ev2 =
Event.builder()
.id("2")
.invocationId(invocationId)
.author("model")
.content(Optional.of(createContent("event 2")))
.build();
TestAgent testAgent = new TestAgent("test agent", "description", () -> Flowable.just(ev1, ev2));
FakeSessionService laggingSessionService =
new FakeSessionService(
(event) -> {
if (event.id().equals(ev1.id())) {
try {
// Lags completion of appendEvent() on the first event (id=1) from an agent, while
// the second event (id=2) will be immediately appended.
Thread.sleep(2000);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
});
Runner runner =
Runner.builder()
.agent(testAgent)
.appName("test")
.sessionService(laggingSessionService)
.build();
Session session = laggingSessionService.session;

// Act
var unused =
runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet();

// Assert that events are stored in the correct order.
assertThat(session.events()).hasSize(3);
assertThat(session.events().get(1).id()).isEqualTo(ev1.id());
assertThat(session.events().get(2).id()).isEqualTo(ev2.id());
}
}
Loading