Skip to content

Commit 9e9dfa7

Browse files
selcukguncopybara-github
authored andcommitted
Prevent session state injection for provider supplied instructions
When the user provides instruction provider, we assume that they will inject the session state parameters if needed. This assumption allows users to return code snippets in the instruction provider without any template replacement. PiperOrigin-RevId: 759705471
1 parent 5ee17a3 commit 9e9dfa7

File tree

4 files changed

+74
-23
lines changed

4 files changed

+74
-23
lines changed

src/google/adk/agents/llm_agent.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -307,31 +307,53 @@ def canonical_model(self) -> BaseLlm:
307307
ancestor_agent = ancestor_agent.parent_agent
308308
raise ValueError(f'No model found for {self.name}.')
309309

310-
async def canonical_instruction(self, ctx: ReadonlyContext) -> str:
310+
async def canonical_instruction(
311+
self, ctx: ReadonlyContext
312+
) -> tuple[str, bool]:
311313
"""The resolved self.instruction field to construct instruction for this agent.
312314
313315
This method is only for use by Agent Development Kit.
316+
317+
Args:
318+
ctx: The context to retrieve the session state.
319+
320+
Returns:
321+
A tuple of (instruction, bypass_state_injection).
322+
instruction: The resolved self.instruction field.
323+
bypass_state_injection: Whether the instruction is based on
324+
InstructionProvider.
314325
"""
315326
if isinstance(self.instruction, str):
316-
return self.instruction
327+
return self.instruction, False
317328
else:
318329
instruction = self.instruction(ctx)
319330
if inspect.isawaitable(instruction):
320331
instruction = await instruction
321-
return instruction
332+
return instruction, True
322333

323-
async def canonical_global_instruction(self, ctx: ReadonlyContext) -> str:
334+
async def canonical_global_instruction(
335+
self, ctx: ReadonlyContext
336+
) -> tuple[str, bool]:
324337
"""The resolved self.instruction field to construct global instruction.
325338
326339
This method is only for use by Agent Development Kit.
340+
341+
Args:
342+
ctx: The context to retrieve the session state.
343+
344+
Returns:
345+
A tuple of (instruction, bypass_state_injection).
346+
instruction: The resolved self.global_instruction field.
347+
bypass_state_injection: Whether the instruction is based on
348+
InstructionProvider.
327349
"""
328350
if isinstance(self.global_instruction, str):
329-
return self.global_instruction
351+
return self.global_instruction, False
330352
else:
331353
global_instruction = self.global_instruction(ctx)
332354
if inspect.isawaitable(global_instruction):
333355
global_instruction = await global_instruction
334-
return global_instruction
356+
return global_instruction, True
335357

