diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index 4d9713b8..2a4387e0 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -52,6 +52,7 @@ import com.google.adk.models.BaseLlm; import com.google.adk.models.LlmRegistry; import com.google.adk.models.Model; +import com.google.adk.planners.BasePlanner; import com.google.adk.tools.BaseTool; import com.google.adk.tools.BaseTool.ToolArgsConfig; import com.google.adk.tools.BaseTool.ToolConfig; @@ -105,6 +106,7 @@ public enum IncludeContents { private final IncludeContents includeContents; private final boolean planning; + private final Optional planner; private final Optional maxSteps; private final boolean disallowTransferToParent; private final boolean disallowTransferToPeers; @@ -138,6 +140,7 @@ protected LlmAgent(Builder builder) { this.includeContents = builder.includeContents != null ? builder.includeContents : IncludeContents.DEFAULT; this.planning = builder.planning != null && builder.planning; + this.planner = Optional.ofNullable(builder.planner); this.maxSteps = Optional.ofNullable(builder.maxSteps); this.disallowTransferToParent = builder.disallowTransferToParent; this.disallowTransferToPeers = builder.disallowTransferToPeers; @@ -187,6 +190,7 @@ public static class Builder { private BaseExampleProvider exampleProvider; private IncludeContents includeContents; private Boolean planning; + private BasePlanner planner; private Integer maxSteps; private Boolean disallowTransferToParent; private Boolean disallowTransferToPeers; @@ -325,6 +329,12 @@ public Builder planning(boolean planning) { return this; } + @CanIgnoreReturnValue + public Builder planner(BasePlanner planner) { + this.planner = planner; + return this; + } + @CanIgnoreReturnValue public Builder maxSteps(int maxSteps) { this.maxSteps = maxSteps; @@ -784,6 +794,10 @@ public boolean planning() { return planning; } + public Optional planner() { + return planner; + } + public Optional maxSteps() { return maxSteps; } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/NLPlanning.java b/core/src/main/java/com/google/adk/flows/llmflows/NLPlanning.java new file mode 100644 index 00000000..70678d69 --- /dev/null +++ b/core/src/main/java/com/google/adk/flows/llmflows/NLPlanning.java @@ -0,0 +1,166 @@ +package com.google.adk.flows.llmflows; + +import com.google.adk.agents.CallbackContext; +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.ReadonlyContext; +import com.google.adk.events.Event; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.planners.BasePlanner; +import com.google.adk.planners.BuiltInPlanner; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.core.Single; + +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +public class NLPlanning { + + static class NlPlanningRequestProcessor implements RequestProcessor { + + @Override + public Single processRequest( + InvocationContext context, LlmRequest llmRequest) { + + if (!(context.agent() instanceof LlmAgent)) { + throw new IllegalArgumentException( + "Agent in InvocationContext is not an instance of LlmAgent."); + } + + Optional plannerOpt = getPlanner(context); + if (plannerOpt.isEmpty()) { + return Single.just(RequestProcessor.RequestProcessingResult.create(llmRequest, ImmutableList.of())); + } + + BasePlanner planner = plannerOpt.get(); + + // Apply thinking configuration for built-in planners + if (planner instanceof BuiltInPlanner) { + llmRequest = ((BuiltInPlanner) planner).applyThinkingConfig(llmRequest); + } + + // Build and append planning instruction + Optional planningInstruction = + planner.generatePlanningInstruction(new ReadonlyContext(context), llmRequest); + + LlmRequest.Builder b = llmRequest.toBuilder(); + planningInstruction.ifPresent(s -> b.appendInstructions(ImmutableList.of(s))); + llmRequest = b.build(); + + // Remove thought annotations from request + llmRequest = removeThoughtFromRequest(llmRequest); + + return Single.just(RequestProcessor.RequestProcessingResult.create(llmRequest, ImmutableList.of())); + } + } + + static class NlPlanningResponseProcessor implements ResponseProcessor { + + @Override + public Single processResponse( + InvocationContext context, LlmResponse llmResponse) { + + if (!(context.agent() instanceof LlmAgent)) { + throw new IllegalArgumentException( + "Agent in InvocationContext is not an instance of LlmAgent."); + } + + // Validate response structure + if (llmResponse == null || llmResponse.content().isEmpty()) { + return Single.just( + ResponseProcessor.ResponseProcessingResult.create( + llmResponse, ImmutableList.of(), Optional.empty())); + } + + Optional plannerOpt = getPlanner(context); + if (plannerOpt.isEmpty()) { + return Single.just( + ResponseProcessor.ResponseProcessingResult.create( + llmResponse, ImmutableList.of(), Optional.empty())); + } + + BasePlanner planner = plannerOpt.get(); + LlmResponse.Builder responseBuilder = llmResponse.toBuilder(); + + // Process the planning response + CallbackContext callbackContext = new CallbackContext(context, null); + Optional> processedParts = + planner.processPlanningResponse( + callbackContext, llmResponse.content().get().parts().orElse(List.of())); + + // Update response with processed parts + if (processedParts.isPresent()) { + Content.Builder contentBuilder = llmResponse.content().get().toBuilder(); + contentBuilder.parts(processedParts.get()); + responseBuilder.content(contentBuilder.build()); + } + + ImmutableList.Builder eventsBuilder = ImmutableList.builder(); + + // Generate state update event if there are deltas + if (callbackContext.state().hasDelta()) { + Event stateUpdateEvent = + Event.builder() + .invocationId(context.invocationId()) + .author(context.agent().name()) + .branch(context.branch()) + .actions(callbackContext.eventActions()) + .build(); + + eventsBuilder.add(stateUpdateEvent); + } + + return Single.just( + ResponseProcessor.ResponseProcessingResult.create( + responseBuilder.build(), eventsBuilder.build(), Optional.empty())); + } + } + + /** + * Retrieves the planner from the invocation context. + * + * @param invocationContext the current invocation context + * @return optional planner instance, or empty if none available + */ + private static Optional getPlanner(InvocationContext invocationContext) { + if (!(invocationContext.agent() instanceof LlmAgent agent)) { + return Optional.empty(); + } + + return agent.planner(); + } + + /** + * Removes thought annotations from all parts in the LLM request. + * + *

This method iterates through all content parts and sets the thought field to false, + * effectively removing thought markings from the request. + * + * @param llmRequest the LLM request to process + */ + private static LlmRequest removeThoughtFromRequest(LlmRequest llmRequest) { + if (llmRequest.contents() == null || llmRequest.contents().isEmpty()) { + return llmRequest; + } + + // Process each content and update its parts + List updatedContents = llmRequest.contents().stream().map(content -> { + if (content.parts().isEmpty()) { + return content; + } + + // Update all parts to set thought to false + List updatedParts = content.parts().get().stream().map(part -> part.toBuilder().thought(false).build()).collect(Collectors.toList()); + + // Return updated content with modified parts + return content.toBuilder().parts(updatedParts).build(); + }).collect(Collectors.toList()); + + // Return updated LlmRequest with modified contents + return llmRequest.toBuilder().contents(updatedContents).build(); + } +} diff --git a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java index 2aeaf622..c563da1d 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/SingleFlow.java @@ -31,10 +31,12 @@ public class SingleFlow extends BaseLlmFlow { new Identity(), new Contents(), new Examples(), + new NLPlanning.NlPlanningRequestProcessor(), CodeExecution.requestProcessor); protected static final ImmutableList RESPONSE_PROCESSORS = - ImmutableList.of(CodeExecution.responseProcessor); + ImmutableList.of( + new NLPlanning.NlPlanningResponseProcessor(), CodeExecution.responseProcessor); public SingleFlow() { this(/* maxSteps= */ Optional.empty()); diff --git a/core/src/main/java/com/google/adk/planners/BasePlanner.java b/core/src/main/java/com/google/adk/planners/BasePlanner.java new file mode 100644 index 00000000..fbc7933c --- /dev/null +++ b/core/src/main/java/com/google/adk/planners/BasePlanner.java @@ -0,0 +1,28 @@ +package com.google.adk.planners; + +import com.google.adk.agents.CallbackContext; +import com.google.adk.agents.ReadonlyContext; +import com.google.adk.models.LlmRequest; +import com.google.genai.types.Part; +import java.util.List; +import java.util.Optional; + +public interface BasePlanner { + /** + * Generates system instruction text for LLM planning requests. + * + * @param context readonly invocation context + * @param request the LLM request being prepared + * @return planning instruction text, or empty if no instruction needed + */ + Optional generatePlanningInstruction(ReadonlyContext context, LlmRequest request); + + /** + * Processes and transforms LLM response parts for planning workflow. + * + * @param context callback context for the current invocation + * @param responseParts list of response parts from the LLM + * @return processed response parts, or empty if no processing required + */ + Optional> processPlanningResponse(CallbackContext context, List responseParts); +} diff --git a/core/src/main/java/com/google/adk/planners/BuiltInPlanner.java b/core/src/main/java/com/google/adk/planners/BuiltInPlanner.java new file mode 100644 index 00000000..19a5cd56 --- /dev/null +++ b/core/src/main/java/com/google/adk/planners/BuiltInPlanner.java @@ -0,0 +1,56 @@ +package com.google.adk.planners; + +import com.google.adk.agents.CallbackContext; +import com.google.adk.agents.ReadonlyContext; +import com.google.adk.models.LlmRequest; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Part; +import com.google.genai.types.ThinkingConfig; +import java.util.List; +import java.util.Optional; + +public class BuiltInPlanner implements BasePlanner { + private ThinkingConfig cognitiveConfig; + + private BuiltInPlanner() {} + + private BuiltInPlanner(ThinkingConfig cognitiveConfig) { + this.cognitiveConfig = cognitiveConfig; + } + + public static BuiltInPlanner buildPlanner(ThinkingConfig cognitiveConfig) { + return new BuiltInPlanner(cognitiveConfig); + } + + @Override + public Optional generatePlanningInstruction(ReadonlyContext context, LlmRequest request) { + return Optional.empty(); + } + + @Override + public Optional> processPlanningResponse( + CallbackContext context, List responseParts) { + return Optional.empty(); + } + + /** + * Configures the LLM request with thinking capabilities. This method modifies the request to + * include the thinking configuration, enabling the model's native cognitive processing features. + * + * @param request the LLM request to configure + */ + public LlmRequest applyThinkingConfig(LlmRequest request) { + if (this.cognitiveConfig != null) { + // Ensure config exists + GenerateContentConfig.Builder configBuilder = + request.config().map(GenerateContentConfig::toBuilder).orElse(GenerateContentConfig.builder()); + + // Apply thinking configuration + request = + request.toBuilder() + .config(configBuilder.thinkingConfig(this.cognitiveConfig).build()) + .build(); + } + return request; + } +} diff --git a/core/src/main/java/com/google/adk/planners/PlanReActPlanner.java b/core/src/main/java/com/google/adk/planners/PlanReActPlanner.java new file mode 100644 index 00000000..e6ebfb00 --- /dev/null +++ b/core/src/main/java/com/google/adk/planners/PlanReActPlanner.java @@ -0,0 +1,209 @@ +package com.google.adk.planners; + +import com.google.adk.agents.CallbackContext; +import com.google.adk.agents.ReadonlyContext; +import com.google.adk.models.LlmRequest; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.Part; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +public class PlanReActPlanner implements BasePlanner { + // ReAct structure tags + private static final String PLANNING_TAG = "/*PLANNING*/"; + private static final String REPLANNING_TAG = "/*REPLANNING*/"; + private static final String REASONING_TAG = "/*REASONING*/"; + private static final String ACTION_TAG = "/*ACTION*/"; + private static final String FINAL_ANSWER_TAG = "/*FINAL_ANSWER*/"; + + @Override + public Optional generatePlanningInstruction(ReadonlyContext context, LlmRequest request) { + return Optional.of(buildNaturalLanguagePlannerInstruction()); + } + + @Override + public Optional> processPlanningResponse( + CallbackContext context, List responseParts) { + if (responseParts == null || responseParts.isEmpty()) { + return Optional.empty(); + } + + List preservedParts = new ArrayList<>(); + int firstFunctionCallIndex = -1; + + // Process parts until first function call + for (int i = 0; i < responseParts.size(); i++) { + Part part = responseParts.get(i); + + // Check for function call + if (part.functionCall().isPresent()) { + FunctionCall functionCall = part.functionCall().get(); + + // Skip function calls with empty names + if (functionCall.name().isEmpty() || functionCall.name().get().trim().isEmpty()) { + continue; + } + + preservedParts.add(part); + firstFunctionCallIndex = i; + break; + } + + // Handle non-function-call parts + handleNonFunctionCallParts(part, preservedParts); + } + + // Process remaining function calls if any + if (firstFunctionCallIndex != -1) { + int j = firstFunctionCallIndex + 1; + while (j < responseParts.size()) { + Part part = responseParts.get(j); + if (part.functionCall().isPresent()) { + preservedParts.add(part); + j++; + } else { + break; + } + } + } + + return Optional.of(ImmutableList.copyOf(preservedParts)); + } + + /** + * Handles processing of non-function-call response parts. + * + * @param responsePart the part to process + * @param preservedParts the mutable list to add processed parts to + */ + private void handleNonFunctionCallParts(Part responsePart, List preservedParts) { + if (responsePart.text().isEmpty()) { + preservedParts.add(responsePart); + return; + } + + String text = responsePart.text().get(); + + // Handle final answer tag specially + if (text.contains(FINAL_ANSWER_TAG)) { + String[] splitResult = splitByLastPattern(text, FINAL_ANSWER_TAG); + String reasoningText = splitResult[0]; + String finalAnswerText = splitResult[1]; + + // Add reasoning part if present + if (!reasoningText.trim().isEmpty()) { + Part reasoningPart = Part.builder().text(reasoningText).thought(true).build(); + preservedParts.add(reasoningPart); + } + + // Add final answer part if present + if (!finalAnswerText.trim().isEmpty()) { + Part finalAnswerPart = Part.builder().text(finalAnswerText).build(); + preservedParts.add(finalAnswerPart); + } + } else { + // Check if part should be marked as thought + boolean isThought = + text.startsWith(PLANNING_TAG) + || text.startsWith(REASONING_TAG) + || text.startsWith(ACTION_TAG) + || text.startsWith(REPLANNING_TAG); + + Part.Builder partBuilder = responsePart.toBuilder(); + if (isThought) { + partBuilder.thought(true); + } + + preservedParts.add(partBuilder.build()); + } + } + + /** + * Splits text by the last occurrence of a separator. + * + * @param text the text to split + * @param separator the separator to search for + * @return array containing [before_separator + separator, after_separator] + */ + private String[] splitByLastPattern(String text, String separator) { + int index = text.lastIndexOf(separator); + if (index == -1) { + return new String[] {text, ""}; + } + + String before = text.substring(0, index + separator.length()); + String after = text.substring(index + separator.length()); + return new String[] {before, after}; + } + + /** + * Builds the comprehensive natural language planner instruction. + * + * @return the complete system instruction for Plan-ReAct methodology + */ + private String buildNaturalLanguagePlannerInstruction() { + String highLevelPreamble = + String.format( + """ + When answering the question, try to leverage the available tools to gather the information instead of your memorized knowledge. + + Follow this process when answering the question: (1) first come up with a plan in natural language text format; (2) Then use tools to execute the plan and provide reasoning between tool code snippets to make a summary of current state and next step. Tool code snippets and reasoning should be interleaved with each other. (3) In the end, return one final answer. + + Follow this format when answering the question: (1) The planning part should be under %s. (2) The tool code snippets should be under %s, and the reasoning parts should be under %s. (3) The final answer part should be under %s. + """, + PLANNING_TAG, ACTION_TAG, REASONING_TAG, FINAL_ANSWER_TAG); + + String planningPreamble = + String.format( + """ + Below are the requirements for the planning: + The plan is made to answer the user query if following the plan. The plan is coherent and covers all aspects of information from user query, and only involves the tools that are accessible by the agent. The plan contains the decomposed steps as a numbered list where each step should use one or multiple available tools. By reading the plan, you can intuitively know which tools to trigger or what actions to take. + If the initial plan cannot be successfully executed, you should learn from previous execution results and revise your plan. The revised plan should be be under %s. Then use tools to follow the new plan. + """, + REPLANNING_TAG); + + String reasoningPreamble = + """ + Below are the requirements for the reasoning: + The reasoning makes a summary of the current trajectory based on the user query and tool outputs. Based on the tool outputs and plan, the reasoning also comes up with instructions to the next steps, making the trajectory closer to the final answer. + """; + + String finalAnswerPreamble = + """ + Below are the requirements for the final answer: + The final answer should be precise and follow query formatting requirements. Some queries may not be answerable with the available tools and information. In those cases, inform the user why you cannot process their query and ask for more information. + """; + + String toolCodePreamble = + """ + Below are the requirements for the tool code: + + **Custom Tools:** The available tools are described in the context and can be directly used. + - Code must be valid self-contained Python snippets with no imports and no references to tools or Python libraries that are not in the context. + - You cannot use any parameters or fields that are not explicitly defined in the APIs in the context. + - The code snippets should be readable, efficient, and directly relevant to the user query and reasoning steps. + - When using the tools, you should use the library name together with the function name, e.g., vertex_search.search(). + - If Python libraries are not provided in the context, NEVER write your own code other than the function calls using the provided tools. + """; + + String userInputPreamble = + """ + VERY IMPORTANT instruction that you MUST follow in addition to the above instructions: + + You should ask for clarification if you need more information to answer the question. + You should prefer using the information available in the context instead of repeated tool use. + """; + + return String.join( + "\n\n", + highLevelPreamble, + planningPreamble, + reasoningPreamble, + finalAnswerPreamble, + toolCodePreamble, + userInputPreamble); + } +}