diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index fe0306fa..8d70f6c4 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -18,8 +18,10 @@ import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.exceptions.LlmCallsLimitExceededException; +import com.google.adk.memory.BaseMemoryService; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; +import com.google.errorprone.annotations.InlineMe; import com.google.genai.types.Content; import java.util.Map; import java.util.Objects; @@ -33,6 +35,7 @@ public class InvocationContext { private final BaseSessionService sessionService; private final BaseArtifactService artifactService; + private final BaseMemoryService memoryService; private final Optional liveRequestQueue; private final Map activeStreamingTools = new ConcurrentHashMap<>(); @@ -46,9 +49,10 @@ public class InvocationContext { private boolean endInvocation; private final InvocationCostManager invocationCostManager = new InvocationCostManager(); - private InvocationContext( + public InvocationContext( BaseSessionService sessionService, BaseArtifactService artifactService, + BaseMemoryService memoryService, Optional liveRequestQueue, Optional branch, String invocationId, @@ -59,6 +63,7 @@ private InvocationContext( boolean endInvocation) { this.sessionService = sessionService; this.artifactService = artifactService; + this.memoryService = memoryService; this.liveRequestQueue = liveRequestQueue; this.branch = branch; this.invocationId = invocationId; @@ -69,6 +74,16 @@ private InvocationContext( this.endInvocation = endInvocation; } + /** + * @deprecated Use the {@link #InvocationContext} constructor directly instead + */ + @InlineMe( + replacement = + "new InvocationContext(sessionService, artifactService, null, Optional.empty()," + + " Optional.empty(), invocationId, agent, session, Optional.ofNullable(userContent)," + + " runConfig, false)", + imports = {"com.google.adk.agents.InvocationContext", "java.util.Optional"}) + @Deprecated public static InvocationContext create( BaseSessionService sessionService, BaseArtifactService artifactService, @@ -80,7 +95,8 @@ public static InvocationContext create( return new InvocationContext( sessionService, artifactService, - Optional.empty(), + /* memoryService= */ null, + /* liveRequestQueue= */ Optional.empty(), /* branch= */ Optional.empty(), invocationId, agent, @@ -90,6 +106,17 @@ public static InvocationContext create( false); } + /** + * @deprecated Use the {@link #InvocationContext} constructor directly instead + */ + @InlineMe( + replacement = + "new InvocationContext(sessionService, artifactService, null," + + " Optional.ofNullable(liveRequestQueue), Optional.empty()," + + " InvocationContext.newInvocationContextId(), agent, session, Optional.empty()," + + " runConfig, false)", + imports = {"com.google.adk.agents.InvocationContext", "java.util.Optional"}) + @Deprecated public static InvocationContext create( BaseSessionService sessionService, BaseArtifactService artifactService, @@ -100,6 +127,7 @@ public static InvocationContext create( return new InvocationContext( sessionService, artifactService, + /* memoryService= */ null, Optional.ofNullable(liveRequestQueue), /* branch= */ Optional.empty(), InvocationContext.newInvocationContextId(), @@ -115,6 +143,7 @@ public static InvocationContext copyOf(InvocationContext other) { new InvocationContext( other.sessionService, other.artifactService, + other.memoryService, other.liveRequestQueue, other.branch, other.invocationId, @@ -135,6 +164,10 @@ public BaseArtifactService artifactService() { return artifactService; } + public BaseMemoryService memoryService() { + return memoryService; + } + public Map activeStreamingTools() { return activeStreamingTools; } @@ -226,6 +259,7 @@ public boolean equals(Object o) { return endInvocation == that.endInvocation && Objects.equals(sessionService, that.sessionService) && Objects.equals(artifactService, that.artifactService) + && Objects.equals(memoryService, that.memoryService) && Objects.equals(liveRequestQueue, that.liveRequestQueue) && Objects.equals(activeStreamingTools, that.activeStreamingTools) && Objects.equals(branch, that.branch) @@ -241,6 +275,7 @@ public int hashCode() { return Objects.hash( sessionService, artifactService, + memoryService, liveRequestQueue, activeStreamingTools, branch, diff --git a/core/src/main/java/com/google/adk/memory/InMemoryMemoryService.java b/core/src/main/java/com/google/adk/memory/InMemoryMemoryService.java index 679d0359..caa5075f 100644 --- a/core/src/main/java/com/google/adk/memory/InMemoryMemoryService.java +++ b/core/src/main/java/com/google/adk/memory/InMemoryMemoryService.java @@ -118,9 +118,9 @@ public Single searchMemory(String appName, String userId, if (!Collections.disjoint(wordsInQuery, wordsInEvent)) { MemoryEntry memory = MemoryEntry.builder() - .setContent(event.content().get()) - .setAuthor(event.author()) - .setTimestamp(formatTimestamp(event.timestamp())) + .content(event.content().get()) + .author(event.author()) + .timestamp(formatTimestamp(event.timestamp())) .build(); matchingMemories.add(memory); } diff --git a/core/src/main/java/com/google/adk/memory/MemoryEntry.java b/core/src/main/java/com/google/adk/memory/MemoryEntry.java index e9afb0fb..ef310f62 100644 --- a/core/src/main/java/com/google/adk/memory/MemoryEntry.java +++ b/core/src/main/java/com/google/adk/memory/MemoryEntry.java @@ -16,6 +16,9 @@ package com.google.adk.memory; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.google.auto.value.AutoValue; import com.google.genai.types.Content; import java.time.Instant; @@ -23,13 +26,16 @@ /** Represents one memory entry. */ @AutoValue +@JsonDeserialize(builder = MemoryEntry.Builder.class) public abstract class MemoryEntry { /** Returns the main content of the memory. */ + @JsonProperty("content") public abstract Content content(); /** Returns the author of the memory, or null if not set. */ @Nullable + @JsonProperty("author") public abstract String author(); /** @@ -56,18 +62,26 @@ public static Builder builder() { @AutoValue.Builder public abstract static class Builder { + @JsonCreator + static Builder create() { + return new AutoValue_MemoryEntry.Builder(); + } + /** * Sets the main content of the memory. * *

This is a required field. */ - public abstract Builder setContent(Content content); + @JsonProperty("content") + public abstract Builder content(Content content); /** Sets the author of the memory. */ - public abstract Builder setAuthor(@Nullable String author); + @JsonProperty("author") + public abstract Builder author(@Nullable String author); /** Sets the timestamp when the original content of this memory happened. */ - public abstract Builder setTimestamp(@Nullable String timestamp); + @JsonProperty("timestamp") + public abstract Builder timestamp(@Nullable String timestamp); /** * A convenience method to set the timestamp from an {@link Instant} object, formatted as an ISO @@ -75,8 +89,8 @@ public abstract static class Builder { * * @param instant The timestamp as an Instant object. */ - public Builder setTimestamp(Instant instant) { - return setTimestamp(instant.toString()); + public Builder timestamp(Instant instant) { + return timestamp(instant.toString()); } /** Builds the immutable {@link MemoryEntry} object. */ diff --git a/core/src/main/java/com/google/adk/runner/InMemoryRunner.java b/core/src/main/java/com/google/adk/runner/InMemoryRunner.java index e7cb7b65..38f21a9f 100644 --- a/core/src/main/java/com/google/adk/runner/InMemoryRunner.java +++ b/core/src/main/java/com/google/adk/runner/InMemoryRunner.java @@ -18,6 +18,7 @@ import com.google.adk.agents.BaseAgent; import com.google.adk.artifacts.InMemoryArtifactService; +import com.google.adk.memory.InMemoryMemoryService; import com.google.adk.sessions.InMemorySessionService; /** The class for the in-memory GenAi runner, using in-memory artifact and session services. */ @@ -30,6 +31,11 @@ public InMemoryRunner(BaseAgent agent) { } public InMemoryRunner(BaseAgent agent, String appName) { - super(agent, appName, new InMemoryArtifactService(), new InMemorySessionService()); + super( + agent, + appName, + new InMemoryArtifactService(), + new InMemorySessionService(), + new InMemoryMemoryService()); } } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 8308fb50..fae1a57f 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -25,12 +25,14 @@ import com.google.adk.agents.RunConfig; import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; +import com.google.adk.memory.BaseMemoryService; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; import com.google.adk.tools.BaseTool; import com.google.adk.tools.FunctionTool; import com.google.adk.utils.CollectionUtils; import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.InlineMe; import com.google.genai.types.AudioTranscriptionConfig; import com.google.genai.types.Content; import com.google.genai.types.Modality; @@ -53,17 +55,36 @@ public class Runner { private final String appName; private final BaseArtifactService artifactService; private final BaseSessionService sessionService; + private final BaseMemoryService memoryService; /** Creates a new {@code Runner}. */ public Runner( BaseAgent agent, String appName, BaseArtifactService artifactService, - BaseSessionService sessionService) { + BaseSessionService sessionService, + BaseMemoryService memoryService) { this.agent = agent; this.appName = appName; this.artifactService = artifactService; this.sessionService = sessionService; + this.memoryService = memoryService; + } + + /** + * Creates a new {@code Runner}. + * + * @deprecated Use the constructor with {@code BaseMemoryService} instead even if with a null if + * you don't need the memory service. + */ + @InlineMe(replacement = "this(agent, appName, artifactService, sessionService, null)") + @Deprecated + public Runner( + BaseAgent agent, + String appName, + BaseArtifactService artifactService, + BaseSessionService sessionService) { + this(agent, appName, artifactService, sessionService, null); } public BaseAgent agent() { @@ -82,6 +103,10 @@ public BaseSessionService sessionService() { return this.sessionService; } + public BaseMemoryService memoryService() { + return this.memoryService; + } + /** * Appends a new user message to the session history. * @@ -185,13 +210,10 @@ public Flowable runAsync(Session session, Content newMessage, RunConfig r sess -> { BaseAgent rootAgent = this.agent; InvocationContext invocationContext = - InvocationContext.create( - this.sessionService, - this.artifactService, - InvocationContext.newInvocationContextId(), - rootAgent, + newInvocationContext( sess, - newMessage, + Optional.of(newMessage), + /* liveRequestQueue= */ Optional.empty(), runConfig); if (newMessage != null) { @@ -240,7 +262,8 @@ private InvocationContext newInvocationContextForLive( } } } - return newInvocationContext(session, liveRequestQueue, runConfigBuilder.build()); + return newInvocationContext( + session, /* newMessage= */ Optional.empty(), liveRequestQueue, runConfigBuilder.build()); } /** @@ -249,16 +272,24 @@ private InvocationContext newInvocationContextForLive( * @return a new {@link InvocationContext}. */ private InvocationContext newInvocationContext( - Session session, Optional liveRequestQueue, RunConfig runConfig) { + Session session, + Optional newMessage, + Optional liveRequestQueue, + RunConfig runConfig) { BaseAgent rootAgent = this.agent; InvocationContext invocationContext = - InvocationContext.create( + new InvocationContext( this.sessionService, this.artifactService, + this.memoryService, + liveRequestQueue, + /* branch= */ Optional.empty(), + InvocationContext.newInvocationContextId(), rootAgent, session, - liveRequestQueue.orElse(null), - runConfig); + newMessage, + runConfig, + /* endInvocation= */ false); invocationContext.agent(this.findAgentToRun(session, rootAgent)); return invocationContext; } diff --git a/core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java b/core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java index 3e012ec1..45fbf80b 100644 --- a/core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java +++ b/core/src/main/java/com/google/adk/tools/FunctionCallingUtils.java @@ -19,8 +19,11 @@ import static com.google.common.base.Preconditions.checkArgument; import com.fasterxml.jackson.databind.BeanDescription; +import com.fasterxml.jackson.databind.JavaType; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.introspect.AnnotatedMember; import com.fasterxml.jackson.databind.introspect.BeanPropertyDefinition; +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; import com.google.adk.JsonBaseModel; import com.google.common.base.Strings; import com.google.genai.types.FunctionDeclaration; @@ -36,39 +39,39 @@ import java.util.Map; import java.util.Optional; import java.util.Set; -import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** Utility class for function calling. */ public final class FunctionCallingUtils { - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final ObjectMapper OBJECT_MAPPER = + new ObjectMapper().registerModule(new Jdk8Module()); private static final Logger logger = LoggerFactory.getLogger(FunctionCallingUtils.class); /** Holds the state during a single schema generation process to handle caching and recursion. */ private static class SchemaGenerationContext { - private final Map definitions = new LinkedHashMap<>(); - private final Set processingStack = new HashSet<>(); + private final Map definitions = new LinkedHashMap<>(); + private final Set processingStack = new HashSet<>(); - boolean isProcessing(Type type) { + boolean isProcessing(JavaType type) { return processingStack.contains(type); } - void startProcessing(Type type) { + void startProcessing(JavaType type) { processingStack.add(type); } - void finishProcessing(Type type) { + void finishProcessing(JavaType type) { processingStack.remove(type); } - Optional getDefinition(String name) { - return Optional.ofNullable(definitions.get(name)); + Optional getDefinition(JavaType type) { + return Optional.ofNullable(definitions.get(type)); } - void addDefinition(String name, Schema schema) { - definitions.put(name, schema); + void addDefinition(JavaType type, Schema schema) { + definitions.put(type, schema); } } @@ -156,130 +159,95 @@ private static Schema buildSchemaFromParameter(Parameter param) { * @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson. */ public static Schema buildSchemaFromType(Type type) { - return buildSchemaRecursive(type, new SchemaGenerationContext()); + return buildSchemaRecursive(OBJECT_MAPPER.constructType(type), new SchemaGenerationContext()); } /** * Recursively builds a Schema from a Java Type using a context to manage recursion and caching. * - * @param type The Java {@link Type} to convert. + * @param javaType The Java {@link JavaType} to convert. * @param context The {@link SchemaGenerationContext} for this generation task. * @return The generated {@link Schema}. * @throws IllegalArgumentException if a type is encountered that cannot be serialized by Jackson. */ - @SuppressWarnings("deprecation") // We don't have actual instances of the type - private static Schema buildSchemaRecursive(Type type, SchemaGenerationContext context) { - String definitionName = getTypeKey(type); - - if (definitionName != null) { - if (context.isProcessing(type)) { - logger.warn("Type {} is recursive. Omitting from schema.", type); - return Schema.builder() - .type("OBJECT") - .description("Recursive reference to " + definitionName + " omitted.") - .build(); - } - Optional cachedSchema = context.getDefinition(definitionName); - if (cachedSchema.isPresent()) { - return cachedSchema.get(); - } + private static Schema buildSchemaRecursive(JavaType javaType, SchemaGenerationContext context) { + if (context.isProcessing(javaType)) { + logger.warn("Type {} is recursive. Omitting from schema.", javaType.toCanonical()); + return Schema.builder() + .type("OBJECT") + .description("Recursive reference to " + javaType.toCanonical() + " omitted.") + .build(); + } + Optional cachedSchema = context.getDefinition(javaType); + if (cachedSchema.isPresent()) { + return cachedSchema.get(); } - context.startProcessing(type); + context.startProcessing(javaType); Schema resultSchema; try { Schema.Builder builder = Schema.builder(); - if (type instanceof ParameterizedType parameterizedType) { - Class rawClass = (Class) parameterizedType.getRawType(); - if (List.class.isAssignableFrom(rawClass)) { - Schema itemSchema = - buildSchemaRecursive(parameterizedType.getActualTypeArguments()[0], context); - builder.type("ARRAY").items(itemSchema); - } else if (Map.class.isAssignableFrom(rawClass)) { - builder.type("OBJECT"); - } else { - // Fallback for other parameterized types (e.g., custom generics) is to inspect the - // raw type. - return buildSchemaRecursive(rawClass, context); + Class rawClass = javaType.getRawClass(); + + if (javaType.isCollectionLikeType() && List.class.isAssignableFrom(rawClass)) { + builder.type("ARRAY").items(buildSchemaRecursive(javaType.getContentType(), context)); + } else if (javaType.isMapLikeType()) { + builder.type("OBJECT"); + } else if (String.class.equals(rawClass)) { + builder.type("STRING"); + } else if (Boolean.class.equals(rawClass) || boolean.class.equals(rawClass)) { + builder.type("BOOLEAN"); + } else if (Integer.class.equals(rawClass) || int.class.equals(rawClass)) { + builder.type("INTEGER"); + } else if (Double.class.equals(rawClass) + || double.class.equals(rawClass) + || Float.class.equals(rawClass) + || float.class.equals(rawClass) + || Long.class.equals(rawClass) + || long.class.equals(rawClass)) { + builder.type("NUMBER"); + } else if (rawClass.isEnum()) { + List enumValues = new ArrayList<>(); + for (Object enumConstant : rawClass.getEnumConstants()) { + enumValues.add(enumConstant.toString()); } - } else if (type instanceof Class clazz) { - if (clazz.isEnum()) { - builder.type("STRING"); - List enumValues = new ArrayList<>(); - for (Object enumConstant : clazz.getEnumConstants()) { - enumValues.add(enumConstant.toString()); - } - builder.enum_(enumValues); - } else if (String.class.equals(clazz)) { - builder.type("STRING"); - } else if (Boolean.class.equals(clazz) || boolean.class.equals(clazz)) { - builder.type("BOOLEAN"); - } else if (Integer.class.equals(clazz) || int.class.equals(clazz)) { - builder.type("INTEGER"); - } else if (Double.class.equals(clazz) - || double.class.equals(clazz) - || Float.class.equals(clazz) - || float.class.equals(clazz) - || Long.class.equals(clazz) - || long.class.equals(clazz)) { - builder.type("NUMBER"); - } else if (Map.class.isAssignableFrom(clazz)) { - builder.type("OBJECT"); - } else { - // Default to treating as a POJO. - if (!OBJECT_MAPPER.canSerialize(clazz)) { - throw new IllegalArgumentException( - "Unsupported type: " - + clazz.getName() - + ". The type must be a Jackson-serializable POJO or a registered" - + " primitive. Opaque types like Protobuf models are not supported" - + " directly."); - } - BeanDescription beanDescription = - OBJECT_MAPPER.getSerializationConfig().introspect(OBJECT_MAPPER.constructType(type)); - Map properties = new LinkedHashMap<>(); - for (BeanPropertyDefinition property : beanDescription.findProperties()) { - Type propertyType = property.getRawPrimaryType(); - if (propertyType == null) { - continue; + builder.enum_(enumValues).type("STRING").format("enum"); + } else { // POJO + if (!OBJECT_MAPPER.canSerialize(rawClass)) { + throw new IllegalArgumentException( + "Unsupported type: " + + rawClass.getName() + + ". The type must be a Jackson-serializable POJO or a registered" + + " primitive. Opaque types like Protobuf models are not supported" + + " directly."); + } + BeanDescription beanDescription = + OBJECT_MAPPER.getSerializationConfig().introspect(javaType); + Map properties = new LinkedHashMap<>(); + List required = new ArrayList<>(); + for (BeanPropertyDefinition property : beanDescription.findProperties()) { + AnnotatedMember member = property.getPrimaryMember(); + if (member != null) { + properties.put(property.getName(), buildSchemaRecursive(member.getType(), context)); + if (property.isRequired()) { + required.add(property.getName()); } - properties.put(property.getName(), buildSchemaRecursive(propertyType, context)); } - builder.type("OBJECT").properties(properties); + } + builder.type("OBJECT").properties(properties); + if (!required.isEmpty()) { + builder.required(required); } } resultSchema = builder.build(); } finally { - context.finishProcessing(type); + context.finishProcessing(javaType); } - if (definitionName != null) { - context.addDefinition(definitionName, resultSchema); - } + context.addDefinition(javaType, resultSchema); return resultSchema; } - /** - * Gets a stable, canonical name for a type to use as a key for caching and recursion tracking. - * - * @param type The type to name. - * @return The canonical name of the type, or null if the type should not be tracked (e.g., - * primitives). - */ - @Nullable - private static String getTypeKey(Type type) { - if (type instanceof Class clazz) { - if (clazz.isPrimitive() || clazz.isEnum() || clazz.getName().startsWith("java.")) { - return null; - } - return clazz.getCanonicalName(); - } - if (type instanceof ParameterizedType pType) { - return getTypeKey(pType.getRawType()); - } - return null; - } - private FunctionCallingUtils() {} } diff --git a/core/src/main/java/com/google/adk/tools/FunctionTool.java b/core/src/main/java/com/google/adk/tools/FunctionTool.java index 48b3b7e3..90166187 100644 --- a/core/src/main/java/com/google/adk/tools/FunctionTool.java +++ b/core/src/main/java/com/google/adk/tools/FunctionTool.java @@ -18,6 +18,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; import com.google.adk.agents.InvocationContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -41,7 +42,8 @@ /** FunctionTool implements a customized function calling tool. */ public class FunctionTool extends BaseTool { - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private static final ObjectMapper OBJECT_MAPPER = + new ObjectMapper().registerModule(new Jdk8Module()); private static final Logger logger = LoggerFactory.getLogger(FunctionTool.class); @Nullable private final Object instance; diff --git a/core/src/main/java/com/google/adk/tools/LoadMemoryResponse.java b/core/src/main/java/com/google/adk/tools/LoadMemoryResponse.java new file mode 100644 index 00000000..f91dae77 --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/LoadMemoryResponse.java @@ -0,0 +1,24 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.adk.memory.MemoryEntry; +import java.util.List; + +/** The response from a load memory tool invocation. */ +public record LoadMemoryResponse(@JsonProperty("memories") List memories) {} diff --git a/core/src/main/java/com/google/adk/tools/LoadMemoryTool.java b/core/src/main/java/com/google/adk/tools/LoadMemoryTool.java new file mode 100644 index 00000000..638cd67c --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/LoadMemoryTool.java @@ -0,0 +1,54 @@ +package com.google.adk.tools; + +import com.google.adk.models.LlmRequest; +import com.google.common.collect.ImmutableList; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Single; +import java.lang.reflect.Method; + +/** + * A tool that loads memory for the current user. + * + *

NOTE: Currently this tool only uses text part from the memory. + */ +public class LoadMemoryTool extends FunctionTool { + + private static Method getLoadMemoryMethod() { + try { + return LoadMemoryTool.class.getMethod("loadMemory", String.class, ToolContext.class); + } catch (NoSuchMethodException e) { + throw new IllegalStateException("Failed to load memory method.", e); + } + } + + public LoadMemoryTool() { + super(null, getLoadMemoryMethod(), false); + } + + /** + * Loads the memory for the current user. + * + * @param query The query to load memory for. + * @return A list of memory results. + */ + public static Single loadMemory( + @Annotations.Schema(name = "query") String query, ToolContext toolContext) { + return toolContext + .searchMemory(query) + .map(searchMemoryResponse -> new LoadMemoryResponse(searchMemoryResponse.memories())); + } + + @Override + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return super.processLlmRequest(llmRequestBuilder, toolContext) + .doOnComplete( + () -> + llmRequestBuilder.appendInstructions( + ImmutableList.of( +""" +You have memory. You can use it to answer questions. If any questions need +you to look up the memory, you should call loadMemory function with a query. +"""))); + } +} diff --git a/core/src/main/java/com/google/adk/tools/ToolContext.java b/core/src/main/java/com/google/adk/tools/ToolContext.java index c9241582..f6a55431 100644 --- a/core/src/main/java/com/google/adk/tools/ToolContext.java +++ b/core/src/main/java/com/google/adk/tools/ToolContext.java @@ -19,7 +19,9 @@ import com.google.adk.agents.CallbackContext; import com.google.adk.agents.InvocationContext; import com.google.adk.events.EventActions; +import com.google.adk.memory.SearchMemoryResponse; import com.google.errorprone.annotations.CanIgnoreReturnValue; +import io.reactivex.rxjava3.core.Single; import java.util.Optional; /** ToolContext object provides a structured context for executing tools or functions. */ @@ -62,10 +64,15 @@ private void getAuthResponse() { throw new UnsupportedOperationException("Auth response retrieval not implemented yet."); } - @SuppressWarnings("unused") - private void searchMemory() { - // TODO: b/414680316 - Implement search memory logic. Make this public. - throw new UnsupportedOperationException("Search memory not implemented yet."); + /** Searches the memory of the current user. */ + public Single searchMemory(String query) { + if (invocationContext.memoryService() == null) { + throw new IllegalStateException("Memory service is not initialized."); + } + return invocationContext + .memoryService() + .searchMemory( + invocationContext.session().appName(), invocationContext.session().userId(), query); } public static Builder builder(InvocationContext invocationContext) { diff --git a/core/src/test/java/com/google/adk/agents/AgentWithMemoryTest.java b/core/src/test/java/com/google/adk/agents/AgentWithMemoryTest.java new file mode 100644 index 00000000..d943e0aa --- /dev/null +++ b/core/src/test/java/com/google/adk/agents/AgentWithMemoryTest.java @@ -0,0 +1,112 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.agents; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.runner.InMemoryRunner; +import com.google.adk.sessions.Session; +import com.google.adk.testing.TestLlm; +import com.google.adk.tools.LoadMemoryTool; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.Part; +import java.util.Optional; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class AgentWithMemoryTest { + @Test + public void agentRemembersUserNameWithMemoryTool() throws Exception { + Part functionCall = + Part.builder() + .functionCall( + FunctionCall.builder() + .name("loadMemory") + .args(ImmutableMap.of("query", "what is my name?")) + .build()) + .build(); + + TestLlm testLlm = + new TestLlm( + ImmutableList.of( + LlmResponse.builder() + .content( + Content.builder() + .parts(Part.fromText("OK, I'll remember that.")) + .role("test-agent") + .build()) + .build(), + LlmResponse.builder() + .content( + Content.builder() + .role("test-agent") + .parts(ImmutableList.of(functionCall)) + .build()) + .build(), + LlmResponse.builder() + .content( + Content.builder() + // we won't actually read the name from here since that'd be + // cheating. + .parts(Part.fromText("Your name is James.")) + .role("test-agent") + .build()) + .build())); + + LlmAgent agent = + LlmAgent.builder() + .name("test-agent") + .model(testLlm) + .tools(ImmutableList.of(new LoadMemoryTool())) + .build(); + + InMemoryRunner runner = new InMemoryRunner(agent); + Session session = runner.sessionService().createSession("test-app", "test-user").blockingGet(); + + Content firstMessage = Content.fromParts(Part.fromText("My name is James")); + + var unused = + runner.runAsync(session, firstMessage, RunConfig.builder().build()).toList().blockingGet(); + // Save the session so we can bring it up on the next request. + runner.memoryService().addSessionToMemory(session).blockingAwait(); + + Content secondMessage = Content.fromParts(Part.fromText("what is my name?")); + unused = + runner.runAsync(session, secondMessage, RunConfig.builder().build()).toList().blockingGet(); + + // Verify that the tool's response was included in the next LLM call. + LlmRequest lastRequest = testLlm.getLastRequest(); + Content functionResponseContent = Iterables.getLast(lastRequest.contents()); + Optional functionResponsePart = + functionResponseContent.parts().get().stream() + .filter(p -> p.functionResponse().isPresent()) + .findFirst(); + assertThat(functionResponsePart).isPresent(); + FunctionResponse functionResponse = functionResponsePart.get().functionResponse().get(); + assertThat(functionResponse.name()).hasValue("loadMemory"); + assertThat(functionResponse.response().get().toString()).contains("My name is James"); + } +} diff --git a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java index 30376e04..881dad24 100644 --- a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java +++ b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java @@ -20,11 +20,13 @@ import static org.mockito.Mockito.mock; import com.google.adk.artifacts.BaseArtifactService; +import com.google.adk.memory.BaseMemoryService; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; import com.google.genai.types.Content; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -37,6 +39,7 @@ public final class InvocationContextTest { @Mock private BaseSessionService mockSessionService; @Mock private BaseArtifactService mockArtifactService; + @Mock private BaseMemoryService mockMemoryService; @Mock private BaseAgent mockAgent; private Session session; private Content userContent; @@ -60,18 +63,23 @@ public void setUp() { @Test public void testCreateWithUserContent() { InvocationContext context = - InvocationContext.create( + new InvocationContext( mockSessionService, mockArtifactService, + mockMemoryService, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), testInvocationId, mockAgent, session, - userContent, - runConfig); + Optional.of(userContent), + runConfig, + /* endInvocation= */ false); assertThat(context).isNotNull(); assertThat(context.sessionService()).isEqualTo(mockSessionService); assertThat(context.artifactService()).isEqualTo(mockArtifactService); + assertThat(context.memoryService()).isEqualTo(mockMemoryService); assertThat(context.liveRequestQueue()).isEmpty(); assertThat(context.invocationId()).isEqualTo(testInvocationId); assertThat(context.agent()).isEqualTo(mockAgent); @@ -84,14 +92,18 @@ public void testCreateWithUserContent() { @Test public void testCreateWithNullUserContent() { InvocationContext context = - InvocationContext.create( + new InvocationContext( mockSessionService, mockArtifactService, + mockMemoryService, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), testInvocationId, mockAgent, session, - null, // Pass null for userContent - runConfig); + /* userContent= */ Optional.empty(), + runConfig, + /* endInvocation= */ false); assertThat(context).isNotNull(); assertThat(context.userContent()).isEmpty(); @@ -100,17 +112,23 @@ public void testCreateWithNullUserContent() { @Test public void testCreateWithLiveRequestQueue() { InvocationContext context = - InvocationContext.create( + new InvocationContext( mockSessionService, mockArtifactService, + mockMemoryService, + Optional.of(liveRequestQueue), + /* branch= */ Optional.empty(), + InvocationContext.newInvocationContextId(), mockAgent, session, - liveRequestQueue, - runConfig); + /* userContent= */ Optional.empty(), + runConfig, + /* endInvocation= */ false); assertThat(context).isNotNull(); assertThat(context.sessionService()).isEqualTo(mockSessionService); assertThat(context.artifactService()).isEqualTo(mockArtifactService); + assertThat(context.memoryService()).isEqualTo(mockMemoryService); assertThat(context.liveRequestQueue()).hasValue(liveRequestQueue); assertThat(context.invocationId()).startsWith("e-"); // Check format of generated ID assertThat(context.agent()).isEqualTo(mockAgent); @@ -123,14 +141,18 @@ public void testCreateWithLiveRequestQueue() { @Test public void testCopyOf() { InvocationContext originalContext = - InvocationContext.create( + new InvocationContext( mockSessionService, mockArtifactService, + mockMemoryService, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), testInvocationId, mockAgent, session, - userContent, - runConfig); + Optional.of(userContent), + runConfig, + /* endInvocation= */ false); originalContext.activeStreamingTools().putAll(activeStreamingTools); InvocationContext copiedContext = InvocationContext.copyOf(originalContext); @@ -140,6 +162,7 @@ public void testCopyOf() { assertThat(copiedContext.sessionService()).isEqualTo(originalContext.sessionService()); assertThat(copiedContext.artifactService()).isEqualTo(originalContext.artifactService()); + assertThat(copiedContext.memoryService()).isEqualTo(originalContext.memoryService()); assertThat(copiedContext.liveRequestQueue()).isEqualTo(originalContext.liveRequestQueue()); assertThat(copiedContext.invocationId()).isEqualTo(originalContext.invocationId()); assertThat(copiedContext.agent()).isEqualTo(originalContext.agent()); @@ -154,17 +177,22 @@ public void testCopyOf() { @Test public void testGetters() { InvocationContext context = - InvocationContext.create( + new InvocationContext( mockSessionService, mockArtifactService, + mockMemoryService, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), testInvocationId, mockAgent, session, - userContent, - runConfig); + Optional.of(userContent), + runConfig, + /* endInvocation= */ false); assertThat(context.sessionService()).isEqualTo(mockSessionService); assertThat(context.artifactService()).isEqualTo(mockArtifactService); + assertThat(context.memoryService()).isEqualTo(mockMemoryService); assertThat(context.liveRequestQueue()).isEmpty(); assertThat(context.invocationId()).isEqualTo(testInvocationId); assertThat(context.agent()).isEqualTo(mockAgent); @@ -177,14 +205,18 @@ public void testGetters() { @Test public void testSetAgent() { InvocationContext context = - InvocationContext.create( + new InvocationContext( mockSessionService, mockArtifactService, + mockMemoryService, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), testInvocationId, mockAgent, session, - userContent, - runConfig); + Optional.of(userContent), + runConfig, + /* endInvocation= */ false); BaseAgent newMockAgent = mock(BaseAgent.class); context.agent(newMockAgent); @@ -207,14 +239,18 @@ public void testNewInvocationContextId() { @Test public void testEquals_sameObject() { InvocationContext context = - InvocationContext.create( + new InvocationContext( mockSessionService, mockArtifactService, + mockMemoryService, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), testInvocationId, mockAgent, session, - userContent, - runConfig); + Optional.of(userContent), + runConfig, + /* endInvocation= */ false); assertThat(context.equals(context)).isTrue(); } @@ -222,14 +258,18 @@ public void testEquals_sameObject() { @Test public void testEquals_null() { InvocationContext context = - InvocationContext.create( + new InvocationContext( mockSessionService, mockArtifactService, + mockMemoryService, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), testInvocationId, mockAgent, session, - userContent, - runConfig); + Optional.of(userContent), + runConfig, + /* endInvocation= */ false); assertThat(context.equals(null)).isFalse(); } @@ -237,25 +277,33 @@ public void testEquals_null() { @Test public void testEquals_sameValues() { InvocationContext context1 = - InvocationContext.create( + new InvocationContext( mockSessionService, mockArtifactService, + mockMemoryService, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), testInvocationId, mockAgent, session, - userContent, - runConfig); + Optional.of(userContent), + runConfig, + /* endInvocation= */ false); // Create another context with the same parameters InvocationContext context2 = - InvocationContext.create( + new InvocationContext( mockSessionService, mockArtifactService, + mockMemoryService, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), testInvocationId, mockAgent, session, - userContent, - runConfig); + Optional.of(userContent), + runConfig, + /* endInvocation= */ false); assertThat(context1.equals(context2)).isTrue(); assertThat(context2.equals(context1)).isTrue(); // Check symmetry @@ -264,64 +312,89 @@ public void testEquals_sameValues() { @Test public void testEquals_differentValues() { InvocationContext context = - InvocationContext.create( + new InvocationContext( mockSessionService, mockArtifactService, + mockMemoryService, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), testInvocationId, mockAgent, session, - userContent, - runConfig); + Optional.of(userContent), + runConfig, + /* endInvocation= */ false); // Create contexts with one field different InvocationContext contextWithDiffSessionService = - InvocationContext.create( + new InvocationContext( mock(BaseSessionService.class), // Different mock mockArtifactService, + mockMemoryService, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), testInvocationId, mockAgent, session, - userContent, - runConfig); + Optional.of(userContent), + runConfig, + /* endInvocation= */ false); InvocationContext contextWithDiffInvocationId = - InvocationContext.create( + new InvocationContext( mockSessionService, mockArtifactService, + mockMemoryService, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), "another-id", // Different ID mockAgent, session, - userContent, - runConfig); + Optional.of(userContent), + runConfig, + /* endInvocation= */ false); InvocationContext contextWithDiffAgent = - InvocationContext.create( + new InvocationContext( mockSessionService, mockArtifactService, + mockMemoryService, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), testInvocationId, mock(BaseAgent.class), // Different mock session, - userContent, - runConfig); + Optional.of(userContent), + runConfig, + /* endInvocation= */ false); InvocationContext contextWithUserContentEmpty = - InvocationContext.create( + new InvocationContext( mockSessionService, mockArtifactService, + mockMemoryService, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), testInvocationId, mockAgent, session, - null, // User content is null (Optional.empty) - runConfig); + /* userContent= */ Optional.empty(), + runConfig, + /* endInvocation= */ false); InvocationContext contextWithLiveQueuePresent = - InvocationContext.create( + new InvocationContext( mockSessionService, mockArtifactService, + mockMemoryService, + Optional.of(liveRequestQueue), + /* branch= */ Optional.empty(), + InvocationContext.newInvocationContextId(), mockAgent, session, - liveRequestQueue, // Live queue is present (Optional.of) - runConfig); + /* userContent= */ Optional.empty(), + runConfig, + /* endInvocation= */ false); assertThat(context.equals(contextWithDiffSessionService)).isFalse(); assertThat(context.equals(contextWithDiffInvocationId)).isFalse(); @@ -333,35 +406,47 @@ public void testEquals_differentValues() { @Test public void testHashCode_differentValues() { InvocationContext context = - InvocationContext.create( + new InvocationContext( mockSessionService, mockArtifactService, + mockMemoryService, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), testInvocationId, mockAgent, session, - userContent, - runConfig); + Optional.of(userContent), + runConfig, + /* endInvocation= */ false); // Create contexts with one field different InvocationContext contextWithDiffSessionService = - InvocationContext.create( + new InvocationContext( mock(BaseSessionService.class), // Different mock mockArtifactService, + mockMemoryService, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), testInvocationId, mockAgent, session, - userContent, - runConfig); + Optional.of(userContent), + runConfig, + /* endInvocation= */ false); InvocationContext contextWithDiffInvocationId = - InvocationContext.create( + new InvocationContext( mockSessionService, mockArtifactService, + mockMemoryService, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), "another-id", // Different ID mockAgent, session, - userContent, - runConfig); + Optional.of(userContent), + runConfig, + /* endInvocation= */ false); assertThat(context).isNotEqualTo(contextWithDiffSessionService); assertThat(context).isNotEqualTo(contextWithDiffInvocationId); diff --git a/core/src/test/java/com/google/adk/testing/TestUtils.java b/core/src/test/java/com/google/adk/testing/TestUtils.java index 0873aa9c..fe76eb96 100644 --- a/core/src/test/java/com/google/adk/testing/TestUtils.java +++ b/core/src/test/java/com/google/adk/testing/TestUtils.java @@ -27,6 +27,7 @@ import com.google.adk.artifacts.InMemoryArtifactService; import com.google.adk.events.Event; import com.google.adk.events.EventActions; +import com.google.adk.memory.InMemoryMemoryService; import com.google.adk.models.BaseLlm; import com.google.adk.models.LlmResponse; import com.google.adk.sessions.InMemorySessionService; @@ -52,14 +53,18 @@ public final class TestUtils { public static InvocationContext createInvocationContext(BaseAgent agent, RunConfig runConfig) { InMemorySessionService sessionService = new InMemorySessionService(); - return InvocationContext.create( + return new InvocationContext( sessionService, new InMemoryArtifactService(), + new InMemoryMemoryService(), + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), "invocationId", agent, sessionService.createSession("test-app", "test-user").blockingGet(), - Content.fromParts(Part.fromText("user content")), - runConfig); + Optional.of(Content.fromParts(Part.fromText("user content"))), + runConfig, + /* endInvocation= */ false); } public static InvocationContext createInvocationContext(BaseAgent agent) { diff --git a/core/src/test/java/com/google/adk/tools/AgentToolTest.java b/core/src/test/java/com/google/adk/tools/AgentToolTest.java index 10fd512e..366983ba 100644 --- a/core/src/test/java/com/google/adk/tools/AgentToolTest.java +++ b/core/src/test/java/com/google/adk/tools/AgentToolTest.java @@ -34,6 +34,7 @@ import com.google.genai.types.Schema; import io.reactivex.rxjava3.core.Flowable; import java.util.Map; +import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -321,13 +322,18 @@ public void call_withoutInputSchema_requestIsSentToAgent() throws Exception { private static ToolContext createToolContext(LlmAgent agent) { return ToolContext.builder( - InvocationContext.create( + new InvocationContext( /* sessionService= */ null, /* artifactService= */ null, + /* memoryService= */ null, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), + /* invocationId= */ InvocationContext.newInvocationContextId(), agent, Session.builder("123").build(), - /* liveRequestQueue= */ null, - /* runConfig= */ null)) + /* userContent= */ Optional.empty(), + /* runConfig= */ null, + /* endInvocation= */ false)) .build(); } } diff --git a/core/src/test/java/com/google/adk/tools/FunctionToolTest.java b/core/src/test/java/com/google/adk/tools/FunctionToolTest.java index 30037849..7ef67db4 100644 --- a/core/src/test/java/com/google/adk/tools/FunctionToolTest.java +++ b/core/src/test/java/com/google/adk/tools/FunctionToolTest.java @@ -20,6 +20,7 @@ import static org.junit.Assert.assertThrows; import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.RunConfig; import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -32,6 +33,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -208,8 +210,18 @@ public void call_withAllSupportedParameterTypes() throws Exception { FunctionTool tool = FunctionTool.create(Functions.class, "returnAllSupportedParametersAsMap"); ToolContext toolContext = ToolContext.builder( - InvocationContext.create( - null, null, null, Session.builder("123").build(), null, null)) + new InvocationContext( + /* sessionService= */ null, + /* artifactService= */ null, + /* memoryService= */ null, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), + /* invocationId= */ null, + /* agent= */ null, + /* session= */ Session.builder("123").build(), + /* userContent= */ Optional.empty(), + /* runConfig= */ RunConfig.builder().build(), + /* endInvocation= */ false)) .functionCallId("functionCallId") .build(); @@ -400,7 +412,7 @@ public void create_withRecursiveParam_avoidsInfiniteRecursion() { Schema.builder() .type("OBJECT") .description( - "Recursive reference to com.google.adk.tools.FunctionToolTest.Node" + "Recursive reference to com.google.adk.tools.FunctionToolTest$Node" + " omitted.") .build())) .build(); @@ -480,8 +492,18 @@ public void call_nonStaticWithAllSupportedParameterTypes() throws Exception { FunctionTool.create(functions, "nonStaticReturnAllSupportedParametersAsMap"); ToolContext toolContext = ToolContext.builder( - InvocationContext.create( - null, null, null, Session.builder("123").build(), null, null)) + new InvocationContext( + /* sessionService= */ null, + /* artifactService= */ null, + /* memoryService= */ null, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), + /* invocationId= */ null, + /* agent= */ null, + /* session= */ Session.builder("123").build(), + /* userContent= */ Optional.empty(), + /* runConfig= */ null, + /* endInvocation= */ false)) .functionCallId("functionCallId") .build(); @@ -711,7 +733,7 @@ private record DiamondRight(DiamondBottom bottom) {} private record DiamondTop(DiamondLeft left, DiamondRight right) {} - private record ParametrizedCustomType(T value) {} + private record ParametrizedCustomType(T value) {} private record Node(String value, Node next) {} } diff --git a/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java b/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java index 7e137c5b..c6d84bee 100644 --- a/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java +++ b/core/src/test/java/com/google/adk/tools/retrieval/VertexAiRagRetrievalTest.java @@ -25,6 +25,7 @@ import com.google.genai.types.VertexRagStore; import com.google.genai.types.VertexRagStoreRagResource; import java.util.Map; +import java.util.Optional; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -55,8 +56,18 @@ public void runAsync_withResults_returnsContexts() throws Exception { String query = "test query"; ToolContext toolContext = ToolContext.builder( - InvocationContext.create( - null, null, null, Session.builder("123").build(), null, null)) + new InvocationContext( + /* sessionService= */ null, + /* artifactService= */ null, + /* memoryService= */ null, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), + /* invocationId= */ null, + /* agent= */ null, + Session.builder("123").build(), + /* userContent= */ null, + /* runConfig= */ null, + /* endInvocation= */ false)) .functionCallId("functionCallId") .build(); RetrieveContextsRequest expectedRequest = @@ -100,8 +111,18 @@ public void runAsync_noResults_returnsNoResultFoundMessage() throws Exception { String query = "test query"; ToolContext toolContext = ToolContext.builder( - InvocationContext.create( - null, null, null, Session.builder("123").build(), null, null)) + new InvocationContext( + /* sessionService= */ null, + /* artifactService= */ null, + /* memoryService= */ null, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), + /* invocationId= */ null, + /* agent= */ null, + Session.builder("123").build(), + /* userContent= */ Optional.empty(), + /* runConfig= */ null, + /* endInvocation= */ false)) .functionCallId("functionCallId") .build(); RetrieveContextsRequest expectedRequest = @@ -147,8 +168,18 @@ public void processLlmRequest_gemini2Model_addVertexRagStoreToConfig() { LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("gemini-2-pro"); ToolContext toolContext = ToolContext.builder( - InvocationContext.create( - null, null, null, Session.builder("123").build(), null, null)) + new InvocationContext( + /* sessionService= */ null, + /* artifactService= */ null, + /* memoryService= */ null, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), + /* invocationId= */ null, + /* agent= */ null, + Session.builder("123").build(), + /* userContent= */ Optional.empty(), + /* runConfig= */ null, + /* endInvocation= */ false)) .functionCallId("functionCallId") .build(); @@ -214,8 +245,18 @@ public void processLlmRequest_otherModel_doNotAddVertexRagStoreToConfig() { LlmRequest.Builder llmRequestBuilder = LlmRequest.builder().model("gemini-1-pro"); ToolContext toolContext = ToolContext.builder( - InvocationContext.create( - null, null, null, Session.builder("123").build(), null, null)) + new InvocationContext( + /* sessionService= */ null, + /* artifactService= */ null, + /* memoryService= */ null, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), + /* invocationId= */ null, + /* agent= */ null, + Session.builder("123").build(), + /* userContent= */ Optional.empty(), + /* runConfig= */ null, + /* endInvocation= */ false)) .functionCallId("functionCallId") .build(); GenerateContentConfig initialConfig = GenerateContentConfig.builder().build(); diff --git a/core/src/test/java/com/google/adk/utils/InstructionUtilsTest.java b/core/src/test/java/com/google/adk/utils/InstructionUtilsTest.java index b0a3b032..70a42e6d 100644 --- a/core/src/test/java/com/google/adk/utils/InstructionUtilsTest.java +++ b/core/src/test/java/com/google/adk/utils/InstructionUtilsTest.java @@ -7,11 +7,13 @@ import com.google.adk.agents.InvocationContext; import com.google.adk.agents.RunConfig; import com.google.adk.artifacts.InMemoryArtifactService; +import com.google.adk.memory.InMemoryMemoryService; import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.sessions.State; import com.google.genai.types.Content; import com.google.genai.types.Part; +import java.util.Optional; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -23,20 +25,26 @@ public final class InstructionUtilsTest { private InvocationContext templateContext; private InMemorySessionService sessionService; private InMemoryArtifactService artifactService; + private InMemoryMemoryService memoryService; @Before public void setUp() { sessionService = new InMemorySessionService(); artifactService = new InMemoryArtifactService(); + memoryService = new InMemoryMemoryService(); templateContext = - InvocationContext.create( + new InvocationContext( sessionService, artifactService, + memoryService, + /* liveRequestQueue= */ Optional.empty(), + /* branch= */ Optional.empty(), "invocationId", createRootAgent(), sessionService.createSession("test-app", "test-user").blockingGet(), - Content.fromParts(), - RunConfig.builder().build()); + Optional.of(Content.fromParts()), + RunConfig.builder().build(), + /* endInvocation= */ false); } @Test diff --git a/dev/src/main/java/com/google/adk/web/AdkWebServer.java b/dev/src/main/java/com/google/adk/web/AdkWebServer.java index f8e1775b..679729c2 100644 --- a/dev/src/main/java/com/google/adk/web/AdkWebServer.java +++ b/dev/src/main/java/com/google/adk/web/AdkWebServer.java @@ -30,6 +30,8 @@ import com.google.adk.artifacts.InMemoryArtifactService; import com.google.adk.artifacts.ListArtifactsResponse; import com.google.adk.events.Event; +import com.google.adk.memory.BaseMemoryService; +import com.google.adk.memory.InMemoryMemoryService; import com.google.adk.runner.Runner; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.InMemorySessionService; @@ -150,6 +152,18 @@ public BaseArtifactService artifactService() { return new InMemoryArtifactService(); } + /** + * Provides the singleton instance of the MemoryService (InMemory). Will be configurable once the + * Vertex MemoryService is available. + * + * @return An instance of BaseMemoryService (currently InMemoryMemoryService). + */ + @Bean + public BaseMemoryService memoryService() { + log.info("Using InMemoryMemoryService"); + return new InMemoryMemoryService(); + } + @Bean("loadedAgentRegistry") public Map loadedAgentRegistry( AgentLoadingProperties props, RunnerService runnerService) { @@ -200,16 +214,19 @@ public static class RunnerService { private final Map agentRegistry; private final BaseArtifactService artifactService; private final BaseSessionService sessionService; + private final BaseMemoryService memoryService; private final Map runnerCache = new ConcurrentHashMap<>(); @Autowired public RunnerService( @Lazy @Qualifier("loadedAgentRegistry") Map agentRegistry, BaseArtifactService artifactService, - BaseSessionService sessionService) { + BaseSessionService sessionService, + BaseMemoryService memoryService) { this.agentRegistry = agentRegistry; this.artifactService = artifactService; this.sessionService = sessionService; + this.memoryService = memoryService; } /** @@ -236,7 +253,8 @@ public Runner getRunner(String appName) { "RunnerService: Creating Runner for appName: {}, using agent" + " definition: {}", appName, agent.name()); - return new Runner(agent, appName, this.artifactService, this.sessionService); + return new Runner( + agent, appName, this.artifactService, this.sessionService, this.memoryService); }); } diff --git a/maven_plugin/src/main/java/com/google/adk/maven/web/AdkWebServer.java b/maven_plugin/src/main/java/com/google/adk/maven/web/AdkWebServer.java index b96e6087..9fda91e0 100644 --- a/maven_plugin/src/main/java/com/google/adk/maven/web/AdkWebServer.java +++ b/maven_plugin/src/main/java/com/google/adk/maven/web/AdkWebServer.java @@ -33,6 +33,8 @@ import com.google.adk.artifacts.ListArtifactsResponse; import com.google.adk.events.Event; import com.google.adk.maven.AgentLoader; +import com.google.adk.memory.BaseMemoryService; +import com.google.adk.memory.InMemoryMemoryService; import com.google.adk.runner.Runner; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.InMemorySessionService; @@ -148,6 +150,18 @@ public BaseArtifactService artifactService() { return new InMemoryArtifactService(); } + /** + * Provides the singleton instance of the MemoryService (InMemory). Will be made configurable once + * we have the Vertex MemoryService. + * + * @return An instance of BaseMemoryService (currently InMemoryMemoryService). + */ + @Bean + public BaseMemoryService memoryService() { + log.info("Using InMemoryMemoryService"); + return new InMemoryMemoryService(); + } + @Bean public ObjectMapper objectMapper() { return JsonBaseModel.getMapper(); @@ -161,16 +175,19 @@ public static class RunnerService { private final AgentLoader agentProvider; private final BaseArtifactService artifactService; private final BaseSessionService sessionService; + private final BaseMemoryService memoryService; private final Map runnerCache = new ConcurrentHashMap<>(); @Autowired public RunnerService( @Qualifier("agentLoader") AgentLoader agentProvider, BaseArtifactService artifactService, - BaseSessionService sessionService) { + BaseSessionService sessionService, + BaseMemoryService memoryService) { this.agentProvider = agentProvider; this.artifactService = artifactService; this.sessionService = sessionService; + this.memoryService = memoryService; } /** @@ -190,7 +207,8 @@ public Runner getRunner(String appName) { "RunnerService: Creating Runner for appName: {}, using agent" + " definition: {}", appName, agent.name()); - return new Runner(agent, appName, this.artifactService, this.sessionService); + return new Runner( + agent, appName, this.artifactService, this.sessionService, this.memoryService); } catch (java.util.NoSuchElementException e) { log.error( "Agent/App named '{}' not found in registry. Available apps: {}",