Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions core/src/main/java/com/google/adk/agents/BaseAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,13 @@ public Optional<List<? extends AfterAgentCallback>> afterAgentCallback() {
* @return new context with updated branch name.
*/
private InvocationContext createInvocationContext(InvocationContext parentContext) {
InvocationContext invocationContext = InvocationContext.copyOf(parentContext);
invocationContext.agent(this);
InvocationContext.Builder builder = parentContext.toBuilder();
builder.agent(this);
// Check for branch to be truthy (not None, not empty string),
if (parentContext.branch().filter(s -> !s.isEmpty()).isPresent()) {
invocationContext.branch(parentContext.branch().get() + "." + name());
builder.branch(parentContext.branch().get() + "." + name());
}
return invocationContext;
return builder.build();
}

/**
Expand Down Expand Up @@ -303,17 +303,16 @@ private Single<Optional<Event>> callCallback(
return maybeContent
.map(
content -> {
Event.Builder eventBuilder =
invocationContext.setEndInvocation(true);
return Optional.of(
Event.builder()
.id(Event.generateEventId())
.invocationId(invocationContext.invocationId())
.author(name())
.branch(invocationContext.branch())
.actions(callbackContext.eventActions());

eventBuilder.content(Optional.of(content));
invocationContext.setEndInvocation(true);
return Optional.of(eventBuilder.build());
.actions(callbackContext.eventActions())
.content(content)
.build());
})
.toFlowable();
})
Expand Down
100 changes: 72 additions & 28 deletions core/src/main/java/com/google/adk/agents/InvocationContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,25 @@ public class InvocationContext {
private final BaseMemoryService memoryService;
private final Plugin pluginManager;
private final Optional<LiveRequestQueue> liveRequestQueue;
private final Map<String, ActiveStreamingTool> activeStreamingTools = new ConcurrentHashMap<>();
private final Map<String, ActiveStreamingTool> activeStreamingTools;
private final String invocationId;
private final Session session;
private final Optional<Content> userContent;
private final RunConfig runConfig;
private final ResumabilityConfig resumabilityConfig;
private final InvocationCostManager invocationCostManager = new InvocationCostManager();
private final InvocationCostManager invocationCostManager;

private Optional<String> branch;
private BaseAgent agent;
private boolean endInvocation;

private InvocationContext(Builder builder) {
protected InvocationContext(Builder builder) {
this.sessionService = builder.sessionService;
this.artifactService = builder.artifactService;
this.memoryService = builder.memoryService;
this.pluginManager = builder.pluginManager;
this.liveRequestQueue = builder.liveRequestQueue;
this.activeStreamingTools = builder.activeStreamingTools;
this.branch = builder.branch;
this.invocationId = builder.invocationId;
this.agent = builder.agent;
Expand All @@ -71,6 +72,7 @@ private InvocationContext(Builder builder) {
this.runConfig = builder.runConfig;
this.endInvocation = builder.endInvocation;
this.resumabilityConfig = builder.resumabilityConfig;
this.invocationCostManager = builder.invocationCostManager;
}

/**
Expand Down Expand Up @@ -188,7 +190,7 @@ public static InvocationContext create(
.artifactService(artifactService)
.agent(agent)
.session(session)
.liveRequestQueue(Optional.ofNullable(liveRequestQueue))
.liveRequestQueue(liveRequestQueue)
.runConfig(runConfig)
.build();
}
Expand All @@ -198,26 +200,19 @@ public static Builder builder() {
return new Builder();
}

/** Creates a shallow copy of the given {@link InvocationContext}. */
/** Returns a {@link Builder} initialized with the values of this instance. */
public Builder toBuilder() {
return new Builder(this);
}

/**
* Creates a shallow copy of the given {@link InvocationContext}.
*
* @deprecated Use {@code other.toBuilder().build()} instead.
*/
@Deprecated(forRemoval = true)
public static InvocationContext copyOf(InvocationContext other) {
InvocationContext newContext =
builder()
.sessionService(other.sessionService)
.artifactService(other.artifactService)
.memoryService(other.memoryService)
.pluginManager(other.pluginManager)
.liveRequestQueue(other.liveRequestQueue)
.branch(other.branch)
.invocationId(other.invocationId)
.agent(other.agent)
.session(other.session)
.userContent(other.userContent)
.runConfig(other.runConfig)
.endInvocation(other.endInvocation)
.resumabilityConfig(other.resumabilityConfig)
.build();
newContext.activeStreamingTools.putAll(other.activeStreamingTools);
return newContext;
return other.toBuilder().build();
}

/** Returns the session service for managing session state. */
Expand Down Expand Up @@ -258,7 +253,10 @@ public String invocationId() {
/**
* Sets the [branch] ID for the current invocation. A branch represents a fork in the conversation
* history.
*
* @deprecated Use {@link #toBuilder()} and {@link Builder#branch(String)} instead.
*/
@Deprecated(forRemoval = true)
public void branch(@Nullable String branch) {
this.branch = Optional.ofNullable(branch);
}
Expand All @@ -276,7 +274,12 @@ public BaseAgent agent() {
return agent;
}

/** Sets the [agent] being invoked. This is useful when delegating to a sub-agent. */
/**
* Sets the [agent] being invoked. This is useful when delegating to a sub-agent.
*
* @deprecated Use {@link #toBuilder()} and {@link Builder#agent(BaseAgent)} instead.
*/
@Deprecated(forRemoval = true)
public void agent(BaseAgent agent) {
this.agent = agent;
}
Expand Down Expand Up @@ -370,15 +373,53 @@ void incrementAndEnforceLlmCallsLimit(RunConfig runConfig)
"Max number of llm calls limit of " + runConfig.maxLlmCalls() + " exceeded");
}
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof InvocationCostManager that)) {
return false;
}
return numberOfLlmCalls == that.numberOfLlmCalls;
}

