Skip to content

Commit 43c1cc2

Browse files
committed
feat: add nested state access support with dot and bracket notation
- Add support for accessing nested dictionary structures in state variables - Support dot notation (state.key.subkey) and bracket notation (state['key']['subkey']) - Support mixed notation (state.key['subkey']) - Maintain backward compatibility with existing flat state access - Add comprehensive tests for all nested access patterns - Handle optional syntax for nested paths - Proper error handling for missing nested keys Functions added: - _parse_nested_path(): Parse variable paths with dot/bracket notation - _get_nested_value(): Safely traverse nested dictionary structures - _is_valid_state_name_or_nested(): Extended validation for nested paths All existing tests pass, ensuring backward compatibility is maintained.
1 parent c9e2655 commit 43c1cc2

File tree

2 files changed

+222
-5
lines changed

2 files changed

+222
-5
lines changed

src/google/adk/utils/instructions_utils.py

Lines changed: 112 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,19 +93,97 @@ async def _replace_match(match) -> str:
9393
raise KeyError(f'Artifact {var_name} not found.')
9494
return str(artifact)
9595
else:
96-
if not _is_valid_state_name(var_name):
96+
if not _is_valid_state_name_or_nested(var_name):
9797
return match.group()
98-
if var_name in invocation_context.session.state:
99-
return str(invocation_context.session.state[var_name])
100-
else:
98+
99+
try:
100+
keys = _parse_nested_path(var_name)
101+
value = _get_nested_value(invocation_context.session.state, keys)
102+
return str(value)
103+
except KeyError as e:
101104
if optional:
102105
return ''
103106
else:
104-
raise KeyError(f'Context variable not found: `{var_name}`.')
107+
raise KeyError(f'Context variable not found: `{var_name}`. {str(e)}')
105108

106109
return await _async_sub(r'{+[^{}]*}+', _replace_match, template)
107110

108111

