Skip to content

Commit 925b534

Browse files
Poggeccicopybara-github
authored andcommitted
feat: add instruction state injection bypass
PiperOrigin-RevId: 788621515
1 parent f5b8fda commit 925b534

File tree

5 files changed

+66
-21
lines changed

5 files changed

+66
-21
lines changed

core/src/main/java/com/google/adk/agents/LlmAgent.java

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -655,38 +655,45 @@ protected Flowable<Event> runLiveImpl(InvocationContext invocationContext) {
655655
}
656656

657657
/**
658-
* Constructs the text instruction for this agent based on the {@link #instruction} field.
658+
* Constructs the text instruction for this agent based on the {@link #instruction} field. Also
659+
* returns a boolean indicating that state injection should be bypassed when the instruction is
660+
* constructed with an {@link Instruction.Provider}.
659661
*
660662
* <p>This method is only for use by Agent Development Kit.
661663
*
662664
* @param context The context to retrieve the session state.
663-
* @return The resolved instruction as a {@link Single} wrapped string.
665+
* @return The resolved instruction as a {@link Single} wrapped Map.Entry. The key is the
666+
* instruction string and the value is a boolean indicating if state injection should be
667+
* bypassed.
664668
*/
665-
public Single<String> canonicalInstruction(ReadonlyContext context) {
669+
public Single<Map.Entry<String, Boolean>> canonicalInstruction(ReadonlyContext context) {
666670
if (instruction instanceof Instruction.Static staticInstr) {
667-
return Single.just(staticInstr.instruction());
671+
return Single.just(Map.entry(staticInstr.instruction(), false));
668672
} else if (instruction instanceof Instruction.Provider provider) {
669-
return provider.getInstruction().apply(context);
673+
return provider.getInstruction().apply(context).map(instr -> Map.entry(instr, true));
670674
}
671675
throw new IllegalStateException("Unknown Instruction subtype: " + instruction.getClass());
672676
}
673677

674678
/**
675679
* Constructs the text global instruction for this agent based on the {@link #globalInstruction}
676-
* field.
680+
* field. Also returns a boolean indicating that state injection should be bypassed when the
681+
* instruction is constructed with an {@link Instruction.Provider}.
677682
*
678683
* <p>This method is only for use by Agent Development Kit.
679684
*
680685
* @param context The context to retrieve the session state.
681-
* @return The resolved global instruction as a {@link Single} wrapped string.
686+
* @return The resolved global instruction as a {@link Single} wrapped Map.Entry. The key is the
687+
* instruction string and the value is a boolean indicating if state injection should be
688+
* bypassed.
682689
*/
683-
public Single<String> canonicalGlobalInstruction(ReadonlyContext context) {
690+
public Single<Map.Entry<String, Boolean>> canonicalGlobalInstruction(ReadonlyContext context) {
684691
if (globalInstruction instanceof Instruction.Static staticInstr) {
685-
return Single.just(staticInstr.instruction());
692+
return Single.just(Map.entry(staticInstr.instruction(), false));
686693
} else if (globalInstruction instanceof Instruction.Provider provider) {
687-
return provider.getInstruction().apply(context);
694+
return provider.getInstruction().apply(context).map(instr -> Map.entry(instr, true));
688695
}
689-
throw new IllegalStateException("Unknown Instruction subtype: " + instruction.getClass());
696+
throw new IllegalStateException("Unknown Instruction subtype: " + globalInstruction.getClass());
690697
}
691698

692699
/**

core/src/main/java/com/google/adk/flows/llmflows/Instructions.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,14 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
4747
rootAgent
4848
.canonicalGlobalInstruction(readonlyContext)
4949
.flatMap(
50-
globalInstr -> {
50+
instructionEntry -> {
51+
String globalInstr = instructionEntry.getKey();
52+
boolean bypassStateInjection = instructionEntry.getValue();
5153
if (!globalInstr.isEmpty()) {
54+
if (bypassStateInjection) {
55+
return Single.just(
56+
builder.appendInstructions(ImmutableList.of(globalInstr)));
57+
}
5258
return InstructionUtils.injectSessionState(context, globalInstr)
5359
.map(
5460
resolvedGlobalInstr ->
@@ -65,8 +71,14 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
6571
agent
6672
.canonicalInstruction(readonlyContext)
6773
.flatMap(
68-
agentInstr -> {
74+
instructionEntry -> {
75+
String agentInstr = instructionEntry.getKey();
76+
boolean bypassStateInjection = instructionEntry.getValue();
6977
if (!agentInstr.isEmpty()) {
78+
if (bypassStateInjection) {
79+
return Single.just(
80+
builder.appendInstructions(ImmutableList.of(agentInstr)));
81+
}
7082
return InstructionUtils.injectSessionState(context, agentInstr)
7183
.map(
7284
resolvedAgentInstr ->

core/src/test/java/com/google/adk/agents/InstructionTest.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ public void testCanonicalInstruction_staticInstruction() {
2323
.build();
2424
ReadonlyContext invocationContext = new ReadonlyContext(createInvocationContext(agent));
2525

26-
String canonicalInstruction = agent.canonicalInstruction(invocationContext).blockingGet();
26+
String canonicalInstruction =
27+
agent.canonicalInstruction(invocationContext).blockingGet().getKey();
2728

2829
assertThat(canonicalInstruction).isEqualTo(instruction);
2930
}
@@ -39,7 +40,8 @@ public void testCanonicalInstruction_providerInstructionInjectsContext() {
3940
.build();
4041
ReadonlyContext invocationContext = new ReadonlyContext(createInvocationContext(agent));
4142

42-
String canonicalInstruction = agent.canonicalInstruction(invocationContext).blockingGet();
43+
String canonicalInstruction =
44+
agent.canonicalInstruction(invocationContext).blockingGet().getKey();
4345

4446
assertThat(canonicalInstruction).isEqualTo(instruction + invocationContext.invocationId());
4547
}
@@ -53,7 +55,8 @@ public void testCanonicalGlobalInstruction_staticInstruction() {
5355
.build();
5456
ReadonlyContext invocationContext = new ReadonlyContext(createInvocationContext(agent));
5557

56-
String canonicalInstruction = agent.canonicalGlobalInstruction(invocationContext).blockingGet();
58+
String canonicalInstruction =
59+
agent.canonicalGlobalInstruction(invocationContext).blockingGet().getKey();
5760

5861
assertThat(canonicalInstruction).isEqualTo(instruction);
5962
}
@@ -69,7 +72,8 @@ public void testCanonicalGlobalInstruction_providerInstructionInjectsContext() {
6972
.build();
7073
ReadonlyContext invocationContext = new ReadonlyContext(createInvocationContext(agent));
7174

72-
String canonicalInstruction = agent.canonicalGlobalInstruction(invocationContext).blockingGet();
75+
String canonicalInstruction =
76+
agent.canonicalGlobalInstruction(invocationContext).blockingGet().getKey();
7377

7478
assertThat(canonicalInstruction).isEqualTo(instruction + invocationContext.invocationId());
7579
}

core/src/test/java/com/google/adk/agents/LlmAgentTest.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,8 @@ public void testCanonicalInstruction_acceptsPlainString() {
225225
.build();
226226
ReadonlyContext invocationContext = new ReadonlyContext(createInvocationContext(agent));
227227

228-
String canonicalInstruction = agent.canonicalInstruction(invocationContext).blockingGet();
228+
String canonicalInstruction =
229+
agent.canonicalInstruction(invocationContext).blockingGet().getKey();
229230

230231
assertThat(canonicalInstruction).isEqualTo(instruction);
231232
}
@@ -241,7 +242,8 @@ public void testCanonicalInstruction_providerInstructionInjectsContext() {
241242
.build();
242243
ReadonlyContext invocationContext = new ReadonlyContext(createInvocationContext(agent));
243244

244-
String canonicalInstruction = agent.canonicalInstruction(invocationContext).blockingGet();
245+
String canonicalInstruction =
246+
agent.canonicalInstruction(invocationContext).blockingGet().getKey();
245247

246248
assertThat(canonicalInstruction).isEqualTo(instruction + invocationContext.invocationId());
247249
}
@@ -265,7 +267,8 @@ public void testCanonicalGlobalInstruction_acceptsPlainString() {
265267
.build();
266268
ReadonlyContext invocationContext = new ReadonlyContext(createInvocationContext(agent));
267269

268-
String canonicalInstruction = agent.canonicalGlobalInstruction(invocationContext).blockingGet();
270+
String canonicalInstruction =
271+
agent.canonicalGlobalInstruction(invocationContext).blockingGet().getKey();
269272

270273
assertThat(canonicalInstruction).isEqualTo(instruction);
271274
}
@@ -281,7 +284,8 @@ public void testCanonicalGlobalInstruction_providerInstructionInjectsContext() {
281284
.build();
282285
ReadonlyContext invocationContext = new ReadonlyContext(createInvocationContext(agent));
283286

284-
String canonicalInstruction = agent.canonicalGlobalInstruction(invocationContext).blockingGet();
287+
String canonicalInstruction =
288+
agent.canonicalGlobalInstruction(invocationContext).blockingGet().getKey();
285289

286290
assertThat(canonicalInstruction).isEqualTo(instruction + invocationContext.invocationId());
287291
}

core/src/test/java/com/google/adk/flows/llmflows/InstructionsTest.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,24 @@ public void processRequest_agentInstructionProvider_appendsInstruction() {
181181
.containsExactly(instructionFromProvider);
182182
}
183183

184+
@Test
185+
public void processRequest_agentInstructionProvider_bypassesStateInjection() {
186+
Session session = createSession();
187+
session.state().put("name", "TestBot");
188+
// This would throw an error if state injection was attempted.
189+
String instructionFromProvider = "My name is {name}. But my friend is {friend_name}.";
190+
Instruction provider = new Instruction.Provider(ctx -> Single.just(instructionFromProvider));
191+
192+
LlmAgent agent = LlmAgent.builder().name("agent").instruction(provider).build();
193+
InvocationContext context = createContext(agent, createSession());
194+
195+
RequestProcessor.RequestProcessingResult result =
196+
instructionsProcessor.processRequest(context, initialRequest).blockingGet();
197+
198+
assertThat(result.updatedRequest().getSystemInstructions())
199+
.containsExactly(instructionFromProvider);
200+
}
201+
184202
@Test
185203
public void
186204
processRequest_agentInstructionString_withInvalidPlaceholderSyntax_appendsInstructionWithLiteral() {

0 commit comments

Comments
 (0)