diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index adc973529..53a978974 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -187,13 +187,13 @@ public Optional> afterAgentCallback() { * @return new context with updated branch name. */ private InvocationContext createInvocationContext(InvocationContext parentContext) { - InvocationContext invocationContext = InvocationContext.copyOf(parentContext); - invocationContext.agent(this); + InvocationContext.Builder builder = parentContext.toBuilder(); + builder.agent(this); // Check for branch to be truthy (not None, not empty string), if (parentContext.branch().filter(s -> !s.isEmpty()).isPresent()) { - invocationContext.branch(parentContext.branch().get() + "." + name()); + builder.branch(parentContext.branch().get() + "." + name()); } - return invocationContext; + return builder.build(); } /** @@ -303,17 +303,16 @@ private Single> callCallback( return maybeContent .map( content -> { - Event.Builder eventBuilder = + invocationContext.setEndInvocation(true); + return Optional.of( Event.builder() .id(Event.generateEventId()) .invocationId(invocationContext.invocationId()) .author(name()) .branch(invocationContext.branch()) - .actions(callbackContext.eventActions()); - - eventBuilder.content(Optional.of(content)); - invocationContext.setEndInvocation(true); - return Optional.of(eventBuilder.build()); + .actions(callbackContext.eventActions()) + .content(content) + .build()); }) .toFlowable(); }) diff --git a/core/src/main/java/com/google/adk/agents/InvocationContext.java b/core/src/main/java/com/google/adk/agents/InvocationContext.java index 9396403bb..532bc92fd 100644 --- a/core/src/main/java/com/google/adk/agents/InvocationContext.java +++ b/core/src/main/java/com/google/adk/agents/InvocationContext.java @@ -45,24 +45,25 @@ public class InvocationContext { private final BaseMemoryService memoryService; private final Plugin pluginManager; private final Optional liveRequestQueue; - private final Map activeStreamingTools = new ConcurrentHashMap<>(); + private final Map activeStreamingTools; private final String invocationId; private final Session session; private final Optional userContent; private final RunConfig runConfig; private final ResumabilityConfig resumabilityConfig; - private final InvocationCostManager invocationCostManager = new InvocationCostManager(); + private final InvocationCostManager invocationCostManager; private Optional branch; private BaseAgent agent; private boolean endInvocation; - private InvocationContext(Builder builder) { + protected InvocationContext(Builder builder) { this.sessionService = builder.sessionService; this.artifactService = builder.artifactService; this.memoryService = builder.memoryService; this.pluginManager = builder.pluginManager; this.liveRequestQueue = builder.liveRequestQueue; + this.activeStreamingTools = builder.activeStreamingTools; this.branch = builder.branch; this.invocationId = builder.invocationId; this.agent = builder.agent; @@ -71,6 +72,7 @@ private InvocationContext(Builder builder) { this.runConfig = builder.runConfig; this.endInvocation = builder.endInvocation; this.resumabilityConfig = builder.resumabilityConfig; + this.invocationCostManager = builder.invocationCostManager; } /** @@ -188,7 +190,7 @@ public static InvocationContext create( .artifactService(artifactService) .agent(agent) .session(session) - .liveRequestQueue(Optional.ofNullable(liveRequestQueue)) + .liveRequestQueue(liveRequestQueue) .runConfig(runConfig) .build(); } @@ -198,26 +200,19 @@ public static Builder builder() { return new Builder(); } - /** Creates a shallow copy of the given {@link InvocationContext}. */ + /** Returns a {@link Builder} initialized with the values of this instance. */ + public Builder toBuilder() { + return new Builder(this); + } + + /** + * Creates a shallow copy of the given {@link InvocationContext}. + * + * @deprecated Use {@code other.toBuilder().build()} instead. + */ + @Deprecated(forRemoval = true) public static InvocationContext copyOf(InvocationContext other) { - InvocationContext newContext = - builder() - .sessionService(other.sessionService) - .artifactService(other.artifactService) - .memoryService(other.memoryService) - .pluginManager(other.pluginManager) - .liveRequestQueue(other.liveRequestQueue) - .branch(other.branch) - .invocationId(other.invocationId) - .agent(other.agent) - .session(other.session) - .userContent(other.userContent) - .runConfig(other.runConfig) - .endInvocation(other.endInvocation) - .resumabilityConfig(other.resumabilityConfig) - .build(); - newContext.activeStreamingTools.putAll(other.activeStreamingTools); - return newContext; + return other.toBuilder().build(); } /** Returns the session service for managing session state. */ @@ -258,7 +253,10 @@ public String invocationId() { /** * Sets the [branch] ID for the current invocation. A branch represents a fork in the conversation * history. + * + * @deprecated Use {@link #toBuilder()} and {@link Builder#branch(String)} instead. */ + @Deprecated(forRemoval = true) public void branch(@Nullable String branch) { this.branch = Optional.ofNullable(branch); } @@ -276,7 +274,12 @@ public BaseAgent agent() { return agent; } - /** Sets the [agent] being invoked. This is useful when delegating to a sub-agent. */ + /** + * Sets the [agent] being invoked. This is useful when delegating to a sub-agent. + * + * @deprecated Use {@link #toBuilder()} and {@link Builder#agent(BaseAgent)} instead. + */ + @Deprecated(forRemoval = true) public void agent(BaseAgent agent) { this.agent = agent; } @@ -370,15 +373,53 @@ void incrementAndEnforceLlmCallsLimit(RunConfig runConfig) "Max number of llm calls limit of " + runConfig.maxLlmCalls() + " exceeded"); } } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof InvocationCostManager that)) { + return false; + } + return numberOfLlmCalls == that.numberOfLlmCalls; + } + + @Override + public int hashCode() { + return Integer.hashCode(numberOfLlmCalls); + } } /** Builder for {@link InvocationContext}. */ public static class Builder { + + private Builder() {} + + private Builder(InvocationContext context) { + this.sessionService = context.sessionService; + this.artifactService = context.artifactService; + this.memoryService = context.memoryService; + this.pluginManager = context.pluginManager; + this.liveRequestQueue = context.liveRequestQueue; + this.activeStreamingTools = new ConcurrentHashMap<>(context.activeStreamingTools); + this.branch = context.branch; + this.invocationId = context.invocationId; + this.agent = context.agent; + this.session = context.session; + this.userContent = context.userContent; + this.runConfig = context.runConfig; + this.endInvocation = context.endInvocation; + this.resumabilityConfig = context.resumabilityConfig; + this.invocationCostManager = context.invocationCostManager; + } + private BaseSessionService sessionService; private BaseArtifactService artifactService; private BaseMemoryService memoryService; private Plugin pluginManager = new PluginManager(); private Optional liveRequestQueue = Optional.empty(); + private Map activeStreamingTools = new ConcurrentHashMap<>(); private Optional branch = Optional.empty(); private String invocationId = newInvocationContextId(); private BaseAgent agent; @@ -387,6 +428,7 @@ public static class Builder { private RunConfig runConfig = RunConfig.builder().build(); private boolean endInvocation = false; private ResumabilityConfig resumabilityConfig = new ResumabilityConfig(); + private InvocationCostManager invocationCostManager = new InvocationCostManager(); /** * Sets the session service for managing session state. @@ -458,8 +500,8 @@ public Builder liveRequestQueue(Optional liveRequestQueue) { * @return this builder instance for chaining. */ @CanIgnoreReturnValue - public Builder liveRequestQueue(LiveRequestQueue liveRequestQueue) { - this.liveRequestQueue = Optional.of(liveRequestQueue); + public Builder liveRequestQueue(@Nullable LiveRequestQueue liveRequestQueue) { + this.liveRequestQueue = Optional.ofNullable(liveRequestQueue); return this; } @@ -618,7 +660,8 @@ public boolean equals(Object o) { && Objects.equals(session, that.session) && Objects.equals(userContent, that.userContent) && Objects.equals(runConfig, that.runConfig) - && Objects.equals(resumabilityConfig, that.resumabilityConfig); + && Objects.equals(resumabilityConfig, that.resumabilityConfig) + && Objects.equals(invocationCostManager, that.invocationCostManager); } @Override @@ -637,6 +680,7 @@ public int hashCode() { userContent, runConfig, endInvocation, - resumabilityConfig); + resumabilityConfig, + invocationCostManager); } } diff --git a/core/src/main/java/com/google/adk/agents/ParallelAgent.java b/core/src/main/java/com/google/adk/agents/ParallelAgent.java index 4bfd0b255..f30d951aa 100644 --- a/core/src/main/java/com/google/adk/agents/ParallelAgent.java +++ b/core/src/main/java/com/google/adk/agents/ParallelAgent.java @@ -16,11 +16,11 @@ package com.google.adk.agents; import static com.google.common.base.Strings.isNullOrEmpty; +import static com.google.common.collect.ImmutableList.toImmutableList; import com.google.adk.agents.ConfigAgentUtils.ConfigurationException; import com.google.adk.events.Event; import io.reactivex.rxjava3.core.Flowable; -import java.util.ArrayList; import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -101,14 +101,15 @@ public static ParallelAgent fromConfig(ParallelAgentConfig config, String config * * @param currentAgent Current agent. * @param invocationContext Invocation context to update. + * @return A new invocation context with branch set. */ - private static void setBranchForCurrentAgent( + private static InvocationContext setBranchForCurrentAgent( BaseAgent currentAgent, InvocationContext invocationContext) { String branch = invocationContext.branch().orElse(null); if (isNullOrEmpty(branch)) { - invocationContext.branch(currentAgent.name()); + return invocationContext.toBuilder().branch(currentAgent.name()).build(); } else { - invocationContext.branch(branch + "." + currentAgent.name()); + return invocationContext.toBuilder().branch(branch + "." + currentAgent.name()).build(); } } @@ -122,18 +123,16 @@ private static void setBranchForCurrentAgent( */ @Override protected Flowable runAsyncImpl(InvocationContext invocationContext) { - setBranchForCurrentAgent(this, invocationContext); - List currentSubAgents = subAgents(); if (currentSubAgents == null || currentSubAgents.isEmpty()) { return Flowable.empty(); } - List> agentFlowables = new ArrayList<>(); - for (BaseAgent subAgent : currentSubAgents) { - agentFlowables.add(subAgent.runAsync(invocationContext)); - } - return Flowable.merge(agentFlowables); + var updatedInvocationContext = setBranchForCurrentAgent(this, invocationContext); + return Flowable.merge( + currentSubAgents.stream() + .map(subAgent -> subAgent.runAsync(updatedInvocationContext)) + .collect(toImmutableList())); } /** diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Examples.java b/core/src/main/java/com/google/adk/flows/llmflows/Examples.java index a85711b3e..d9cee5fa0 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Examples.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Examples.java @@ -21,6 +21,7 @@ import com.google.adk.examples.ExampleUtils; import com.google.adk.models.LlmRequest; import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; import io.reactivex.rxjava3.core.Single; /** {@link RequestProcessor} that populates examples in LLM request. */ @@ -38,9 +39,12 @@ public Single processRequest( LlmRequest.Builder builder = request.toBuilder(); String query = - context.userContent().isPresent() - ? context.userContent().get().parts().get().get(0).text().orElse("") - : ""; + context + .userContent() + .flatMap(Content::parts) + .filter(parts -> !parts.isEmpty()) + .map(parts -> parts.get(0).text().orElse("")) + .orElse(""); agent .exampleProvider() .ifPresent( diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index f4b17c41f..c774ff361 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -386,12 +386,13 @@ public Flowable runAsync( // Create initial context InvocationContext initialContext = - newInvocationContextWithId( - session, - Optional.of(newMessage), - /* liveRequestQueue= */ Optional.empty(), - runConfig, - invocationId); + newInvocationContextBuilder( + session, + Optional.of(newMessage), + /* liveRequestQueue= */ Optional.empty(), + runConfig) + .invocationId(invocationId) + .build(); return Telemetry.traceFlowable( spanContext, @@ -427,14 +428,16 @@ public Flowable runAsync( // Create context with updated session for // beforeRunCallback InvocationContext contextWithUpdatedSession = - newInvocationContextWithId( - updatedSession, - event.content(), - /* liveRequestQueue= */ Optional.empty(), - runConfig, - invocationId); - contextWithUpdatedSession.agent( - this.findAgentToRun(updatedSession, rootAgent)); + newInvocationContextBuilder( + updatedSession, + event.content(), + /* liveRequestQueue= */ Optional.empty(), + runConfig) + .invocationId(invocationId) + .agent( + this.findAgentToRun( + updatedSession, rootAgent)) + .build(); // Call beforeRunCallback with updated session Maybe beforeRunEvent = @@ -553,35 +556,14 @@ private InvocationContext newInvocationContext( Optional newMessage, Optional liveRequestQueue, RunConfig runConfig) { - BaseAgent rootAgent = this.agent; - var invocationContextBuilder = - InvocationContext.builder() - .sessionService(this.sessionService) - .artifactService(this.artifactService) - .memoryService(this.memoryService) - .pluginManager(this.pluginManager) - .agent(rootAgent) - .session(session) - .userContent(newMessage) - .runConfig(runConfig) - .resumabilityConfig(this.resumabilityConfig); - liveRequestQueue.ifPresent(invocationContextBuilder::liveRequestQueue); - var invocationContext = invocationContextBuilder.build(); - invocationContext.agent(this.findAgentToRun(session, rootAgent)); - return invocationContext; + return newInvocationContextBuilder(session, newMessage, liveRequestQueue, runConfig).build(); } - /** - * Creates a new InvocationContext with a specific invocation ID. - * - * @return a new {@link InvocationContext} with the specified invocation ID. - */ - private InvocationContext newInvocationContextWithId( + private InvocationContext.Builder newInvocationContextBuilder( Session session, Optional newMessage, Optional liveRequestQueue, - RunConfig runConfig, - String invocationId) { + RunConfig runConfig) { BaseAgent rootAgent = this.agent; var invocationContextBuilder = InvocationContext.builder() @@ -589,16 +571,14 @@ private InvocationContext newInvocationContextWithId( .artifactService(this.artifactService) .memoryService(this.memoryService) .pluginManager(this.pluginManager) - .invocationId(invocationId) .agent(rootAgent) .session(session) - .userContent(newMessage) + .userContent(newMessage.orElse(Content.fromParts())) .runConfig(runConfig) - .resumabilityConfig(this.resumabilityConfig); + .resumabilityConfig(this.resumabilityConfig) + .agent(this.findAgentToRun(session, rootAgent)); liveRequestQueue.ifPresent(invocationContextBuilder::liveRequestQueue); - var invocationContext = invocationContextBuilder.build(); - invocationContext.agent(this.findAgentToRun(session, rootAgent)); - return invocationContext; + return invocationContextBuilder; } /** diff --git a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java index d85f2eb71..7ceff6a75 100644 --- a/core/src/test/java/com/google/adk/agents/InvocationContextTest.java +++ b/core/src/test/java/com/google/adk/agents/InvocationContextTest.java @@ -69,7 +69,7 @@ public void setUp() { } @Test - public void testCreateWithUserContent() { + public void testBuildWithUserContent() { InvocationContext context = InvocationContext.builder() .sessionService(mockSessionService) @@ -79,7 +79,7 @@ public void testCreateWithUserContent() { .invocationId(testInvocationId) .agent(mockAgent) .session(session) - .userContent(Optional.of(userContent)) + .userContent(userContent) .runConfig(runConfig) .endInvocation(false) .build(); @@ -98,7 +98,7 @@ public void testCreateWithUserContent() { } @Test - public void testCreateWithNullUserContent() { + public void testBuildWithNullUserContent() { InvocationContext context = InvocationContext.builder() .sessionService(mockSessionService) @@ -108,7 +108,6 @@ public void testCreateWithNullUserContent() { .invocationId(testInvocationId) .agent(mockAgent) .session(session) - .userContent(Optional.empty()) .runConfig(runConfig) .endInvocation(false) .build(); @@ -118,7 +117,7 @@ public void testCreateWithNullUserContent() { } @Test - public void testCreateWithLiveRequestQueue() { + public void testBuildWithLiveRequestQueue() { InvocationContext context = InvocationContext.builder() .sessionService(mockSessionService) @@ -128,7 +127,6 @@ public void testCreateWithLiveRequestQueue() { .liveRequestQueue(liveRequestQueue) .agent(mockAgent) .session(session) - .userContent(Optional.empty()) .runConfig(runConfig) .endInvocation(false) .build(); @@ -157,13 +155,13 @@ public void testCopyOf() { .invocationId(testInvocationId) .agent(mockAgent) .session(session) - .userContent(Optional.of(userContent)) + .userContent(userContent) .runConfig(runConfig) .endInvocation(false) .build(); originalContext.activeStreamingTools().putAll(activeStreamingTools); - InvocationContext copiedContext = InvocationContext.copyOf(originalContext); + InvocationContext copiedContext = originalContext.toBuilder().build(); assertThat(copiedContext).isNotNull(); assertThat(copiedContext).isNotSameInstanceAs(originalContext); @@ -193,7 +191,7 @@ public void testGetters() { .invocationId(testInvocationId) .agent(mockAgent) .session(session) - .userContent(Optional.of(userContent)) + .userContent(userContent) .runConfig(runConfig) .endInvocation(false) .build(); @@ -212,6 +210,7 @@ public void testGetters() { @Test public void testSetAgent() { + BaseAgent newMockAgent = mock(BaseAgent.class); InvocationContext context = InvocationContext.builder() .sessionService(mockSessionService) @@ -221,14 +220,12 @@ public void testSetAgent() { .invocationId(testInvocationId) .agent(mockAgent) .session(session) - .userContent(Optional.of(userContent)) + .userContent(userContent) .runConfig(runConfig) .endInvocation(false) + .agent(newMockAgent) .build(); - BaseAgent newMockAgent = mock(BaseAgent.class); - context.agent(newMockAgent); - assertThat(context.agent()).isEqualTo(newMockAgent); } @@ -255,7 +252,7 @@ public void testEquals_sameObject() { .invocationId(testInvocationId) .agent(mockAgent) .session(session) - .userContent(Optional.of(userContent)) + .userContent(userContent) .runConfig(runConfig) .endInvocation(false) .build(); @@ -274,7 +271,7 @@ public void testEquals_null() { .invocationId(testInvocationId) .agent(mockAgent) .session(session) - .userContent(Optional.of(userContent)) + .userContent(userContent) .runConfig(runConfig) .endInvocation(false) .build(); @@ -293,7 +290,7 @@ public void testEquals_sameValues() { .invocationId(testInvocationId) .agent(mockAgent) .session(session) - .userContent(Optional.of(userContent)) + .userContent(userContent) .runConfig(runConfig) .endInvocation(false) .build(); @@ -308,7 +305,7 @@ public void testEquals_sameValues() { .invocationId(testInvocationId) .agent(mockAgent) .session(session) - .userContent(Optional.of(userContent)) + .userContent(userContent) .runConfig(runConfig) .endInvocation(false) .build(); @@ -328,7 +325,7 @@ public void testEquals_differentValues() { .invocationId(testInvocationId) .agent(mockAgent) .session(session) - .userContent(Optional.of(userContent)) + .userContent(userContent) .runConfig(runConfig) .endInvocation(false) .build(); @@ -343,7 +340,7 @@ public void testEquals_differentValues() { .invocationId(testInvocationId) .agent(mockAgent) .session(session) - .userContent(Optional.of(userContent)) + .userContent(userContent) .runConfig(runConfig) .endInvocation(false) .build(); @@ -357,7 +354,7 @@ public void testEquals_differentValues() { .invocationId("another-id") // Different ID .agent(mockAgent) .session(session) - .userContent(Optional.of(userContent)) + .userContent(userContent) .runConfig(runConfig) .endInvocation(false) .build(); @@ -371,7 +368,7 @@ public void testEquals_differentValues() { .invocationId(testInvocationId) .agent(mock(BaseAgent.class)) // Different mock .session(session) - .userContent(Optional.of(userContent)) + .userContent(userContent) .runConfig(runConfig) .endInvocation(false) .build(); @@ -385,7 +382,6 @@ public void testEquals_differentValues() { .invocationId(testInvocationId) .agent(mockAgent) .session(session) - .userContent(Optional.empty()) .runConfig(runConfig) .endInvocation(false) .build(); @@ -399,7 +395,6 @@ public void testEquals_differentValues() { .liveRequestQueue(liveRequestQueue) .agent(mockAgent) .session(session) - .userContent(Optional.empty()) .runConfig(runConfig) .endInvocation(false) .build(); @@ -422,7 +417,7 @@ public void testHashCode_differentValues() { .invocationId(testInvocationId) .agent(mockAgent) .session(session) - .userContent(Optional.of(userContent)) + .userContent(userContent) .runConfig(runConfig) .endInvocation(false) .build(); @@ -437,7 +432,7 @@ public void testHashCode_differentValues() { .invocationId(testInvocationId) .agent(mockAgent) .session(session) - .userContent(Optional.of(userContent)) + .userContent(userContent) .runConfig(runConfig) .endInvocation(false) .build(); @@ -451,7 +446,7 @@ public void testHashCode_differentValues() { .invocationId("another-id") // Different ID .agent(mockAgent) .session(session) - .userContent(Optional.of(userContent)) + .userContent(userContent) .runConfig(runConfig) .endInvocation(false) .build(); diff --git a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java index e1756d09f..1cb0c8771 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/ContentsTest.java @@ -22,12 +22,9 @@ import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; -import com.google.adk.agents.RunConfig; -import com.google.adk.artifacts.InMemoryArtifactService; import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; import com.google.adk.models.Model; -import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -639,14 +636,11 @@ private List runContentsProcessorWithIncludeContents( .events(new ArrayList<>(events)) .build(); InvocationContext context = - InvocationContext.create( - new InMemorySessionService(), - new InMemoryArtifactService(), - "test-invocation", - agent, - session, - /* userContent= */ null, - RunConfig.builder().build()); + InvocationContext.builder() + .invocationId("test-invocation") + .agent(agent) + .session(session) + .build(); LlmRequest initialRequest = LlmRequest.builder().build(); RequestProcessor.RequestProcessingResult result = @@ -671,14 +665,11 @@ private List runContentsProcessorWithModelName(List events, Stri .events(new ArrayList<>(events)) .build(); InvocationContext context = - InvocationContext.create( - new InMemorySessionService(), - new InMemoryArtifactService(), - "test-invocation", - agent, - session, - /* userContent= */ null, - RunConfig.builder().build()); + InvocationContext.builder() + .invocationId("test-invocation") + .agent(agent) + .session(session) + .build(); LlmRequest initialRequest = LlmRequest.builder().build(); RequestProcessor.RequestProcessingResult result = diff --git a/core/src/test/java/com/google/adk/flows/llmflows/InstructionsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/InstructionsTest.java index 623caabb8..2ac9e454d 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/InstructionsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/InstructionsTest.java @@ -25,7 +25,6 @@ import com.google.adk.agents.Instruction; import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; -import com.google.adk.agents.RunConfig; import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.models.LlmRequest; import com.google.adk.sessions.InMemorySessionService; @@ -62,14 +61,13 @@ public void setUp() { } private InvocationContext createContext(BaseAgent agent, Session session) { - return InvocationContext.create( - sessionService, - mockArtifactService, - "test-invocation-id", - agent, - session, - null, - RunConfig.builder().build()); + return InvocationContext.builder() + .sessionService(sessionService) + .artifactService(mockArtifactService) + .invocationId("test-invocation-id") + .agent(agent) + .session(session) + .build(); } private Session createSession() { diff --git a/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java b/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java index 971930907..815477c7e 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/RequestConfirmationLlmRequestProcessorTest.java @@ -24,7 +24,6 @@ import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; -import com.google.adk.agents.RunConfig; import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; import com.google.adk.plugins.PluginManager; @@ -37,7 +36,6 @@ import com.google.genai.types.FunctionCall; import com.google.genai.types.FunctionResponse; import com.google.genai.types.Part; -import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -216,19 +214,12 @@ public void runAsync_noUserConfirmationEvent_empty() { } private static InvocationContext createInvocationContext(LlmAgent agent, Session session) { - return new InvocationContext( - /* sessionService= */ null, - /* artifactService= */ null, - /* memoryService= */ null, - /* pluginManager= */ new PluginManager(), - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - /* invocationId= */ InvocationContext.newInvocationContextId(), - /* agent= */ agent, - /* session= */ session, - /* userContent= */ Optional.empty(), - /* runConfig= */ RunConfig.builder().build(), - /* endInvocation= */ false); + return InvocationContext.builder() + .pluginManager(new PluginManager()) + .invocationId(InvocationContext.newInvocationContextId()) + .agent(agent) + .session(session) + .build(); } private static LlmAgent createAgentWithEchoTool() { diff --git a/core/src/test/java/com/google/adk/testing/TestBaseAgent.java b/core/src/test/java/com/google/adk/testing/TestBaseAgent.java index 46d9a43d5..e3e5a632c 100644 --- a/core/src/test/java/com/google/adk/testing/TestBaseAgent.java +++ b/core/src/test/java/com/google/adk/testing/TestBaseAgent.java @@ -58,14 +58,14 @@ public TestBaseAgent( @Override public Flowable runAsyncImpl(InvocationContext invocationContext) { - lastInvocationContext = InvocationContext.copyOf(invocationContext); + lastInvocationContext = invocationContext.toBuilder().build(); invocationCount++; return eventSupplier.get(); } @Override public Flowable runLiveImpl(InvocationContext invocationContext) { - lastInvocationContext = InvocationContext.copyOf(invocationContext); + lastInvocationContext = invocationContext.toBuilder().build(); invocationCount++; return eventSupplier.get(); } diff --git a/core/src/test/java/com/google/adk/testing/TestUtils.java b/core/src/test/java/com/google/adk/testing/TestUtils.java index 56cb9d07a..e4c2949eb 100644 --- a/core/src/test/java/com/google/adk/testing/TestUtils.java +++ b/core/src/test/java/com/google/adk/testing/TestUtils.java @@ -24,10 +24,8 @@ import com.google.adk.agents.InvocationContext; import com.google.adk.agents.LlmAgent; import com.google.adk.agents.RunConfig; -import com.google.adk.artifacts.InMemoryArtifactService; import com.google.adk.events.Event; import com.google.adk.events.EventActions; -import com.google.adk.memory.InMemoryMemoryService; import com.google.adk.models.BaseLlm; import com.google.adk.models.LlmResponse; import com.google.adk.sessions.InMemorySessionService; @@ -53,18 +51,14 @@ public final class TestUtils { public static InvocationContext createInvocationContext(BaseAgent agent, RunConfig runConfig) { InMemorySessionService sessionService = new InMemorySessionService(); - return new InvocationContext( - sessionService, - new InMemoryArtifactService(), - new InMemoryMemoryService(), - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - "invocationId", - agent, - sessionService.createSession("test-app", "test-user").blockingGet(), - Optional.of(Content.fromParts(Part.fromText("user content"))), - runConfig, - /* endInvocation= */ false); + return InvocationContext.builder() + .sessionService(sessionService) + .invocationId("invocationId") + .agent(agent) + .session(sessionService.createSession("test-app", "test-user").blockingGet()) + .userContent(Content.fromParts(Part.fromText("user content"))) + .runConfig(runConfig) + .build(); } public static InvocationContext createInvocationContext(BaseAgent agent) { diff --git a/core/src/test/java/com/google/adk/tools/AgentToolTest.java b/core/src/test/java/com/google/adk/tools/AgentToolTest.java index e5318f87f..d43d9d03a 100644 --- a/core/src/test/java/com/google/adk/tools/AgentToolTest.java +++ b/core/src/test/java/com/google/adk/tools/AgentToolTest.java @@ -38,7 +38,6 @@ import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import java.util.Map; -import java.util.Optional; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -454,18 +453,11 @@ public void call_withStateDeltaInResponse_propagatesStateDelta() throws Exceptio private static ToolContext createToolContext(LlmAgent agent) { return ToolContext.builder( - new InvocationContext( - /* sessionService= */ null, - /* artifactService= */ null, - /* memoryService= */ null, - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - /* invocationId= */ InvocationContext.newInvocationContextId(), - agent, - Session.builder("123").build(), - /* userContent= */ Optional.empty(), - /* runConfig= */ null, - /* endInvocation= */ false)) + InvocationContext.builder() + .invocationId(InvocationContext.newInvocationContextId()) + .agent(agent) + .session(Session.builder("123").build()) + .build()) .build(); } } diff --git a/core/src/test/java/com/google/adk/tools/FunctionToolTest.java b/core/src/test/java/com/google/adk/tools/FunctionToolTest.java index 3e2826369..5816c427a 100644 --- a/core/src/test/java/com/google/adk/tools/FunctionToolTest.java +++ b/core/src/test/java/com/google/adk/tools/FunctionToolTest.java @@ -20,7 +20,6 @@ import static org.junit.Assert.assertThrows; import com.google.adk.agents.InvocationContext; -import com.google.adk.agents.RunConfig; import com.google.adk.events.ToolConfirmation; import com.google.adk.sessions.Session; import com.google.common.collect.ImmutableList; @@ -236,10 +235,7 @@ public void call_withAllSupportedParameterTypes() throws Exception { FunctionTool tool = FunctionTool.create(Functions.class, "returnAllSupportedParametersAsMap"); ToolContext toolContext = ToolContext.builder( - InvocationContext.builder() - .session(Session.builder("123").build()) - .runConfig(RunConfig.builder().build()) - .build()) + InvocationContext.builder().session(Session.builder("123").build()).build()) .functionCallId("functionCallId") .build(); diff --git a/core/src/test/java/com/google/adk/utils/InstructionUtilsTest.java b/core/src/test/java/com/google/adk/utils/InstructionUtilsTest.java index 70a42e6d0..48b2531d3 100644 --- a/core/src/test/java/com/google/adk/utils/InstructionUtilsTest.java +++ b/core/src/test/java/com/google/adk/utils/InstructionUtilsTest.java @@ -5,15 +5,12 @@ import static org.junit.Assert.assertThrows; import com.google.adk.agents.InvocationContext; -import com.google.adk.agents.RunConfig; import com.google.adk.artifacts.InMemoryArtifactService; import com.google.adk.memory.InMemoryMemoryService; import com.google.adk.sessions.InMemorySessionService; import com.google.adk.sessions.Session; import com.google.adk.sessions.State; -import com.google.genai.types.Content; import com.google.genai.types.Part; -import java.util.Optional; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -23,28 +20,19 @@ public final class InstructionUtilsTest { private InvocationContext templateContext; - private InMemorySessionService sessionService; - private InMemoryArtifactService artifactService; - private InMemoryMemoryService memoryService; @Before public void setUp() { - sessionService = new InMemorySessionService(); - artifactService = new InMemoryArtifactService(); - memoryService = new InMemoryMemoryService(); + InMemorySessionService sessionService = new InMemorySessionService(); templateContext = - new InvocationContext( - sessionService, - artifactService, - memoryService, - /* liveRequestQueue= */ Optional.empty(), - /* branch= */ Optional.empty(), - "invocationId", - createRootAgent(), - sessionService.createSession("test-app", "test-user").blockingGet(), - Optional.of(Content.fromParts()), - RunConfig.builder().build(), - /* endInvocation= */ false); + InvocationContext.builder() + .sessionService(sessionService) + .artifactService(new InMemoryArtifactService()) + .memoryService(new InMemoryMemoryService()) + .invocationId("invocationId") + .agent(createRootAgent()) + .session(sessionService.createSession("test-app", "test-user").blockingGet()) + .build(); } @Test @@ -67,7 +55,7 @@ public void injectSessionState_nullContext_throwsNullPointerException() { @Test public void injectSessionState_withMultipleStateVariables_replacesStatePlaceholders() { - var testContext = InvocationContext.copyOf(templateContext); + var testContext = templateContext.toBuilder().build(); testContext.session().state().put("greeting", "Hi"); testContext.session().state().put("user", "Alice"); String template = "Greet the user with: {greeting} {user}."; @@ -79,7 +67,7 @@ public void injectSessionState_withMultipleStateVariables_replacesStatePlacehold @Test public void injectSessionState_stateVariablePlaceholderWithSpaces_trimsAndReplacesVariable() { - var testContext = InvocationContext.copyOf(templateContext); + var testContext = templateContext.toBuilder().build(); testContext.session().state().put("name", "James"); String template = "The user you are helping is: { name }."; @@ -90,7 +78,7 @@ public void injectSessionState_stateVariablePlaceholderWithSpaces_trimsAndReplac @Test public void injectSessionState_stateVariablePlaceholderWithMultipleBraces_replacesVariable() { - var testContext = InvocationContext.copyOf(templateContext); + var testContext = templateContext.toBuilder().build(); testContext.session().state().put("user:name", "Charlie"); String template = "Use the user name: {{user:name}}."; @@ -101,7 +89,7 @@ public void injectSessionState_stateVariablePlaceholderWithMultipleBraces_replac @Test public void injectSessionState_stateVariableWithNonStringValue_convertsValueToString() { - InvocationContext testContext = InvocationContext.copyOf(templateContext); + InvocationContext testContext = templateContext.toBuilder().build(); testContext.session().state().put("app:count", 123); String template = "The current count is: {app:count}."; @@ -121,7 +109,7 @@ public void injectSessionState_missingNonOptionalStateVariable_throwsIllegalArgu @Test public void injectSessionState_missingOptionalStateVariable_replacesWithEmptyString() { - InvocationContext testContext = InvocationContext.copyOf(templateContext); + InvocationContext testContext = templateContext.toBuilder().build(); testContext.session().state().put("user:first_name", "John"); testContext.session().state().put("user:last_name", "Doe"); String template = @@ -134,10 +122,11 @@ public void injectSessionState_missingOptionalStateVariable_replacesWithEmptyStr @Test public void injectSessionState_withValidArtifact_replacesWithArtifactText() { - InvocationContext testContext = InvocationContext.copyOf(templateContext); + InvocationContext testContext = templateContext.toBuilder().build(); Session session = testContext.session(); var unused = - artifactService + testContext + .artifactService() .saveArtifact( session.appName(), session.userId(), @@ -155,7 +144,7 @@ public void injectSessionState_withValidArtifact_replacesWithArtifactText() { @Test public void injectSessionState_missingNonOptionalArtifact_throwsIllegalArgumentException() { - InvocationContext testContext = InvocationContext.copyOf(templateContext); + InvocationContext testContext = templateContext.toBuilder().build(); String template = "Include this knowledge: {artifact.missing_knowledge.txt}."; assertThrows( @@ -165,7 +154,7 @@ public void injectSessionState_missingNonOptionalArtifact_throwsIllegalArgumentE @Test public void injectSessionState_missingOptionalArtifact_replacesWithEmptyString() { - InvocationContext testContext = InvocationContext.copyOf(templateContext); + InvocationContext testContext = templateContext.toBuilder().build(); String template = "Include this additional info: {artifact.optional_info.txt?}."; String result = InstructionUtils.injectSessionState(testContext, template).blockingGet(); @@ -185,7 +174,7 @@ public void injectSessionState_invalidStateVariableNameSyntax_returnsPlaceholder @Test public void injectSessionState_stateVariableWithValidPrefix_replacesVariable() { - var testContext = InvocationContext.copyOf(templateContext); + var testContext = templateContext.toBuilder().build(); testContext.session().state().put("app:assistant_name", "Trippy"); String template = "Set the assistant name to: {app:assistant_name}.";