336358
async def canonical_tools(
337359
self, ctx: ReadonlyContext = None

src/google/adk/flows/llm_flows/instructions.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,24 @@ async def run_async(
5353
if (
5454
isinstance(root_agent, LlmAgent) and root_agent.global_instruction
5555
): # not empty str
56-
raw_si = await root_agent.canonical_global_instruction(
57-
ReadonlyContext(invocation_context)
56+
raw_si, bypass_state_injection = (
57+
await root_agent.canonical_global_instruction(
58+
ReadonlyContext(invocation_context)
59+
)
5860
)
59-
si = await _populate_values(raw_si, invocation_context)
61+
si = raw_si
62+
if not bypass_state_injection:
63+
si = await _populate_values(raw_si, invocation_context)
6064
llm_request.append_instructions([si])
6165

6266
# Appends agent instructions if set.
6367
if agent.instruction: # not empty str
64-
raw_si = await agent.canonical_instruction(
68+
raw_si, bypass_state_injection = await agent.canonical_instruction(
6569
ReadonlyContext(invocation_context)
6670
)
67-
si = await _populate_values(raw_si, invocation_context)
71+
si = raw_si
72+
if not bypass_state_injection:
73+
si = await _populate_values(raw_si, invocation_context)
6874
llm_request.append_instructions([si])
6975

7076
# Maintain async generator behavior

tests/unittests/agents/test_llm_agent_fields.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,11 @@ async def test_canonical_instruction_str():
7979
agent = LlmAgent(name='test_agent', instruction='instruction')
8080
ctx = await _create_readonly_context(agent)
8181

82-
canonical_instruction = await agent.canonical_instruction(ctx)
82+
canonical_instruction, bypass_state_injection = (
83+
await agent.canonical_instruction(ctx)
84+
)
8385
assert canonical_instruction == 'instruction'
86+
assert not bypass_state_injection
8487

8588

8689
async def test_canonical_instruction():
@@ -92,8 +95,11 @@ def _instruction_provider(ctx: ReadonlyContext) -> str:
9295
agent, state={'state_var': 'state_value'}
9396
)
9497

95-
canonical_instruction = await agent.canonical_instruction(ctx)
98+
canonical_instruction, bypass_state_injection = (
99+
await agent.canonical_instruction(ctx)
100+
)
96101
assert canonical_instruction == 'instruction: state_value'
102+
assert bypass_state_injection
97103

98104

99105
async def test_async_canonical_instruction():
@@ -105,16 +111,22 @@ async def _instruction_provider(ctx: ReadonlyContext) -> str:
105111
agent, state={'state_var': 'state_value'}
106112
)
107113

108-
canonical_instruction = await agent.canonical_instruction(ctx)
114+
canonical_instruction, bypass_state_injection = (
115+
await agent.canonical_instruction(ctx)
116+
)
109117
assert canonical_instruction == 'instruction: state_value'
118+
assert bypass_state_injection
110119

111120

112121
async def test_canonical_global_instruction_str():
113122
agent = LlmAgent(name='test_agent', global_instruction='global instruction')
114123
ctx = await _create_readonly_context(agent)
115124

116-
canonical_instruction = await agent.canonical_global_instruction(ctx)
125+
canonical_instruction, bypass_state_injection = (
126+
await agent.canonical_global_instruction(ctx)
127+
)
117128
assert canonical_instruction == 'global instruction'
129+
assert not bypass_state_injection
118130

119131

120132
async def test_canonical_global_instruction():
@@ -128,9 +140,11 @@ def _global_instruction_provider(ctx: ReadonlyContext) -> str:
128140
agent, state={'state_var': 'state_value'}
129141
)
130142

131-
canonical_global_instruction = await agent.canonical_global_instruction(ctx)
143+
canonical_global_instruction, bypass_state_injection = (
144+
await agent.canonical_global_instruction(ctx)
145+
)
132146
assert canonical_global_instruction == 'global instruction: state_value'
133-
147+
assert bypass_state_injection
134148

135149
async def test_async_canonical_global_instruction():
136150
async def _global_instruction_provider(ctx: ReadonlyContext) -> str:
@@ -142,11 +156,11 @@ async def _global_instruction_provider(ctx: ReadonlyContext) -> str:
142156
ctx = await _create_readonly_context(
143157
agent, state={'state_var': 'state_value'}
144158
)
145-
146-
assert (
159+
canonical_global_instruction, bypass_state_injection = (
147160
await agent.canonical_global_instruction(ctx)
148-
== 'global instruction: state_value'
149161
)
162+
assert canonical_global_instruction == 'global instruction: state_value'
163+
assert bypass_state_injection
150164

151165

152166
def test_output_schema_will_disable_transfer(caplog: pytest.LogCaptureFixture):

tests/unittests/flows/llm_flows/test_instructions.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ async def test_function_system_instruction():
6161
def build_function_instruction(readonly_context: ReadonlyContext) -> str:
6262
return (
6363
"This is the function agent instruction for invocation:"
64+
" provider template intact { customerId }"
65+
" provider template intact { customer_int }"
6466
f" {readonly_context.invocation_id}."
6567
)
6668

@@ -88,17 +90,21 @@ def build_function_instruction(readonly_context: ReadonlyContext) -> str:
8890
pass
8991

9092
assert request.config.system_instruction == (
91-
"This is the function agent instruction for invocation: test_id."
93+
"This is the function agent instruction for invocation:"
94+
" provider template intact { customerId }"
95+
" provider template intact { customer_int }"
96+
" test_id."
9297
)
9398

94-
9599
@pytest.mark.asyncio
96100
async def test_async_function_system_instruction():
97101
async def build_function_instruction(
98102
readonly_context: ReadonlyContext,
99103
) -> str:
100104
return (
101105
"This is the function agent instruction for invocation:"
106+
" provider template intact { customerId }"
107+
" provider template intact { customer_int }"
102108
f" {readonly_context.invocation_id}."
103109
)
104110

@@ -126,7 +132,10 @@ async def build_function_instruction(
126132
pass
127133

128134
assert request.config.system_instruction == (
129-
"This is the function agent instruction for invocation: test_id."
135+
"This is the function agent instruction for invocation:"
136+
" provider template intact { customerId }"
137+
" provider template intact { customer_int }"
138+
" test_id."
130139
)
131140

132141

0 commit comments

Comments
 (0)