Skip to content

Commit f286fc7

Browse files
Unshuredbschmigelski
authored andcommitted
fix: Add AgentInput TypeAlias (strands-agents#738)
1 parent fa2cd36 commit f286fc7

File tree

7 files changed

+28
-18
lines changed

7 files changed

+28
-18
lines changed

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as
4949

5050
Alternatively, install development dependencies in a manually created virtual environment:
5151
```bash
52-
pip install -e ".[dev]" && pip install -e ".[litellm]"
52+
pip install -e ".[all]"
5353
```
5454

5555

src/strands/agent/agent.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,19 @@
1414
import logging
1515
import random
1616
from concurrent.futures import ThreadPoolExecutor
17-
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Mapping, Optional, Type, TypeVar, Union, cast
17+
from typing import (
18+
Any,
19+
AsyncGenerator,
20+
AsyncIterator,
21+
Callable,
22+
Mapping,
23+
Optional,
24+
Type,
25+
TypeAlias,
26+
TypeVar,
27+
Union,
28+
cast,
29+
)
1830

1931
from opentelemetry import trace as trace_api
2032
from pydantic import BaseModel
@@ -55,6 +67,8 @@
5567
# TypeVar for generic structured output
5668
T = TypeVar("T", bound=BaseModel)
5769

70+
AgentInput: TypeAlias = str | list[ContentBlock] | Messages | None
71+
5872

5973
# Sentinel class and object to distinguish between explicit None and default parameter value
6074
class _DefaultCallbackHandlerSentinel:
@@ -361,7 +375,7 @@ def tool_names(self) -> list[str]:
361375
all_tools = self.tool_registry.get_all_tools_config()
362376
return list(all_tools.keys())
363377

364-
def __call__(self, prompt: str | list[ContentBlock] | Messages | None = None, **kwargs: Any) -> AgentResult:
378+
def __call__(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult:
365379
"""Process a natural language prompt through the agent's event loop.
366380
367381
This method implements the conversational interface with multiple input patterns:
@@ -394,9 +408,7 @@ def execute() -> AgentResult:
394408
future = executor.submit(execute)
395409
return future.result()
396410

397-
async def invoke_async(
398-
self, prompt: str | list[ContentBlock] | Messages | None = None, **kwargs: Any
399-
) -> AgentResult:
411+
async def invoke_async(self, prompt: AgentInput = None, **kwargs: Any) -> AgentResult:
400412
"""Process a natural language prompt through the agent's event loop.
401413
402414
This method implements the conversational interface with multiple input patterns:
@@ -427,7 +439,7 @@ async def invoke_async(
427439

428440
return cast(AgentResult, event["result"])
429441

430-
def structured_output(self, output_model: Type[T], prompt: str | list[ContentBlock] | Messages | None = None) -> T:
442+
def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> T:
431443
"""This method allows you to get structured output from the agent.
432444
433445
If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
@@ -456,9 +468,7 @@ def execute() -> T:
456468
future = executor.submit(execute)
457469
return future.result()
458470

459-
async def structured_output_async(
460-
self, output_model: Type[T], prompt: str | list[ContentBlock] | Messages | None = None
461-
) -> T:
471+
async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T:
462472
"""This method allows you to get structured output from the agent.
463473
464474
If you pass in a prompt, it will be used temporarily without adding it to the conversation history.
@@ -517,7 +527,7 @@ async def structured_output_async(
517527

518528
async def stream_async(
519529
self,
520-
prompt: str | list[ContentBlock] | Messages | None = None,
530+
prompt: AgentInput = None,
521531
**kwargs: Any,
522532
) -> AsyncIterator[Any]:
523533
"""Process a natural language prompt and yield events as an async iterator.
@@ -657,7 +667,7 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A
657667
async for event in events:
658668
yield event
659669

660-
def _convert_prompt_to_messages(self, prompt: str | list[ContentBlock] | Messages | None) -> Messages:
670+
def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages:
661671
messages: Messages | None = None
662672
if prompt is not None:
663673
if isinstance(prompt, str):

src/strands/session/file_session_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) ->
9292
"""
9393
if not isinstance(message_id, int):
9494
raise ValueError(f"message_id=<{message_id}> | message id must be an integer")
95-
95+
9696
agent_path = self._get_agent_path(session_id, agent_id)
9797
return os.path.join(agent_path, "messages", f"{MESSAGE_PREFIX}{message_id}.json")
9898

src/strands/session/s3_session_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,13 @@ def _get_message_path(self, session_id: str, agent_id: str, message_id: int) ->
116116
117117
Returns:
118118
The key for the message
119-
119+
120120
Raises:
121121
ValueError: If message_id is not an integer.
122122
"""
123123
if not isinstance(message_id, int):
124124
raise ValueError(f"message_id=<{message_id}> | message id must be an integer")
125-
125+
126126
agent_path = self._get_agent_path(session_id, agent_id)
127127
return f"{agent_path}messages/{MESSAGE_PREFIX}{message_id}.json"
128128

tests/strands/agent/test_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1830,6 +1830,7 @@ def test_agent_with_list_of_message_and_content_block():
18301830
with pytest.raises(ValueError, match="Input prompt must be of type: `str | list[Contentblock] | Messages | None`."):
18311831
agent([{"role": "user", "content": [{"text": "hello"}]}, {"text", "hello"}])
18321832

1833+
18331834
def test_agent_tool_call_parameter_filtering_integration(mock_randint):
18341835
"""Test that tool calls properly filter parameters in message recording."""
18351836
mock_randint.return_value = 42
@@ -1861,4 +1862,3 @@ def test_tool(action: str) -> str:
18611862
assert '"action": "test_value"' in tool_call_text
18621863
assert '"agent"' not in tool_call_text
18631864
assert '"extra_param"' not in tool_call_text
1864-

tests/strands/session/test_file_session_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def test__get_agent_path_invalid_agent_id(agent_id, file_manager):
396396
"message_id",
397397
[
398398
"../../../secret",
399-
"../../attack",
399+
"../../attack",
400400
"../escape",
401401
"path/traversal",
402402
"not_an_int",

tests/strands/session/test_s3_session_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def test__get_agent_path_invalid_agent_id(agent_id, s3_manager):
362362
"message_id",
363363
[
364364
"../../../secret",
365-
"../../attack",
365+
"../../attack",
366366
"../escape",
367367
"path/traversal",
368368
"not_an_int",

0 commit comments

Comments
 (0)