@Override
public int hashCode() {
return Integer.hashCode(numberOfLlmCalls);
}
}

/** Builder for {@link InvocationContext}. */
public static class Builder {

private Builder() {}

private Builder(InvocationContext context) {
this.sessionService = context.sessionService;
this.artifactService = context.artifactService;
this.memoryService = context.memoryService;
this.pluginManager = context.pluginManager;
this.liveRequestQueue = context.liveRequestQueue;
this.activeStreamingTools = new ConcurrentHashMap<>(context.activeStreamingTools);
this.branch = context.branch;
this.invocationId = context.invocationId;
this.agent = context.agent;
this.session = context.session;
this.userContent = context.userContent;
this.runConfig = context.runConfig;
this.endInvocation = context.endInvocation;
this.resumabilityConfig = context.resumabilityConfig;
this.invocationCostManager = context.invocationCostManager;
}

private BaseSessionService sessionService;
private BaseArtifactService artifactService;
private BaseMemoryService memoryService;
private Plugin pluginManager = new PluginManager();
private Optional<LiveRequestQueue> liveRequestQueue = Optional.empty();
private Map<String, ActiveStreamingTool> activeStreamingTools = new ConcurrentHashMap<>();
private Optional<String> branch = Optional.empty();
private String invocationId = newInvocationContextId();
private BaseAgent agent;
Expand All @@ -387,6 +428,7 @@ public static class Builder {
private RunConfig runConfig = RunConfig.builder().build();
private boolean endInvocation = false;
private ResumabilityConfig resumabilityConfig = new ResumabilityConfig();
private InvocationCostManager invocationCostManager = new InvocationCostManager();

/**
* Sets the session service for managing session state.
Expand Down Expand Up @@ -458,8 +500,8 @@ public Builder liveRequestQueue(Optional<LiveRequestQueue> liveRequestQueue) {
* @return this builder instance for chaining.
*/
@CanIgnoreReturnValue
public Builder liveRequestQueue(LiveRequestQueue liveRequestQueue) {
this.liveRequestQueue = Optional.of(liveRequestQueue);
public Builder liveRequestQueue(@Nullable LiveRequestQueue liveRequestQueue) {
this.liveRequestQueue = Optional.ofNullable(liveRequestQueue);
return this;
}

Expand Down Expand Up @@ -618,7 +660,8 @@ public boolean equals(Object o) {
&& Objects.equals(session, that.session)
&& Objects.equals(userContent, that.userContent)
&& Objects.equals(runConfig, that.runConfig)
&& Objects.equals(resumabilityConfig, that.resumabilityConfig);
&& Objects.equals(resumabilityConfig, that.resumabilityConfig)
&& Objects.equals(invocationCostManager, that.invocationCostManager);
}

@Override
Expand All @@ -637,6 +680,7 @@ public int hashCode() {
userContent,
runConfig,
endInvocation,
resumabilityConfig);
resumabilityConfig,
invocationCostManager);
}
}
21 changes: 10 additions & 11 deletions core/src/main/java/com/google/adk/agents/ParallelAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
package com.google.adk.agents;

