|
29 | 29 | import static org.mockito.Mockito.when; |
30 | 30 |
|
31 | 31 | import com.google.adk.Telemetry; |
| 32 | +import com.google.adk.agents.BaseAgent; |
32 | 33 | import com.google.adk.agents.InvocationContext; |
33 | 34 | import com.google.adk.agents.LiveRequestQueue; |
34 | 35 | import com.google.adk.agents.LlmAgent; |
|
37 | 38 | import com.google.adk.flows.llmflows.ResumabilityConfig; |
38 | 39 | import com.google.adk.models.LlmResponse; |
39 | 40 | import com.google.adk.plugins.BasePlugin; |
| 41 | +import com.google.adk.sessions.BaseSessionService; |
| 42 | +import com.google.adk.sessions.GetSessionConfig; |
| 43 | +import com.google.adk.sessions.ListEventsResponse; |
| 44 | +import com.google.adk.sessions.ListSessionsResponse; |
40 | 45 | import com.google.adk.sessions.Session; |
41 | 46 | import com.google.adk.testing.TestLlm; |
42 | 47 | import com.google.adk.testing.TestUtils; |
|
53 | 58 | import io.reactivex.rxjava3.core.Completable; |
54 | 59 | import io.reactivex.rxjava3.core.Flowable; |
55 | 60 | import io.reactivex.rxjava3.core.Maybe; |
| 61 | +import io.reactivex.rxjava3.core.Single; |
| 62 | +import io.reactivex.rxjava3.schedulers.Schedulers; |
56 | 63 | import io.reactivex.rxjava3.subscribers.TestSubscriber; |
57 | 64 | import java.util.List; |
58 | 65 | import java.util.Objects; |
59 | 66 | import java.util.Optional; |
| 67 | +import java.util.UUID; |
60 | 68 | import java.util.concurrent.ConcurrentHashMap; |
| 69 | +import java.util.concurrent.ConcurrentMap; |
| 70 | + |
| 71 | +import java.util.function.Consumer; |
| 72 | +import java.util.function.Supplier; |
| 73 | +import javax.annotation.Nullable; |
61 | 74 | import org.junit.After; |
62 | 75 | import org.junit.Before; |
63 | 76 | import org.junit.Rule; |
@@ -639,7 +652,7 @@ public void beforeRunCallback_withStateDelta_seesMergedState() { |
639 | 652 | assertThat(sessionInCallback.state()).containsEntry("number", 123); |
640 | 653 | } |
641 | 654 |
|
642 | | - private Content createContent(String text) { |
| 655 | + static Content createContent(String text) { |
643 | 656 | return Content.builder().parts(Part.builder().text(text).build()).build(); |
644 | 657 | } |
645 | 658 |
|
@@ -791,4 +804,141 @@ public void resumabilityConfig_isNotResumable_isFalseInInvocationContext() { |
791 | 804 | runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); |
792 | 805 | assertThat(contextCaptor.getValue().isResumable()).isFalse(); |
793 | 806 | } |
| 807 | + |
| 808 | + static class TestAgent extends BaseAgent { |
| 809 | + private final Supplier<Flowable<Event>> eventSupplier; |
| 810 | + |
| 811 | + public TestAgent(String name, String description, Supplier<Flowable<Event>> eventSupplier) { |
| 812 | + super(name, description, ImmutableList.of(), ImmutableList.of(), ImmutableList.of()); |
| 813 | + this.eventSupplier = eventSupplier; |
| 814 | + } |
| 815 | + |
| 816 | + @Override |
| 817 | + public Flowable<Event> runAsyncImpl(InvocationContext context) { |
| 818 | + return eventSupplier.get(); |
| 819 | + } |
| 820 | + |
| 821 | + @Override |
| 822 | + public Flowable<Event> runLiveImpl(InvocationContext context) { |
| 823 | + throw new UnsupportedOperationException("runLiveImpl not supported in this test"); |
| 824 | + } |
| 825 | + } |
| 826 | + |
| 827 | + static class FakeSessionService implements BaseSessionService { |
| 828 | + private static final String SESSION_ID = "1234"; |
| 829 | + private final Session session; |
| 830 | + private final Consumer<Event> onAppendEventFn; |
| 831 | + |
| 832 | + FakeSessionService(Consumer<Event> onAppendEventFn) { |
| 833 | + this.session = Session.builder(SESSION_ID).build(); |
| 834 | + this.onAppendEventFn = onAppendEventFn; |
| 835 | + } |
| 836 | + |
| 837 | + @Override |
| 838 | + public Single<Event> appendEvent(Session session, Event event) { |
| 839 | + return Flowable.defer( |
| 840 | + () -> { |
| 841 | + this.onAppendEventFn.accept(event); |
| 842 | + |
| 843 | + synchronized (this) { |
| 844 | + session.events().add(event); |
| 845 | + } |
| 846 | + return Flowable.just(event); |
| 847 | + }) |
| 848 | + .firstElement() |
| 849 | + .toSingle() |
| 850 | + // Run this in a separate thread, to unblock the main thread that processes the events. |
| 851 | + .subscribeOn(Schedulers.io()); |
| 852 | + } |
| 853 | + |
| 854 | + @Override |
| 855 | + public Maybe<Session> getSession( |
| 856 | + String appName, String userId, String sessionId, Optional<GetSessionConfig> configOpt) { |
| 857 | + if (sessionId.equals(SESSION_ID)) { |
| 858 | + return Maybe.just(session); |
| 859 | + } |
| 860 | + return Maybe.empty(); |
| 861 | + } |
| 862 | + |
| 863 | + @Override |
| 864 | + public Single<ListSessionsResponse> listSessions(String appName, String userId) { |
| 865 | + return Single.just( |
| 866 | + ListSessionsResponse.builder().sessions(ImmutableList.of(this.session)).build()); |
| 867 | + } |
| 868 | + |
| 869 | + @Override |
| 870 | + public Single<ListEventsResponse> listEvents(String appName, String userId, String sessionId) { |
| 871 | + return getSession(appName, userId, sessionId, Optional.empty()) |
| 872 | + .map( |
| 873 | + session -> |
| 874 | + ListEventsResponse.builder() |
| 875 | + .events(ImmutableList.copyOf(session.events())) |
| 876 | + .build()) |
| 877 | + .toSingle(); |
| 878 | + } |
| 879 | + |
| 880 | + @Override |
| 881 | + public Single<Session> createSession( |
| 882 | + String appName, |
| 883 | + String userId, |
| 884 | + @Nullable ConcurrentMap<String, Object> state, |
| 885 | + @Nullable String sessionId) { |
| 886 | + throw new UnsupportedOperationException("createSession not supported in this test"); |
| 887 | + } |
| 888 | + |
| 889 | + @Override |
| 890 | + public Completable deleteSession(String appName, String userId, String sessionId) { |
| 891 | + throw new UnsupportedOperationException("deleteSession not supported in this test"); |
| 892 | + } |
| 893 | + } |
| 894 | + |
| 895 | + @Test |
| 896 | + public void runAsync_sessionService_appendsEventsInCorrectOrder() throws Exception { |
| 897 | + // Arrange |
| 898 | + String invocationId = UUID.randomUUID().toString(); |
| 899 | + Event ev1 = |
| 900 | + Event.builder() |
| 901 | + .id("1") |
| 902 | + .invocationId(invocationId) |
| 903 | + .author("model") |
| 904 | + .content(Optional.of(createContent("event 1"))) |
| 905 | + .build(); |
| 906 | + Event ev2 = |
| 907 | + Event.builder() |
| 908 | + .id("2") |
| 909 | + .invocationId(invocationId) |
| 910 | + .author("model") |
| 911 | + .content(Optional.of(createContent("event 2"))) |
| 912 | + .build(); |
| 913 | + TestAgent testAgent = new TestAgent("test agent", "description", () -> Flowable.just(ev1, ev2)); |
| 914 | + FakeSessionService laggingSessionService = |
| 915 | + new FakeSessionService( |
| 916 | + (event) -> { |
| 917 | + if (event.id().equals(ev1.id())) { |
| 918 | + try { |
| 919 | + // Lags completion of appendEvent() on the first event (id=1) from an agent, while |
| 920 | + // the second event (id=2) will be immediately appended. |
| 921 | + Thread.sleep(2000); |
| 922 | + } catch (InterruptedException e) { |
| 923 | + throw new RuntimeException(e); |
| 924 | + } |
| 925 | + } |
| 926 | + }); |
| 927 | + Runner runner = |
| 928 | + Runner.builder() |
| 929 | + .agent(testAgent) |
| 930 | + .appName("test") |
| 931 | + .sessionService(laggingSessionService) |
| 932 | + .build(); |
| 933 | + Session session = laggingSessionService.session; |
| 934 | + |
| 935 | + // Act |
| 936 | + var unused = |
| 937 | + runner.runAsync("user", session.id(), createContent("from user")).toList().blockingGet(); |
| 938 | + |
| 939 | + // Assert that events are stored in the correct order. |
| 940 | + assertThat(session.events()).hasSize(3); |
| 941 | + assertThat(session.events().get(1).id()).isEqualTo(ev1.id()); |
| 942 | + assertThat(session.events().get(2).id()).isEqualTo(ev2.id()); |
| 943 | + } |
794 | 944 | } |
0 commit comments