Skip to content

Commit a7ce539

Browse files
zastrowmratish
authored andcommitted
feat: Use TypedEvent inheritance for callback behavior (strands-agents#755)
Move away from "callback" nested properties in the dict and explicitly passing invocation_state migrating to behaviors on the TypedEvent: - TypedEvent.is_callback_event for determining if an event should be yielded and or invoked in the callback - TypedEvent.prepare for taking in invocation_state Customers still only get dictionaries, as we decided that this will remain an implementation detail for the time being, but this makes the events typed all the way up until *just* before we yield events back to the caller --------- Co-authored-by: Mackenzie Zastrow <[email protected]>
1 parent 89c7431 commit a7ce539

File tree

12 files changed

+288
-215
lines changed

12 files changed

+288
-215
lines changed

src/strands/agent/agent.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from ..tools.executors._executor import ToolExecutor
5151
from ..tools.registry import ToolRegistry
5252
from ..tools.watcher import ToolWatcher
53-
from ..types._events import InitEventLoopEvent
53+
from ..types._events import AgentResultEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent
5454
from ..types.agent import AgentInput
5555
from ..types.content import ContentBlock, Message, Messages
5656
from ..types.exceptions import ContextWindowOverflowException
@@ -576,23 +576,24 @@ async def stream_async(
576576
events = self._run_loop(messages, invocation_state=kwargs)
577577

578578
async for event in events:
579-
if "callback" in event:
580-
callback_handler(**event["callback"])
581-
yield event["callback"]
579+
event.prepare(invocation_state=kwargs)
580+
581+
if event.is_callback_event:
582+
as_dict = event.as_dict()
583+
callback_handler(**as_dict)
584+
yield as_dict
582585

583586
result = AgentResult(*event["stop"])
584587
callback_handler(result=result)
585-
yield {"result": result}
588+
yield AgentResultEvent(result=result).as_dict()
586589

587590
self._end_agent_trace_span(response=result)
588591

589592
except Exception as e:
590593
self._end_agent_trace_span(error=e)
591594
raise
592595

593-
async def _run_loop(
594-
self, messages: Messages, invocation_state: dict[str, Any]
595-
) -> AsyncGenerator[dict[str, Any], None]:
596+
async def _run_loop(self, messages: Messages, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
596597
"""Execute the agent's event loop with the given message and parameters.
597598
598599
Args:
@@ -605,7 +606,7 @@ async def _run_loop(
605606
self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self))
606607

607608
try:
608-
yield InitEventLoopEvent(invocation_state)
609+
yield InitEventLoopEvent()
609610

610611
for message in messages:
611612
self._append_message(message)
@@ -616,13 +617,13 @@ async def _run_loop(
616617
# Signal from the model provider that the message sent by the user should be redacted,
617618
# likely due to a guardrail.
618619
if (
619-
event.get("callback")
620-
and event["callback"].get("event")
621-
and event["callback"]["event"].get("redactContent")
622-
and event["callback"]["event"]["redactContent"].get("redactUserContentMessage")
620+
isinstance(event, ModelStreamChunkEvent)
621+
and event.chunk
622+
and event.chunk.get("redactContent")
623+
and event.chunk["redactContent"].get("redactUserContentMessage")
623624
):
624625
self.messages[-1]["content"] = [
625-
{"text": event["callback"]["event"]["redactContent"]["redactUserContentMessage"]}
626+
{"text": str(event.chunk["redactContent"]["redactUserContentMessage"])}
626627
]
627628
if self._session_manager:
628629
self._session_manager.redact_latest_message(self.messages[-1], self)
@@ -632,7 +633,7 @@ async def _run_loop(
632633
self.conversation_manager.apply_management(self)
633634
self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self))
634635

635-
async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
636+
async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
636637
"""Execute the event loop cycle with retry logic for context window limits.
637638
638639
This internal method handles the execution of the event loop cycle and implements

src/strands/event_loop/event_loop.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
EventLoopThrottleEvent,
3131
ForceStopEvent,
3232
ModelMessageEvent,
33+
ModelStopReason,
3334
StartEvent,
3435
StartEventLoopEvent,
3536
ToolResultMessageEvent,
37+
TypedEvent,
3638
)
3739
from ..types.content import Message
3840
from ..types.exceptions import (
@@ -56,7 +58,7 @@
5658
MAX_DELAY = 240 # 4 minutes
5759

5860

59-
async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
61+
async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
6062
"""Execute a single cycle of the event loop.
6163
6264
This core function processes a single conversation turn, handling model inference, tool execution, and error
@@ -139,17 +141,9 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
139141
)
140142

141143
try:
142-
# TODO: To maintain backwards compatibility, we need to combine the stream event with invocation_state
143-
# before yielding to the callback handler. This will be revisited when migrating to strongly
144-
# typed events.
145144
async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs):
146-
if "callback" in event:
147-
yield {
148-
"callback": {
149-
**event["callback"],
150-
**(invocation_state if "delta" in event["callback"] else {}),
151-
}
152-
}
145+
if not isinstance(event, ModelStopReason):
146+
yield event
153147

154148
stop_reason, message, usage, metrics = event["stop"]
155149
invocation_state.setdefault("request_state", {})
@@ -198,7 +192,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
198192
time.sleep(current_delay)
199193
current_delay = min(current_delay * 2, MAX_DELAY)
200194

201-
yield EventLoopThrottleEvent(delay=current_delay, invocation_state=invocation_state)
195+
yield EventLoopThrottleEvent(delay=current_delay)
202196
else:
203197
raise e
204198

@@ -280,7 +274,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
280274
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
281275

282276

283-
async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[dict[str, Any], None]:
277+
async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
284278
"""Make a recursive call to event_loop_cycle with the current state.
285279
286280
This function is used when the event loop needs to continue processing after tool execution.
@@ -321,7 +315,7 @@ async def _handle_tool_execution(
321315
cycle_span: Any,
322316
cycle_start_time: float,
323317
invocation_state: dict[str, Any],
324-
) -> AsyncGenerator[dict[str, Any], None]:
318+
) -> AsyncGenerator[TypedEvent, None]:
325319
"""Handles the execution of tools requested by the model during an event loop cycle.
326320
327321
Args:

src/strands/tools/executors/_executor.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,16 @@
77
import abc
88
import logging
99
import time
10-
from typing import TYPE_CHECKING, Any, cast
10+
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast
1111

1212
from opentelemetry import trace as trace_api
1313

1414
from ...experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent
1515
from ...telemetry.metrics import Trace
1616
from ...telemetry.tracer import get_tracer
17+
from ...types._events import ToolResultEvent, ToolStreamEvent, TypedEvent
1718
from ...types.content import Message
18-
from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolGenerator, ToolResult, ToolUse
19+
from ...types.tools import ToolChoice, ToolChoiceAuto, ToolConfig, ToolResult, ToolUse
1920

2021
if TYPE_CHECKING: # pragma: no cover
2122
from ...agent import Agent
@@ -33,7 +34,7 @@ async def _stream(
3334
tool_results: list[ToolResult],
3435
invocation_state: dict[str, Any],
3536
**kwargs: Any,
36-
) -> ToolGenerator:
37+
) -> AsyncGenerator[TypedEvent, None]:
3738
"""Stream tool events.
3839
3940
This method adds additional logic to the stream invocation including:
@@ -113,12 +114,12 @@ async def _stream(
113114
result=result,
114115
)
115116
)
116-
yield after_event.result
117+
yield ToolResultEvent(after_event.result)
117118
tool_results.append(after_event.result)
118119
return
119120

120121
async for event in selected_tool.stream(tool_use, invocation_state, **kwargs):
121-
yield event
122+
yield ToolStreamEvent(tool_use, event)
122123

123124
result = cast(ToolResult, event)
124125

@@ -131,7 +132,8 @@ async def _stream(
131132
result=result,
132133
)
133134
)
134-
yield after_event.result
135+
136+
yield ToolResultEvent(after_event.result)
135137
tool_results.append(after_event.result)
136138

