Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions core/src/main/java/com/google/adk/agents/InvocationContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

import com.google.adk.artifacts.BaseArtifactService;
import com.google.adk.exceptions.LlmCallsLimitExceededException;
import com.google.adk.memory.BaseMemoryService;
import com.google.adk.sessions.BaseSessionService;
import com.google.adk.sessions.Session;
import com.google.errorprone.annotations.InlineMe;
import com.google.genai.types.Content;
import java.util.Map;
import java.util.Objects;
Expand All @@ -33,6 +35,7 @@ public class InvocationContext {

private final BaseSessionService sessionService;
private final BaseArtifactService artifactService;
private final BaseMemoryService memoryService;
private final Optional<LiveRequestQueue> liveRequestQueue;
private final Map<String, ActiveStreamingTool> activeStreamingTools = new ConcurrentHashMap<>();

Expand All @@ -46,9 +49,10 @@ public class InvocationContext {
private boolean endInvocation;
private final InvocationCostManager invocationCostManager = new InvocationCostManager();

private InvocationContext(
public InvocationContext(
BaseSessionService sessionService,
BaseArtifactService artifactService,
BaseMemoryService memoryService,
Optional<LiveRequestQueue> liveRequestQueue,
Optional<String> branch,
String invocationId,
Expand All @@ -59,6 +63,7 @@ private InvocationContext(
boolean endInvocation) {
this.sessionService = sessionService;
this.artifactService = artifactService;
this.memoryService = memoryService;
this.liveRequestQueue = liveRequestQueue;
this.branch = branch;
this.invocationId = invocationId;
Expand All @@ -69,6 +74,16 @@ private InvocationContext(
this.endInvocation = endInvocation;
}

/**
* @deprecated Use the {@link #InvocationContext} constructor directly instead
*/
@InlineMe(
replacement =
"new InvocationContext(sessionService, artifactService, null, Optional.empty(),"
+ " Optional.empty(), invocationId, agent, session, Optional.ofNullable(userContent),"
+ " runConfig, false)",
imports = {"com.google.adk.agents.InvocationContext", "java.util.Optional"})
@Deprecated
public static InvocationContext create(
BaseSessionService sessionService,
BaseArtifactService artifactService,
Expand All @@ -80,7 +95,8 @@ public static InvocationContext create(
return new InvocationContext(
sessionService,
artifactService,
Optional.empty(),
/* memoryService= */ null,
/* liveRequestQueue= */ Optional.empty(),
/* branch= */ Optional.empty(),
invocationId,
agent,
Expand All @@ -90,6 +106,17 @@ public static InvocationContext create(
false);
}

/**
* @deprecated Use the {@link #InvocationContext} constructor directly instead
*/
@InlineMe(
replacement =
"new InvocationContext(sessionService, artifactService, null,"
+ " Optional.ofNullable(liveRequestQueue), Optional.empty(),"
+ " InvocationContext.newInvocationContextId(), agent, session, Optional.empty(),"
+ " runConfig, false)",
imports = {"com.google.adk.agents.InvocationContext", "java.util.Optional"})
@Deprecated
public static InvocationContext create(
BaseSessionService sessionService,
BaseArtifactService artifactService,
Expand All @@ -100,6 +127,7 @@ public static InvocationContext create(
return new InvocationContext(
sessionService,
artifactService,
/* memoryService= */ null,
Optional.ofNullable(liveRequestQueue),
/* branch= */ Optional.empty(),
InvocationContext.newInvocationContextId(),
Expand All @@ -115,6 +143,7 @@ public static InvocationContext copyOf(InvocationContext other) {
new InvocationContext(
other.sessionService,
other.artifactService,
other.memoryService,
other.liveRequestQueue,
other.branch,
other.invocationId,
Expand All @@ -135,6 +164,10 @@ public BaseArtifactService artifactService() {
return artifactService;
}

public BaseMemoryService memoryService() {
return memoryService;
}

public Map<String, ActiveStreamingTool> activeStreamingTools() {
return activeStreamingTools;
}
Expand Down Expand Up @@ -226,6 +259,7 @@ public boolean equals(Object o) {
return endInvocation == that.endInvocation
&& Objects.equals(sessionService, that.sessionService)
&& Objects.equals(artifactService, that.artifactService)
&& Objects.equals(memoryService, that.memoryService)
&& Objects.equals(liveRequestQueue, that.liveRequestQueue)
&& Objects.equals(activeStreamingTools, that.activeStreamingTools)
&& Objects.equals(branch, that.branch)
Expand All @@ -241,6 +275,7 @@ public int hashCode() {
return Objects.hash(
sessionService,
artifactService,
memoryService,
liveRequestQueue,
activeStreamingTools,
branch,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ public Single<SearchMemoryResponse> searchMemory(String appName, String userId,
if (!Collections.disjoint(wordsInQuery, wordsInEvent)) {
MemoryEntry memory =
MemoryEntry.builder()
.setContent(event.content().get())
.setAuthor(event.author())
.setTimestamp(formatTimestamp(event.timestamp()))
.content(event.content().get())
.author(event.author())
.timestamp(formatTimestamp(event.timestamp()))
.build();
matchingMemories.add(memory);
}
Expand Down
24 changes: 19 additions & 5 deletions core/src/main/java/com/google/adk/memory/MemoryEntry.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,26 @@

package com.google.adk.memory;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.google.auto.value.AutoValue;
import com.google.genai.types.Content;
import java.time.Instant;
import javax.annotation.Nullable;

/** Represents one memory entry. */
@AutoValue
@JsonDeserialize(builder = MemoryEntry.Builder.class)
public abstract class MemoryEntry {

/** Returns the main content of the memory. */
@JsonProperty("content")
public abstract Content content();

/** Returns the author of the memory, or null if not set. */
@Nullable
@JsonProperty("author")
public abstract String author();

/**
Expand All @@ -56,27 +62,35 @@ public static Builder builder() {
@AutoValue.Builder
public abstract static class Builder {

@JsonCreator
static Builder create() {
return new AutoValue_MemoryEntry.Builder();
}

/**
* Sets the main content of the memory.
*
* <p>This is a required field.
*/
public abstract Builder setContent(Content content);
@JsonProperty("content")
public abstract Builder content(Content content);

/** Sets the author of the memory. */
public abstract Builder setAuthor(@Nullable String author);
@JsonProperty("author")
public abstract Builder author(@Nullable String author);

/** Sets the timestamp when the original content of this memory happened. */
public abstract Builder setTimestamp(@Nullable String timestamp);
@JsonProperty("timestamp")
public abstract Builder timestamp(@Nullable String timestamp);

/**
* A convenience method to set the timestamp from an {@link Instant} object, formatted as an ISO
* 8601 string.
*
* @param instant The timestamp as an Instant object.
*/
public Builder setTimestamp(Instant instant) {
return setTimestamp(instant.toString());
public Builder timestamp(Instant instant) {
return timestamp(instant.toString());
}

/** Builds the immutable {@link MemoryEntry} object. */
Expand Down
8 changes: 7 additions & 1 deletion core/src/main/java/com/google/adk/runner/InMemoryRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.google.adk.agents.BaseAgent;
import com.google.adk.artifacts.InMemoryArtifactService;
import com.google.adk.memory.InMemoryMemoryService;
import com.google.adk.sessions.InMemorySessionService;

/** The class for the in-memory GenAi runner, using in-memory artifact and session services. */
Expand All @@ -30,6 +31,11 @@ public InMemoryRunner(BaseAgent agent) {
}

public InMemoryRunner(BaseAgent agent, String appName) {
super(agent, appName, new InMemoryArtifactService(), new InMemorySessionService());
super(
agent,
appName,
new InMemoryArtifactService(),
new InMemorySessionService(),
new InMemoryMemoryService());
}
}
55 changes: 43 additions & 12 deletions core/src/main/java/com/google/adk/runner/Runner.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
import com.google.adk.agents.RunConfig;
import com.google.adk.artifacts.BaseArtifactService;
import com.google.adk.events.Event;
import com.google.adk.memory.BaseMemoryService;
import com.google.adk.sessions.BaseSessionService;
import com.google.adk.sessions.Session;
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.FunctionTool;
import com.google.adk.utils.CollectionUtils;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.InlineMe;
import com.google.genai.types.AudioTranscriptionConfig;
import com.google.genai.types.Content;
import com.google.genai.types.Modality;
Expand All @@ -53,17 +55,36 @@ public class Runner {
private final String appName;
private final BaseArtifactService artifactService;
private final BaseSessionService sessionService;
private final BaseMemoryService memoryService;

/** Creates a new {@code Runner}. */
public Runner(
BaseAgent agent,
String appName,
BaseArtifactService artifactService,
BaseSessionService sessionService) {
BaseSessionService sessionService,
BaseMemoryService memoryService) {
this.agent = agent;
this.appName = appName;
this.artifactService = artifactService;
this.sessionService = sessionService;
this.memoryService = memoryService;
}

/**
* Creates a new {@code Runner}.
*
* @deprecated Use the constructor with {@code BaseMemoryService} instead even if with a null if
* you don't need the memory service.
*/
@InlineMe(replacement = "this(agent, appName, artifactService, sessionService, null)")
@Deprecated
public Runner(
BaseAgent agent,
String appName,
BaseArtifactService artifactService,
BaseSessionService sessionService) {
this(agent, appName, artifactService, sessionService, null);
}

public BaseAgent agent() {
Expand All @@ -82,6 +103,10 @@ public BaseSessionService sessionService() {
return this.sessionService;
}

public BaseMemoryService memoryService() {
return this.memoryService;
}

/**
* Appends a new user message to the session history.
*
Expand Down Expand Up @@ -185,13 +210,10 @@ public Flowable<Event> runAsync(Session session, Content newMessage, RunConfig r
sess -> {
BaseAgent rootAgent = this.agent;
InvocationContext invocationContext =
InvocationContext.create(
this.sessionService,
this.artifactService,
InvocationContext.newInvocationContextId(),
rootAgent,
newInvocationContext(
sess,
newMessage,
Optional.of(newMessage),
/* liveRequestQueue= */ Optional.empty(),
runConfig);

if (newMessage != null) {
Expand Down Expand Up @@ -240,7 +262,8 @@ private InvocationContext newInvocationContextForLive(
}
}
}
return newInvocationContext(session, liveRequestQueue, runConfigBuilder.build());
return newInvocationContext(
session, /* newMessage= */ Optional.empty(), liveRequestQueue, runConfigBuilder.build());
}

/**
Expand All @@ -249,16 +272,24 @@ private InvocationContext newInvocationContextForLive(
* @return a new {@link InvocationContext}.
*/
private InvocationContext newInvocationContext(
Session session, Optional<LiveRequestQueue> liveRequestQueue, RunConfig runConfig) {
Session session,
Optional<Content> newMessage,
Optional<LiveRequestQueue> liveRequestQueue,
RunConfig runConfig) {
BaseAgent rootAgent = this.agent;
InvocationContext invocationContext =
InvocationContext.create(
new InvocationContext(
this.sessionService,
this.artifactService,
this.memoryService,
liveRequestQueue,
/* branch= */ Optional.empty(),
InvocationContext.newInvocationContextId(),
rootAgent,
session,
liveRequestQueue.orElse(null),
runConfig);
newMessage,
runConfig,
/* endInvocation= */ false);
invocationContext.agent(this.findAgentToRun(session, rootAgent));
return invocationContext;
}
Expand Down
Loading