Skip to content

Commit 71440c8

Browse files
zastrowmratish
authored andcommitted
Add invocation_state to ToolContext (strands-agents#761)
Addresses issue strands-agents#579, strands-agents#750 --------- Co-authored-by: Mackenzie Zastrow <[email protected]>
1 parent 0db9b3c commit 71440c8

File tree

3 files changed

+25
-6
lines changed

3 files changed

+25
-6
lines changed

src/strands/tools/decorator.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,10 +265,13 @@ def inject_special_parameters(
265265
Args:
266266
validated_input: The validated input parameters (modified in place).
267267
tool_use: The tool use request containing tool invocation details.
268-
invocation_state: Context for the tool invocation, including agent state.
268+
invocation_state: Caller-provided kwargs that were passed to the agent when it was invoked (agent(),
269+
agent.invoke_async(), etc.).
269270
"""
270271
if self._context_param and self._context_param in self.signature.parameters:
271-
tool_context = ToolContext(tool_use=tool_use, agent=invocation_state["agent"])
272+
tool_context = ToolContext(
273+
tool_use=tool_use, agent=invocation_state["agent"], invocation_state=invocation_state
274+
)
272275
validated_input[self._context_param] = tool_context
273276

274277
# Inject agent if requested (backward compatibility)
@@ -433,7 +436,8 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
433436
434437
Args:
435438
tool_use: The tool use specification from the Agent.
436-
invocation_state: Context for the tool invocation, including agent state.
439+
invocation_state: Caller-provided kwargs that were passed to the agent when it was invoked (agent(),
440+
agent.invoke_async(), etc.).
437441
**kwargs: Additional keyword arguments for future extensibility.
438442
439443
Yields:

src/strands/types/tools.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ class ToolContext:
132132
tool_use: The complete ToolUse object containing tool invocation details.
133133
agent: The Agent instance executing this tool, providing access to conversation history,
134134
model configuration, and other agent state.
135+
invocation_state: Caller-provided kwargs that were passed to the agent when it was invoked (agent(),
136+
agent.invoke_async(), etc.).
135137
136138
Note:
137139
This class is intended to be instantiated by the SDK. Direct construction by users
@@ -140,6 +142,7 @@ class ToolContext:
140142

141143
tool_use: ToolUse
142144
agent: "Agent"
145+
invocation_state: dict[str, Any]
143146

144147

145148
ToolChoice = Union[
@@ -246,7 +249,8 @@ def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kwargs:
246249
247250
Args:
248251
tool_use: The tool use request containing tool ID and parameters.
249-
invocation_state: Context for the tool invocation, including agent state.
252+
invocation_state: Caller-provided kwargs that were passed to the agent when it was invoked (agent(),
253+
agent.invoke_async(), etc.).
250254
**kwargs: Additional keyword arguments for future extensibility.
251255
252256
Yields:

tests/strands/tools/test_decorator.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Tests for the function-based tool decorator pattern.
33
"""
44

5+
from asyncio import Queue
56
from typing import Any, Dict, Optional, Union
67
from unittest.mock import MagicMock
78

@@ -1039,7 +1040,7 @@ def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None]
10391040
assert "NoneType: None" in result["content"][0]["text"]
10401041

10411042

1042-
async def _run_context_injection_test(context_tool: AgentTool):
1043+
async def _run_context_injection_test(context_tool: AgentTool, additional_context=None):
10431044
"""Common test logic for context injection tests."""
10441045
tool: AgentTool = context_tool
10451046
generator = tool.stream(
@@ -1052,6 +1053,7 @@ async def _run_context_injection_test(context_tool: AgentTool):
10521053
},
10531054
invocation_state={
10541055
"agent": Agent(name="test_agent"),
1056+
**(additional_context or {}),
10551057
},
10561058
)
10571059
tool_results = [value async for value in generator]
@@ -1074,13 +1076,17 @@ async def _run_context_injection_test(context_tool: AgentTool):
10741076
async def test_tool_context_injection_default():
10751077
"""Test that ToolContext is properly injected with default parameter name (tool_context)."""
10761078

1079+
value_to_pass = Queue() # a complex value that is not serializable
1080+
10771081
@strands.tool(context=True)
10781082
def context_tool(message: str, agent: Agent, tool_context: ToolContext) -> dict:
10791083
"""Tool that uses ToolContext to access tool_use_id."""
10801084
tool_use_id = tool_context.tool_use["toolUseId"]
10811085
tool_name = tool_context.tool_use["name"]
10821086
agent_from_tool_context = tool_context.agent
10831087

1088+
assert tool_context.invocation_state["test_reference"] is value_to_pass
1089+
10841090
return {
10851091
"status": "success",
10861092
"content": [
@@ -1090,7 +1096,12 @@ def context_tool(message: str, agent: Agent, tool_context: ToolContext) -> dict:
10901096
],
10911097
}
10921098

1093-
await _run_context_injection_test(context_tool)
1099+
await _run_context_injection_test(
1100+
context_tool,
1101+
{
1102+
"test_reference": value_to_pass,
1103+
},
1104+
)
10941105

10951106

10961107
@pytest.mark.asyncio

0 commit comments

Comments
 (0)