137139
except Exception as e:
@@ -151,7 +153,7 @@ async def _stream(
151153
exception=e,
152154
)
153155
)
154-
yield after_event.result
156+
yield ToolResultEvent(after_event.result)
155157
tool_results.append(after_event.result)
156158

157159
@staticmethod
@@ -163,7 +165,7 @@ async def _stream_with_trace(
163165
cycle_span: Any,
164166
invocation_state: dict[str, Any],
165167
**kwargs: Any,
166-
) -> ToolGenerator:
168+
) -> AsyncGenerator[TypedEvent, None]:
167169
"""Execute tool with tracing and metrics collection.
168170
169171
Args:
@@ -190,7 +192,8 @@ async def _stream_with_trace(
190192
async for event in ToolExecutor._stream(agent, tool_use, tool_results, invocation_state, **kwargs):
191193
yield event
192194

193-
result = cast(ToolResult, event)
195+
result_event = cast(ToolResultEvent, event)
196+
result = result_event.tool_result
194197

195198
tool_success = result.get("status") == "success"
196199
tool_duration = time.time() - tool_start_time
@@ -210,7 +213,7 @@ def _execute(
210213
cycle_trace: Trace,
211214
cycle_span: Any,
212215
invocation_state: dict[str, Any],
213-
) -> ToolGenerator:
216+
) -> AsyncGenerator[TypedEvent, None]:
214217
"""Execute the given tools according to this executor's strategy.
215218
216219
Args:

src/strands/tools/executors/concurrent.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""Concurrent tool executor implementation."""
22

33
import asyncio
4-
from typing import TYPE_CHECKING, Any
4+
from typing import TYPE_CHECKING, Any, AsyncGenerator
55

66
from typing_extensions import override
77

88
from ...telemetry.metrics import Trace
9-
from ...types.tools import ToolGenerator, ToolResult, ToolUse
9+
from ...types._events import TypedEvent
10+
from ...types.tools import ToolResult, ToolUse
1011
from ._executor import ToolExecutor
1112

1213
if TYPE_CHECKING: # pragma: no cover
@@ -25,7 +26,7 @@ async def _execute(
2526
cycle_trace: Trace,
2627
cycle_span: Any,
2728
invocation_state: dict[str, Any],
28-
) -> ToolGenerator:
29+
) -> AsyncGenerator[TypedEvent, None]:
2930
"""Execute tools concurrently.
3031
3132
Args:

src/strands/tools/executors/sequential.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""Sequential tool executor implementation."""
22

3-
from typing import TYPE_CHECKING, Any
3+
from typing import TYPE_CHECKING, Any, AsyncGenerator
44

55
from typing_extensions import override
66

77
from ...telemetry.metrics import Trace
8-
from ...types.tools import ToolGenerator, ToolResult, ToolUse
8+
from ...types._events import TypedEvent
9+
from ...types.tools import ToolResult, ToolUse
910
from ._executor import ToolExecutor
1011

1112
if TYPE_CHECKING: # pragma: no cover
@@ -24,7 +25,7 @@ async def _execute(
2425
cycle_trace: Trace,
2526
cycle_span: Any,
2627
invocation_state: dict[str, Any],
27-
) -> ToolGenerator:
28+
) -> AsyncGenerator[TypedEvent, None]:
2829
"""Execute tools sequentially.
2930
3031
Args:

0 commit comments

Comments
 (0)