import static com.google.common.base.Strings.isNullOrEmpty;
import static com.google.common.collect.ImmutableList.toImmutableList;

import com.google.adk.agents.ConfigAgentUtils.ConfigurationException;
import com.google.adk.events.Event;
import io.reactivex.rxjava3.core.Flowable;
import java.util.ArrayList;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -101,14 +101,15 @@ public static ParallelAgent fromConfig(ParallelAgentConfig config, String config
*
* @param currentAgent Current agent.
* @param invocationContext Invocation context to update.
* @return A new invocation context with branch set.
*/
private static void setBranchForCurrentAgent(
private static InvocationContext setBranchForCurrentAgent(
BaseAgent currentAgent, InvocationContext invocationContext) {
String branch = invocationContext.branch().orElse(null);
if (isNullOrEmpty(branch)) {
invocationContext.branch(currentAgent.name());
return invocationContext.toBuilder().branch(currentAgent.name()).build();
} else {
invocationContext.branch(branch + "." + currentAgent.name());
return invocationContext.toBuilder().branch(branch + "." + currentAgent.name()).build();
}
}

Expand All @@ -122,18 +123,16 @@ private static void setBranchForCurrentAgent(
*/
@Override
protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
setBranchForCurrentAgent(this, invocationContext);

List<? extends BaseAgent> currentSubAgents = subAgents();
if (currentSubAgents == null || currentSubAgents.isEmpty()) {
return Flowable.empty();
}

List<Flowable<Event>> agentFlowables = new ArrayList<>();
for (BaseAgent subAgent : currentSubAgents) {
agentFlowables.add(subAgent.runAsync(invocationContext));
}
return Flowable.merge(agentFlowables);
var updatedInvocationContext = setBranchForCurrentAgent(this, invocationContext);
return Flowable.merge(
currentSubAgents.stream()
.map(subAgent -> subAgent.runAsync(updatedInvocationContext))
.collect(toImmutableList()));
}

/**
Expand Down
10 changes: 7 additions & 3 deletions core/src/main/java/com/google/adk/flows/llmflows/Examples.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.google.adk.examples.ExampleUtils;
import com.google.adk.models.LlmRequest;
import com.google.common.collect.ImmutableList;
import com.google.genai.types.Content;
import io.reactivex.rxjava3.core.Single;

/** {@link RequestProcessor} that populates examples in LLM request. */
Expand All @@ -38,9 +39,12 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
LlmRequest.Builder builder = request.toBuilder();

String query =
context.userContent().isPresent()
? context.userContent().get().parts().get().get(0).text().orElse("")
: "";
context
.userContent()
.flatMap(Content::parts)
.filter(parts -> !parts.isEmpty())
.map(parts -> parts.get(0).text().orElse(""))
.orElse("");
agent
.exampleProvider()
.ifPresent(
Expand Down
Loading