112+
def _parse_nested_path(var_name: str) -> list[str]:
113+
"""Parse a nested variable path into individual keys.
114+
115+
Supports both dot notation (key.subkey) and bracket notation (key['subkey']).
116+
Mixed notation is also supported (key.subkey['nested']).
117+
118+
Args:
119+
var_name: The variable name to parse (e.g., "user.profile.name" or "user['profile']['name']")
120+
121+
Returns:
122+
List of keys to traverse the nested structure.
123+
"""
124+
if '.' not in var_name and '[' not in var_name:
125+
return [var_name]
126+
127+
keys = []
128+
current_key = ""
129+
i = 0
130+
131+
while i < len(var_name):
132+
char = var_name[i]
133+
134+
if char == '.':
135+
if current_key:
136+
keys.append(current_key)
137+
current_key = ""
138+
elif char == '[':
139+
if current_key:
140+
keys.append(current_key)
141+
current_key = ""
142+
bracket_end = var_name.find(']', i)
143+
if bracket_end == -1:
144+
raise ValueError(f"Unclosed bracket in variable name: {var_name}")
145+
146+
bracket_content = var_name[i+1:bracket_end]
147+
if (bracket_content.startswith('"') and bracket_content.endswith('"')) or \
148+
(bracket_content.startswith("'") and bracket_content.endswith("'")):
149+
bracket_content = bracket_content[1:-1]
150+
151+
keys.append(bracket_content)
152+
i = bracket_end
153+
else:
154+
current_key += char
155+
156+
i += 1
157+
158+
if current_key:
159+
keys.append(current_key)
160+
161+
return keys
162+
163+
164+
def _get_nested_value(data: dict, keys: list[str]):
165+
"""Get a value from nested dictionary structure using a list of keys.
166+
167+
Args:
168+
data: The dictionary to traverse
169+
keys: List of keys to traverse the nested structure
170+
171+
Returns:
172+
The value at the nested path
173+
174+
Raises:
175+
KeyError: If any key in the path doesn't exist
176+
"""
177+
current = data
178+
for key in keys:
179+
if not isinstance(current, dict):
180+
raise KeyError(f"Cannot access key '{key}' on non-dict value")
181+
if key not in current:
182+
raise KeyError(f"Key '{key}' not found")
183+
current = current[key]
184+
return current
185+
186+
109187
def _is_valid_state_name(var_name):
110188
"""Checks if the variable name is a valid state name.
111189
@@ -129,3 +207,32 @@ def _is_valid_state_name(var_name):
129207
if (parts[0] + ':') in prefixes:
130208
return parts[1].isidentifier()
131209
return False
210+
211+
212+
def _is_valid_state_name_or_nested(var_name: str) -> bool:
213+
"""Checks if the variable name is a valid state name or nested path.
214+
215+
Valid state is either:
216+
- Valid identifier (existing behavior)
217+
- <Valid prefix>:<Valid identifier> (existing behavior)
218+
- Nested path with dot notation (key.subkey.nested)
219+
- Nested path with bracket notation (key['subkey']['nested'])
220+
- Mixed notation (key.subkey['nested'])
221+
222+
Args:
223+
var_name: The variable name to check.
224+
225+
Returns:
226+
True if the variable name is valid, False otherwise.
227+
"""
228+
if _is_valid_state_name(var_name):
229+
return True
230+
231+
try:
232+
keys = _parse_nested_path(var_name)
233+
for key in keys:
234+
if not _is_valid_state_name(key):
235+
return False
236+
return len(keys) > 1
237+
except ValueError:
238+
return False

tests/unittests/utils/test_instructions_utils.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,113 @@ async def test_inject_session_state_artifact_service_not_initialized_raises_valu
214214
await instructions_utils.inject_session_state(
215215
instruction_template, invocation_context
216216
)
217+
218+
219+
@pytest.mark.asyncio
220+
async def test_inject_session_state_with_nested_dot_notation():
221+
instruction_template = "User name: {user.profile.name}, Age: {user.profile.age}"
222+
invocation_context = await _create_test_readonly_context(
223+
state={
224+
"user": {
225+
"profile": {
226+
"name": "John Doe",
227+
"age": 30
228+
}
229+
}
230+
}
231+
)
232+
233+
populated_instruction = await instructions_utils.inject_session_state(
234+
instruction_template, invocation_context
235+
)
236+
assert populated_instruction == "User name: John Doe, Age: 30"
237+
238+
239+
@pytest.mark.asyncio
240+
async def test_inject_session_state_with_nested_bracket_notation():
241+
instruction_template = "User name: {user['profile']['name']}, City: {user['address']['city']}"
242+
invocation_context = await _create_test_readonly_context(
243+
state={
244+
"user": {
245+
"profile": {"name": "Jane Smith"},
246+
"address": {"city": "New York"}
247+
}
248+
}
249+
)
250+
251+
populated_instruction = await instructions_utils.inject_session_state(
252+
instruction_template, invocation_context
253+
)
254+
assert populated_instruction == "User name: Jane Smith, City: New York"
255+
256+
257+
@pytest.mark.asyncio
258+
async def test_inject_session_state_with_mixed_notation():
259+
instruction_template = "Data: {config.database['host']}, Port: {config['database'].port}"
260+
invocation_context = await _create_test_readonly_context(
261+
state={
262+
"config": {
263+
"database": {
264+
"host": "localhost",
265+
"port": 5432
266+
}
267+
}
268+
}
269+
)
270+
271+
populated_instruction = await instructions_utils.inject_session_state(
272+
instruction_template, invocation_context
273+
)
274+
assert populated_instruction == "Data: localhost, Port: 5432"
275+
276+
277+
@pytest.mark.asyncio
278+
async def test_inject_session_state_with_nested_optional():
279+
instruction_template = "Optional nested: {user.profile.nickname?}"
280+
invocation_context = await _create_test_readonly_context(
281+
state={"user": {"profile": {"name": "John"}}}
282+
)
283+
284+
populated_instruction = await instructions_utils.inject_session_state(
285+
instruction_template, invocation_context
286+
)
287+
assert populated_instruction == "Optional nested: "
288+
289+
290+
@pytest.mark.asyncio
291+
async def test_inject_session_state_with_nested_missing_raises_error():
292+
instruction_template = "Missing: {user.profile.missing}"
293+
invocation_context = await _create_test_readonly_context(
294+
state={"user": {"profile": {"name": "John"}}}
295+
)
296+
297+
with pytest.raises(KeyError, match="Context variable not found: `user.profile.missing`"):
298+
await instructions_utils.inject_session_state(
299+
instruction_template, invocation_context
300+
)
301+
302+
303+
@pytest.mark.asyncio
304+
async def test_inject_session_state_with_invalid_nested_path():
305+
instruction_template = "Invalid: {user.profile[unclosed}"
306+
invocation_context = await _create_test_readonly_context(
307+
state={"user": {"profile": {"name": "John"}}}
308+
)
309+
310+
populated_instruction = await instructions_utils.inject_session_state(
311+
instruction_template, invocation_context
312+
)
313+
assert populated_instruction == "Invalid: {user.profile[unclosed}"
314+
315+
316+
@pytest.mark.asyncio
317+
async def test_inject_session_state_backward_compatibility():
318+
instruction_template = "Hello {user_name}, status: {app_status}"
319+
invocation_context = await _create_test_readonly_context(
320+
state={"user_name": "Alice", "app_status": "active"}
321+
)
322+
323+
populated_instruction = await instructions_utils.inject_session_state(
324+
instruction_template, invocation_context
325+
)
326+
assert populated_instruction == "Hello Alice, status: active"

0 commit comments

Comments
 (0)