diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiClient.java b/core/src/main/java/com/google/adk/sessions/VertexAiClient.java index 21239db3..d35bbcca 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiClient.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiClient.java @@ -9,12 +9,16 @@ import com.google.common.base.Splitter; import com.google.common.collect.Iterables; import com.google.genai.types.HttpOptions; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; import java.io.IOException; import java.io.UncheckedIOException; import java.util.List; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.TimeoutException; import javax.annotation.Nullable; import okhttp3.ResponseBody; import org.slf4j.Logger; @@ -46,8 +50,7 @@ final class VertexAiClient { new HttpApiClient(Optional.of(project), Optional.of(location), credentials, httpOptions); } - @Nullable - JsonNode createSession( + Maybe createSession( String reasoningEngineId, String userId, ConcurrentMap state) { ConcurrentHashMap sessionJsonMap = new ConcurrentHashMap<>(); sessionJsonMap.put("userId", userId); @@ -55,95 +58,116 @@ JsonNode createSession( sessionJsonMap.put("sessionState", state); } - String sessId; - String operationId; - try { - String sessionJson = objectMapper.writeValueAsString(sessionJsonMap); - try (ApiResponse apiResponse = - apiClient.request( - "POST", "reasoningEngines/" + reasoningEngineId + "/sessions", sessionJson)) { - logger.debug("Create Session response {}", apiResponse.getResponseBody()); - if (apiResponse == null || apiResponse.getResponseBody() == null) { - return null; - } - - JsonNode jsonResponse = getJsonResponse(apiResponse); - if (jsonResponse == null) { - return null; - } - String sessionName = jsonResponse.get("name").asText(); - List parts = Splitter.on('/').splitToList(sessionName); - sessId = parts.get(parts.size() - 3); - operationId = Iterables.getLast(parts); - } - } catch (IOException e) { - throw new UncheckedIOException(e); - } + return Single.fromCallable(() -> objectMapper.writeValueAsString(sessionJsonMap)) + .flatMap( + sessionJson -> + performApiRequest( + "POST", "reasoningEngines/" + reasoningEngineId + "/sessions", sessionJson)) + .flatMapMaybe( + apiResponse -> { + logger.debug("Create Session response {}", apiResponse.getResponseBody()); + return getJsonResponse(apiResponse); + }) + .flatMap( + jsonResponse -> { + String sessionName = jsonResponse.get("name").asText(); + List parts = Splitter.on('/').splitToList(sessionName); + String sessId = parts.get(parts.size() - 3); + String operationId = Iterables.getLast(parts); + + return pollOperation(operationId, 0).andThen(getSession(reasoningEngineId, sessId)); + }); + } - for (int i = 0; i < MAX_RETRY_ATTEMPTS; i++) { - try (ApiResponse lroResponse = apiClient.request("GET", "operations/" + operationId, "")) { - JsonNode lroJsonResponse = getJsonResponse(lroResponse); - if (lroJsonResponse != null && lroJsonResponse.get("done") != null) { - break; - } - } - try { - SECONDS.sleep(1); - } catch (InterruptedException e) { - logger.warn("Error during sleep", e); - Thread.currentThread().interrupt(); - } + /** + * Polls the status of a long-running operation. + * + * @param operationId The ID of the operation to poll. + * @param attempt The current retry attempt number (starting from 0). + * @return A Completable that completes when the operation is done, or errors with + * TimeoutException if max retries are exceeded. + */ + private Completable pollOperation(String operationId, int attempt) { + if (attempt >= MAX_RETRY_ATTEMPTS) { + return Completable.error( + new TimeoutException("Operation " + operationId + " did not complete in time.")); } - return getSession(reasoningEngineId, sessId); + return performApiRequest("GET", "operations/" + operationId, "") + .flatMapMaybe(VertexAiClient::getJsonResponse) + .flatMapCompletable( + lroJsonResponse -> { + if (lroJsonResponse != null && lroJsonResponse.get("done") != null) { + return Completable.complete(); // Operation is done + } else { + // Not done, retry after a delay + return Completable.timer(1, SECONDS) + .andThen(pollOperation(operationId, attempt + 1)); + } + }); } - JsonNode listSessions(String reasoningEngineId, String userId) { - try (ApiResponse apiResponse = - apiClient.request( + Maybe listSessions(String reasoningEngineId, String userId) { + return performApiRequest( "GET", "reasoningEngines/" + reasoningEngineId + "/sessions?filter=user_id=" + userId, - "")) { - return getJsonResponse(apiResponse); - } + "") + .flatMapMaybe(VertexAiClient::getJsonResponse); } - JsonNode listEvents(String reasoningEngineId, String sessionId) { - try (ApiResponse apiResponse = - apiClient.request( + Maybe listEvents(String reasoningEngineId, String sessionId) { + return performApiRequest( "GET", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId + "/events", - "")) { - logger.debug("List events response {}", apiResponse); - return getJsonResponse(apiResponse); - } + "") + .doOnSuccess(apiResponse -> logger.debug("List events response {}", apiResponse)) + .flatMapMaybe(VertexAiClient::getJsonResponse); } - JsonNode getSession(String reasoningEngineId, String sessionId) { - try (ApiResponse apiResponse = - apiClient.request( - "GET", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "")) { - return getJsonResponse(apiResponse); - } + Maybe getSession(String reasoningEngineId, String sessionId) { + return performApiRequest( + "GET", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "") + .flatMapMaybe(apiResponse -> getJsonResponse(apiResponse)); } - void deleteSession(String reasoningEngineId, String sessionId) { - try (ApiResponse response = - apiClient.request( - "DELETE", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "")) {} + Completable deleteSession(String reasoningEngineId, String sessionId) { + return performApiRequest( + "DELETE", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "") + .doOnSuccess(ApiResponse::close) + .ignoreElement(); } - void appendEvent(String reasoningEngineId, String sessionId, String eventJson) { - try (ApiResponse response = - apiClient.request( + Completable appendEvent(String reasoningEngineId, String sessionId, String eventJson) { + return performApiRequest( "POST", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId + ":appendEvent", - eventJson)) { - if (response.getResponseBody().string().contains("com.google.genai.errors.ClientException")) { - logger.warn("Failed to append event: {}", eventJson); - } - } catch (IOException e) { - throw new UncheckedIOException(e); - } + eventJson) + .flatMapCompletable( + response -> { + try (response) { + ResponseBody responseBody = response.getResponseBody(); + if (responseBody != null) { + String responseString = responseBody.string(); + if (responseString.contains("com.google.genai.errors.ClientException")) { + logger.warn("Failed to append event: {}", eventJson); + } + } + return Completable.complete(); + } catch (IOException e) { + return Completable.error(new UncheckedIOException(e)); + } + }); + } + + /** + * Performs an API request and returns a Single emitting the ApiResponse. + * + *

Note: The caller is responsible for closing the returned {@link ApiResponse}. + */ + private Single performApiRequest(String method, String path, String body) { + return Single.fromCallable( + () -> { + return apiClient.request(method, path, body); + }); } /** @@ -152,19 +176,23 @@ void appendEvent(String reasoningEngineId, String sessionId, String eventJson) { * @throws UncheckedIOException if parsing fails. */ @Nullable - private static JsonNode getJsonResponse(ApiResponse apiResponse) { - if (apiResponse == null || apiResponse.getResponseBody() == null) { - return null; - } + private static Maybe getJsonResponse(ApiResponse apiResponse) { try { - ResponseBody responseBody = apiResponse.getResponseBody(); - String responseString = responseBody.string(); - if (responseString.isEmpty()) { - return null; + if (apiResponse == null || apiResponse.getResponseBody() == null) { + return Maybe.empty(); + } + try { + ResponseBody responseBody = apiResponse.getResponseBody(); + String responseString = responseBody.string(); // Read body here + if (responseString.isEmpty()) { + return Maybe.empty(); + } + return Maybe.just(objectMapper.readTree(responseString)); + } catch (IOException e) { + return Maybe.error(new UncheckedIOException(e)); } - return objectMapper.readTree(responseString); - } catch (IOException e) { - throw new UncheckedIOException(e); + } finally { + apiResponse.close(); } } } diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java index 2e0934be..0321ef28 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java @@ -78,11 +78,20 @@ public Single createSession( @Nullable String sessionId) { String reasoningEngineId = parseReasoningEngineId(appName); - JsonNode getSessionResponseMap = client.createSession(reasoningEngineId, userId, state); + return client + .createSession(reasoningEngineId, userId, state) + .map( + getSessionResponseMap -> + parseSession(getSessionResponseMap, appName, userId, sessionId)) + .toSingle(); + } + + private static Session parseSession( + JsonNode getSessionResponseMap, String appName, String userId, String fallbackSessionId) { String sessId = Optional.ofNullable(getSessionResponseMap.get("name")) .map(name -> Iterables.getLast(Splitter.on('/').splitToList(name.asText()))) - .orElse(sessionId); + .orElse(fallbackSessionId); Instant updateTimestamp = Instant.parse(getSessionResponseMap.get("updateTime").asText()); ConcurrentMap sessionState = null; if (getSessionResponseMap != null && getSessionResponseMap.has("sessionState")) { @@ -93,25 +102,28 @@ public Single createSession( sessionStateNode, new TypeReference>() {}); } } - return Single.just( - Session.builder(sessId) - .appName(appName) - .userId(userId) - .lastUpdateTime(updateTimestamp) - .state(sessionState == null ? new ConcurrentHashMap<>() : sessionState) - .build()); + return Session.builder(sessId) + .appName(appName) + .userId(userId) + .lastUpdateTime(updateTimestamp) + .state(sessionState == null ? new ConcurrentHashMap<>() : sessionState) + .build(); } @Override public Single listSessions(String appName, String userId) { String reasoningEngineId = parseReasoningEngineId(appName); - JsonNode listSessionsResponseMap = client.listSessions(reasoningEngineId, userId); + return client + .listSessions(reasoningEngineId, userId) + .map( + listSessionsResponseMap -> + parseListSessionsResponse(listSessionsResponseMap, appName, userId)) + .defaultIfEmpty(ListSessionsResponse.builder().build()); + } - // Handles empty response case - if (listSessionsResponseMap == null) { - return Single.just(ListSessionsResponse.builder().build()); - } + private ListSessionsResponse parseListSessionsResponse( + JsonNode listSessionsResponseMap, String appName, String userId) { List> apiSessions = objectMapper.convertValue( listSessionsResponseMap.get("sessions"), @@ -131,125 +143,132 @@ public Single listSessions(String appName, String userId) .build(); sessions.add(session); } - return Single.just(ListSessionsResponse.builder().sessions(sessions).build()); + return ListSessionsResponse.builder().sessions(sessions).build(); } @Override public Single listEvents(String appName, String userId, String sessionId) { String reasoningEngineId = parseReasoningEngineId(appName); - JsonNode listEventsResponse = client.listEvents(reasoningEngineId, sessionId); - - if (listEventsResponse == null) { - return Single.just(ListEventsResponse.builder().build()); - } + return client + .listEvents(reasoningEngineId, sessionId) + .map(this::parseListEventsResponse) + .defaultIfEmpty(ListEventsResponse.builder().build()); + } + private ListEventsResponse parseListEventsResponse(JsonNode listEventsResponse) { JsonNode sessionEventsNode = listEventsResponse.get("sessionEvents"); if (sessionEventsNode == null || sessionEventsNode.isEmpty()) { - return Single.just(ListEventsResponse.builder().events(new ArrayList<>()).build()); + return ListEventsResponse.builder().events(new ArrayList<>()).build(); } - return Single.just( - ListEventsResponse.builder() - .events( - objectMapper - .convertValue( - sessionEventsNode, - new TypeReference>>() {}) - .stream() - .map(SessionJsonConverter::fromApiEvent) - .collect(toCollection(ArrayList::new))) - .build()); + return ListEventsResponse.builder() + .events( + objectMapper + .convertValue( + sessionEventsNode, new TypeReference>>() {}) + .stream() + .map(SessionJsonConverter::fromApiEvent) + .collect(toCollection(ArrayList::new))) + .build(); } @Override public Maybe getSession( String appName, String userId, String sessionId, Optional config) { String reasoningEngineId = parseReasoningEngineId(appName); - JsonNode getSessionResponseMap = client.getSession(reasoningEngineId, sessionId); + return client + .getSession(reasoningEngineId, sessionId) + .flatMap( + getSessionResponseMap -> { + String sessId = + Optional.ofNullable(getSessionResponseMap.get("name")) + .map(name -> Iterables.getLast(Splitter.on('/').splitToList(name.asText()))) + .orElse(sessionId); + Instant updateTimestamp = + Optional.ofNullable(getSessionResponseMap.get("updateTime")) + .map(updateTime -> Instant.parse(updateTime.asText())) + .orElse(null); - if (getSessionResponseMap == null) { - return Maybe.empty(); - } - - String sessId = - Optional.ofNullable(getSessionResponseMap.get("name")) - .map(name -> Iterables.getLast(Splitter.on('/').splitToList(name.asText()))) - .orElse(sessionId); - Instant updateTimestamp = - Optional.ofNullable(getSessionResponseMap.get("updateTime")) - .map(updateTime -> Instant.parse(updateTime.asText())) - .orElse(null); + ConcurrentMap sessionState = new ConcurrentHashMap<>(); + if (getSessionResponseMap != null && getSessionResponseMap.has("sessionState")) { + sessionState.putAll( + objectMapper.convertValue( + getSessionResponseMap.get("sessionState"), + new TypeReference>() {})); + } - ConcurrentMap sessionState = new ConcurrentHashMap<>(); - if (getSessionResponseMap != null && getSessionResponseMap.has("sessionState")) { - sessionState.putAll( - objectMapper.convertValue( - getSessionResponseMap.get("sessionState"), - new TypeReference>() {})); - } + return listEvents(appName, userId, sessionId) + .map( + response -> { + Session.Builder sessionBuilder = + Session.builder(sessId) + .appName(appName) + .userId(userId) + .lastUpdateTime(updateTimestamp) + .state(sessionState); + List events = response.events(); + if (events.isEmpty()) { + return sessionBuilder.build(); + } + events = filterEvents(events, updateTimestamp, config); + return sessionBuilder.events(events).build(); + }) + .toMaybe(); + }); + } - return listEvents(appName, userId, sessionId) - .map( - response -> { - Session.Builder sessionBuilder = - Session.builder(sessId) - .appName(appName) - .userId(userId) - .lastUpdateTime(updateTimestamp) - .state(sessionState); - List events = response.events(); - if (events.isEmpty()) { - return sessionBuilder.build(); - } - events = - events.stream() - .filter( - event -> - updateTimestamp == null - || Instant.ofEpochMilli(event.timestamp()) - .isBefore(updateTimestamp)) - .sorted(Comparator.comparing(Event::timestamp)) - .collect(toCollection(ArrayList::new)); + private static List filterEvents( + List originalEvents, + @Nullable Instant updateTimestamp, + Optional config) { + List events = + originalEvents.stream() + .filter( + event -> + updateTimestamp == null + || Instant.ofEpochMilli(event.timestamp()).isBefore(updateTimestamp)) + .sorted(Comparator.comparing(Event::timestamp)) + .collect(toCollection(ArrayList::new)); - if (config.isPresent()) { - if (config.get().numRecentEvents().isPresent()) { - int numRecentEvents = config.get().numRecentEvents().get(); - if (events.size() > numRecentEvents) { - events = events.subList(events.size() - numRecentEvents, events.size()); - } - } else if (config.get().afterTimestamp().isPresent()) { - Instant afterTimestamp = config.get().afterTimestamp().get(); - int i = events.size() - 1; - while (i >= 0) { - if (Instant.ofEpochMilli(events.get(i).timestamp()).isBefore(afterTimestamp)) { - break; - } - i -= 1; - } - if (i >= 0) { - events = events.subList(i, events.size()); - } - } - } - return sessionBuilder.events(events).build(); - }) - .toMaybe(); + if (config.isPresent()) { + if (config.get().numRecentEvents().isPresent()) { + int numRecentEvents = config.get().numRecentEvents().get(); + if (events.size() > numRecentEvents) { + events = events.subList(events.size() - numRecentEvents, events.size()); + } + } else if (config.get().afterTimestamp().isPresent()) { + Instant afterTimestamp = config.get().afterTimestamp().get(); + int i = events.size() - 1; + while (i >= 0) { + if (Instant.ofEpochMilli(events.get(i).timestamp()).isBefore(afterTimestamp)) { + break; + } + i -= 1; + } + if (i >= 0) { + events = events.subList(i, events.size()); + } + } + } + return events; } @Override public Completable deleteSession(String appName, String userId, String sessionId) { String reasoningEngineId = parseReasoningEngineId(appName); - client.deleteSession(reasoningEngineId, sessionId); - return Completable.complete(); + return client.deleteSession(reasoningEngineId, sessionId); } @Override public Single appendEvent(Session session, Event event) { - BaseSessionService.super.appendEvent(session, event); - String reasoningEngineId = parseReasoningEngineId(session.appName()); - client.appendEvent( - reasoningEngineId, session.id(), SessionJsonConverter.convertEventToJson(event)); - return Single.just(event); + return BaseSessionService.super + .appendEvent(session, event) + .flatMap( + e -> + client + .appendEvent( + reasoningEngineId, session.id(), SessionJsonConverter.convertEventToJson(e)) + .toSingleDefault(e)); } /** diff --git a/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java b/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java index 5e8f3d99..bd2dd2b9 100644 --- a/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java +++ b/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java @@ -8,6 +8,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Random; import java.util.concurrent.ConcurrentMap; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -32,6 +33,7 @@ class MockApiAnswer implements Answer { Pattern.compile("^reasoningEngines/([^/]+)/sessions/([^/]+)/events$"); private static final MediaType JSON_MEDIA_TYPE = MediaType.parse("application/json; charset=utf-8"); + private static final Random random = new Random(); private final Map sessionMap; private final Map eventMap; @@ -193,15 +195,18 @@ private ApiResponse handleGetEvents(String path) throws Exception { } private ApiResponse handleGetLro(String path) { + // Simulate LRO being done 50% of the time. + boolean done = random.nextBoolean(); + String doneStr = done ? ", \"done\": true" : ""; return responseWithBody( String.format( """ { - "name": "%s", - "done": true + "name": "%s" + %s } """, - path.replace("/operations/111", ""))); // Simulate LRO done + path.replace("/operations/111", ""), doneStr)); } private ApiResponse handleDeleteSession(String path) {