Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions core/src/main/java/com/google/adk/agents/LlmAgent.java
Original file line number Diff line number Diff line change
Expand Up @@ -655,38 +655,45 @@ protected Flowable<Event> runLiveImpl(InvocationContext invocationContext) {
}

/**
* Constructs the text instruction for this agent based on the {@link #instruction} field.
* Constructs the text instruction for this agent based on the {@link #instruction} field. Also
* returns a boolean indicating that state injection should be bypassed when the instruction is
* constructed with an {@link Instruction.Provider}.
*
* <p>This method is only for use by Agent Development Kit.
*
* @param context The context to retrieve the session state.
* @return The resolved instruction as a {@link Single} wrapped string.
* @return The resolved instruction as a {@link Single} wrapped Map.Entry. The key is the
* instruction string and the value is a boolean indicating if state injection should be
* bypassed.
*/
public Single<String> canonicalInstruction(ReadonlyContext context) {
public Single<Map.Entry<String, Boolean>> canonicalInstruction(ReadonlyContext context) {
if (instruction instanceof Instruction.Static staticInstr) {
return Single.just(staticInstr.instruction());
return Single.just(Map.entry(staticInstr.instruction(), false));
} else if (instruction instanceof Instruction.Provider provider) {
return provider.getInstruction().apply(context);
return provider.getInstruction().apply(context).map(instr -> Map.entry(instr, true));
}
throw new IllegalStateException("Unknown Instruction subtype: " + instruction.getClass());
}

/**
* Constructs the text global instruction for this agent based on the {@link #globalInstruction}
* field.
* field. Also returns a boolean indicating that state injection should be bypassed when the
* instruction is constructed with an {@link Instruction.Provider}.
*
* <p>This method is only for use by Agent Development Kit.
*
* @param context The context to retrieve the session state.
* @return The resolved global instruction as a {@link Single} wrapped string.
* @return The resolved global instruction as a {@link Single} wrapped Map.Entry. The key is the
* instruction string and the value is a boolean indicating if state injection should be
* bypassed.
*/
public Single<String> canonicalGlobalInstruction(ReadonlyContext context) {
public Single<Map.Entry<String, Boolean>> canonicalGlobalInstruction(ReadonlyContext context) {
if (globalInstruction instanceof Instruction.Static staticInstr) {
return Single.just(staticInstr.instruction());
return Single.just(Map.entry(staticInstr.instruction(), false));
} else if (globalInstruction instanceof Instruction.Provider provider) {
return provider.getInstruction().apply(context);
return provider.getInstruction().apply(context).map(instr -> Map.entry(instr, true));
}
throw new IllegalStateException("Unknown Instruction subtype: " + instruction.getClass());
throw new IllegalStateException("Unknown Instruction subtype: " + globalInstruction.getClass());
}

/**
Expand Down
16 changes: 14 additions & 2 deletions core/src/main/java/com/google/adk/flows/llmflows/Instructions.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,14 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
rootAgent
.canonicalGlobalInstruction(readonlyContext)
.flatMap(
globalInstr -> {
instructionEntry -> {
String globalInstr = instructionEntry.getKey();
boolean bypassStateInjection = instructionEntry.getValue();
if (!globalInstr.isEmpty()) {
if (bypassStateInjection) {
return Single.just(
builder.appendInstructions(ImmutableList.of(globalInstr)));
}
return InstructionUtils.injectSessionState(context, globalInstr)
.map(
resolvedGlobalInstr ->
Expand All @@ -65,8 +71,14 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
agent
.canonicalInstruction(readonlyContext)
.flatMap(
agentInstr -> {
instructionEntry -> {
String agentInstr = instructionEntry.getKey();
boolean bypassStateInjection = instructionEntry.getValue();
if (!agentInstr.isEmpty()) {
if (bypassStateInjection) {
return Single.just(
builder.appendInstructions(ImmutableList.of(agentInstr)));
}
return InstructionUtils.injectSessionState(context, agentInstr)
.map(
resolvedAgentInstr ->
Expand Down
12 changes: 8 additions & 4 deletions core/src/test/java/com/google/adk/agents/InstructionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ public void testCanonicalInstruction_staticInstruction() {
.build();
ReadonlyContext invocationContext = new ReadonlyContext(createInvocationContext(agent));

String canonicalInstruction = agent.canonicalInstruction(invocationContext).blockingGet();
String canonicalInstruction =
agent.canonicalInstruction(invocationContext).blockingGet().getKey();

assertThat(canonicalInstruction).isEqualTo(instruction);
}
Expand All @@ -39,7 +40,8 @@ public void testCanonicalInstruction_providerInstructionInjectsContext() {
.build();
ReadonlyContext invocationContext = new ReadonlyContext(createInvocationContext(agent));

String canonicalInstruction = agent.canonicalInstruction(invocationContext).blockingGet();
String canonicalInstruction =
agent.canonicalInstruction(invocationContext).blockingGet().getKey();

assertThat(canonicalInstruction).isEqualTo(instruction + invocationContext.invocationId());
}
Expand All @@ -53,7 +55,8 @@ public void testCanonicalGlobalInstruction_staticInstruction() {
.build();
ReadonlyContext invocationContext = new ReadonlyContext(createInvocationContext(agent));

String canonicalInstruction = agent.canonicalGlobalInstruction(invocationContext).blockingGet();
String canonicalInstruction =
agent.canonicalGlobalInstruction(invocationContext).blockingGet().getKey();

assertThat(canonicalInstruction).isEqualTo(instruction);
}
Expand All @@ -69,7 +72,8 @@ public void testCanonicalGlobalInstruction_providerInstructionInjectsContext() {
.build();
ReadonlyContext invocationContext = new ReadonlyContext(createInvocationContext(agent));

String canonicalInstruction = agent.canonicalGlobalInstruction(invocationContext).blockingGet();
String canonicalInstruction =
agent.canonicalGlobalInstruction(invocationContext).blockingGet().getKey();

assertThat(canonicalInstruction).isEqualTo(instruction + invocationContext.invocationId());
}
Expand Down
12 changes: 8 additions & 4 deletions core/src/test/java/com/google/adk/agents/LlmAgentTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ public void testCanonicalInstruction_acceptsPlainString() {
.build();
ReadonlyContext invocationContext = new ReadonlyContext(createInvocationContext(agent));

String canonicalInstruction = agent.canonicalInstruction(invocationContext).blockingGet();
String canonicalInstruction =
agent.canonicalInstruction(invocationContext).blockingGet().getKey();

assertThat(canonicalInstruction).isEqualTo(instruction);
}
Expand All @@ -241,7 +242,8 @@ public void testCanonicalInstruction_providerInstructionInjectsContext() {
.build();
ReadonlyContext invocationContext = new ReadonlyContext(createInvocationContext(agent));

String canonicalInstruction = agent.canonicalInstruction(invocationContext).blockingGet();
String canonicalInstruction =
agent.canonicalInstruction(invocationContext).blockingGet().getKey();

assertThat(canonicalInstruction).isEqualTo(instruction + invocationContext.invocationId());
}
Expand All @@ -265,7 +267,8 @@ public void testCanonicalGlobalInstruction_acceptsPlainString() {
.build();
ReadonlyContext invocationContext = new ReadonlyContext(createInvocationContext(agent));

String canonicalInstruction = agent.canonicalGlobalInstruction(invocationContext).blockingGet();
String canonicalInstruction =
agent.canonicalGlobalInstruction(invocationContext).blockingGet().getKey();

assertThat(canonicalInstruction).isEqualTo(instruction);
}
Expand All @@ -281,7 +284,8 @@ public void testCanonicalGlobalInstruction_providerInstructionInjectsContext() {
.build();
ReadonlyContext invocationContext = new ReadonlyContext(createInvocationContext(agent));

String canonicalInstruction = agent.canonicalGlobalInstruction(invocationContext).blockingGet();
String canonicalInstruction =
agent.canonicalGlobalInstruction(invocationContext).blockingGet().getKey();

assertThat(canonicalInstruction).isEqualTo(instruction + invocationContext.invocationId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,24 @@ public void processRequest_agentInstructionProvider_appendsInstruction() {
.containsExactly(instructionFromProvider);
}

@Test
public void processRequest_agentInstructionProvider_bypassesStateInjection() {
Session session = createSession();
session.state().put("name", "TestBot");
// This would throw an error if state injection was attempted.
String instructionFromProvider = "My name is {name}. But my friend is {friend_name}.";
Instruction provider = new Instruction.Provider(ctx -> Single.just(instructionFromProvider));

LlmAgent agent = LlmAgent.builder().name("agent").instruction(provider).build();
InvocationContext context = createContext(agent, createSession());

RequestProcessor.RequestProcessingResult result =
instructionsProcessor.processRequest(context, initialRequest).blockingGet();

assertThat(result.updatedRequest().getSystemInstructions())
.containsExactly(instructionFromProvider);
}

@Test
public void
processRequest_agentInstructionString_withInvalidPlaceholderSyntax_appendsInstructionWithLiteral() {
Expand Down