Skip to content

Commit 089ccb3

Browse files
authored
feat: Store conversation manager in session (strands-agents#441)
1 parent bdff8d5 commit 089ccb3

18 files changed

+307
-79
lines changed

src/strands/agent/agent.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,11 @@ async def _execute_event_loop_cycle(self, invocation_state: dict[str, Any]) -> A
588588
except ContextWindowOverflowException as e:
589589
# Try reducing the context size and retrying
590590
self.conversation_manager.reduce_context(self, e=e)
591+
592+
# Sync agent after reduce_context to keep conversation_manager_state up to date in the session
593+
if self._session_manager:
594+
self._session_manager.sync_agent(self)
595+
591596
events = self._execute_event_loop_cycle(invocation_state)
592597
async for event in events:
593598
yield event

src/strands/agent/conversation_manager/conversation_manager.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from abc import ABC, abstractmethod
44
from typing import TYPE_CHECKING, Any, Optional
55

6+
from ...types.content import Message
7+
68
if TYPE_CHECKING:
79
from ...agent.agent import Agent
810

@@ -18,8 +20,37 @@ class ConversationManager(ABC):
1820
- Maintain relevant conversation state
1921
"""
2022

23+
def __init__(self) -> None:
24+
"""Initialize the ConversationManager.
25+
26+
Attributes:
27+
removed_message_count: The messages that have been removed from the agents messages array.
28+
These represent messages provided by the user or LLM that have been removed, not messages
29+
included by the conversation manager through something like summarization.
30+
"""
31+
self.removed_message_count = 0
32+
33+
def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]:
34+
"""Restore the Conversation Manager's state from a session.
35+
36+
Args:
37+
state: Previous state of the conversation manager
38+
Returns:
39+
Optional list of messages to prepend to the agents messages. By defualt returns None.
40+
"""
41+
if state.get("__name__") != self.__class__.__name__:
42+
raise ValueError("Invalid conversation manager state.")
43+
self.removed_message_count = state["removed_message_count"]
44+
return None
45+
46+
def get_state(self) -> dict[str, Any]:
47+
"""Get the current state of a Conversation Manager as a Json serializable dictionary."""
48+
return {
49+
"__name__": self.__class__.__name__,
50+
"removed_message_count": self.removed_message_count,
51+
}
52+
2153
@abstractmethod
22-
# pragma: no cover
2354
def apply_management(self, agent: "Agent", **kwargs: Any) -> None:
2455
"""Applies management strategy to the provided agent.
2556
@@ -35,7 +66,6 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None:
3566
pass
3667

3768
@abstractmethod
38-
# pragma: no cover
3969
def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs: Any) -> None:
4070
"""Called when the model's context window is exceeded.
4171

src/strands/agent/conversation_manager/sliding_window_conversation_manager.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,37 +6,13 @@
66
if TYPE_CHECKING:
77
from ...agent.agent import Agent
88

9-
from ...types.content import Message, Messages
9+
from ...types.content import Messages
1010
from ...types.exceptions import ContextWindowOverflowException
1111
from .conversation_manager import ConversationManager
1212

1313
logger = logging.getLogger(__name__)
1414

1515

16-
def is_user_message(message: Message) -> bool:
17-
"""Check if a message is from a user.
18-
19-
Args:
20-
message: The message object to check.
21-
22-
Returns:
23-
True if the message has the user role, False otherwise.
24-
"""
25-
return message["role"] == "user"
26-
27-
28-
def is_assistant_message(message: Message) -> bool:
29-
"""Check if a message is from an assistant.
30-
31-
Args:
32-
message: The message object to check.
33-
34-
Returns:
35-
True if the message has the assistant role, False otherwise.
36-
"""
37-
return message["role"] == "assistant"
38-
39-
4016
class SlidingWindowConversationManager(ConversationManager):
4117
"""Implements a sliding window strategy for managing conversation history.
4218
@@ -52,6 +28,7 @@ def __init__(self, window_size: int = 40, should_truncate_results: bool = True):
5228
Defaults to 40 messages.
5329
should_truncate_results: Truncate tool results when a message is too large for the model's context window
5430
"""
31+
super().__init__()
5532
self.window_size = window_size
5633
self.should_truncate_results = should_truncate_results
5734

@@ -129,6 +106,9 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs
129106
# If we didn't find a valid trim_index, then we throw
130107
raise ContextWindowOverflowException("Unable to trim conversation context!") from e
131108

109+
# trim_index represents the number of messages being removed from the agents messages array
110+
self.removed_message_count += trim_index
111+
132112
# Overwrite message history
133113
messages[:] = messages[trim_index:]
134114

src/strands/agent/conversation_manager/summarizing_conversation_manager.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import logging
44
from typing import TYPE_CHECKING, Any, List, Optional
55

6+
from typing_extensions import override
7+
68
from ...types.content import Message
79
from ...types.exceptions import ContextWindowOverflowException
810
from .conversation_manager import ConversationManager
@@ -67,6 +69,7 @@ def __init__(
6769
summarization_system_prompt: Optional system prompt override for summarization.
6870
If None, uses the default summarization prompt.
6971
"""
72+
super().__init__()
7073
if summarization_agent is not None and summarization_system_prompt is not None:
7174
raise ValueError(
7275
"Cannot provide both summarization_agent and summarization_system_prompt. "
@@ -77,6 +80,25 @@ def __init__(
7780
self.preserve_recent_messages = preserve_recent_messages
7881
self.summarization_agent = summarization_agent
7982
self.summarization_system_prompt = summarization_system_prompt
83+
self._summary_message: Optional[Message] = None
84+
85+
@override
86+
def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]:
87+
"""Restores the Summarizing Conversation manager from its previous state in a session.
88+
89+
Args:
90+
state: The previous state of the Summarizing Conversation Manager.
91+
92+
Returns:
93+
Optionally returns the previous conversation summary if it exists.
94+
"""
95+
super().restore_from_session(state)
96+
self._summary_message = state.get("summary_message")
97+
return [self._summary_message] if self._summary_message else None
98+
99+
def get_state(self) -> dict[str, Any]:
100+
"""Returns a dictionary representation of the state for the Summarizing Conversation Manager."""
101+
return {"summary_message": self._summary_message, **super().get_state()}
80102

81103
def apply_management(self, agent: "Agent", **kwargs: Any) -> None:
82104
"""Apply management strategy to conversation history.
@@ -128,11 +150,17 @@ def reduce_context(self, agent: "Agent", e: Optional[Exception] = None, **kwargs
128150
messages_to_summarize = agent.messages[:messages_to_summarize_count]
129151
remaining_messages = agent.messages[messages_to_summarize_count:]
130152

153+
# Keep track of the number of messages that have been summarized thus far.
154+
self.removed_message_count += len(messages_to_summarize)
155+
# If there is a summary message, don't count it in the removed_message_count.
156+
if self._summary_message:
157+
self.removed_message_count -= 1
158+
131159
# Generate summary
132-
summary_message = self._generate_summary(messages_to_summarize, agent)
160+
self._summary_message = self._generate_summary(messages_to_summarize, agent)
133161

134162
# Replace the summarized messages with the summary
135-
agent.messages[:] = [summary_message] + remaining_messages
163+
agent.messages[:] = [self._summary_message] + remaining_messages
136164

137165
except Exception as summarization_error:
138166
logger.error("Summarization failed: %s", summarization_error)

src/strands/session/file_session_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ class FileSessionManager(RepositorySessionManager, SessionRepository):
3030
└── agent_<agent_id>/
3131
├── agent.json # Agent metadata
3232
└── messages/
33-
├── message_<created_timestamp>_<id1>.json
34-
└── message_<created_timestamp>_<id2>.json
33+
├── message_<id1>.json
34+
└── message_<id2>.json
3535
3636
"""
3737

src/strands/session/repository_session_manager.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,25 @@ def initialize(self, agent: Agent) -> None:
126126
agent.agent_id,
127127
self.session_id,
128128
)
129-
session_messages = self.session_repository.list_messages(self.session_id, agent.agent_id)
129+
agent.state = AgentState(session_agent.state)
130+
131+
# Restore the conversation manager to its previous state, and get the optional prepend messages
132+
prepend_messsages = agent.conversation_manager.restore_from_session(
133+
session_agent.conversation_manager_state
134+
)
135+
136+
if prepend_messsages is None:
137+
prepend_messsages = []
138+
139+
# List the messages currently in the session, using an offset of the messages previously removed
140+
# by the converstaion manager.
141+
session_messages = self.session_repository.list_messages(
142+
session_id=self.session_id,
143+
agent_id=agent.agent_id,
144+
offset=agent.conversation_manager.removed_message_count,
145+
)
130146
if len(session_messages) > 0:
131147
self._latest_agent_message[agent.agent_id] = session_messages[-1]
132-
agent.messages = [session_message.to_message() for session_message in session_messages]
133148

134-
agent.state = AgentState(session_agent.state)
149+
# Resore the agents messages array including the optional prepend messages
150+
agent.messages = prepend_messsages + [session_message.to_message() for session_message in session_messages]

src/strands/session/s3_session_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ class S3SessionManager(RepositorySessionManager, SessionRepository):
3131
└── agent_<agent_id>/
3232
├── agent.json # Agent metadata
3333
└── messages/
34-
├── message_<created_timestamp>_<id1>.json
35-
└── message_<created_timestamp>_<id2>.json
34+
├── message_<id1>.json
35+
└── message_<id2>.json
3636
3737
"""
3838

@@ -77,7 +77,7 @@ def __init__(
7777

7878
def _get_session_path(self, session_id: str) -> str:
7979
"""Get session S3 prefix."""
80-
return f"{self.prefix}{SESSION_PREFIX}{session_id}/"
80+
return f"{self.prefix}/{SESSION_PREFIX}{session_id}/"
8181

8282
def _get_agent_path(self, session_id: str, agent_id: str) -> str:
8383
"""Get agent S3 prefix."""

src/strands/session/session_manager.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from abc import ABC, abstractmethod
44
from typing import TYPE_CHECKING, Any
55

6-
from ..hooks.events import AgentInitializedEvent, MessageAddedEvent
6+
from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent
77
from ..hooks.registry import HookProvider, HookRegistry
88
from ..types.content import Message
99

@@ -22,10 +22,18 @@ class SessionManager(HookProvider, ABC):
2222

2323
def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
2424
"""Register hooks for persisting the agent to the session."""
25+
# After the normal Agent initialization behavior, call the session initialize function to restore the agent
2526
registry.add_callback(AgentInitializedEvent, lambda event: self.initialize(event.agent))
27+
28+
# For each message appended to the Agents messages, store that message in the session
2629
registry.add_callback(MessageAddedEvent, lambda event: self.append_message(event.message, event.agent))
30+
31+
# Sync the agent into the session for each message in case the agent state was updated
2732
registry.add_callback(MessageAddedEvent, lambda event: self.sync_agent(event.agent))
2833

34+
# After an agent was invoked, sync it with the session to capture any conversation manager state updates
35+
registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent))
36+
2937
@abstractmethod
3038
def redact_latest_message(self, redact_message: Message, agent: "Agent") -> None:
3139
"""Redact the message most recently appended to the agent in the session.

src/strands/types/session.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class SessionAgent:
106106

107107
agent_id: str
108108
state: Dict[str, Any]
109+
conversation_manager_state: Dict[str, Any]
109110
created_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
110111
updated_at: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
111112

@@ -116,6 +117,7 @@ def from_agent(cls, agent: Agent) -> "SessionAgent":
116117
raise ValueError("agent_id needs to be defined.")
117118
return cls(
118119
agent_id=agent.agent_id,
120+
conversation_manager_state=agent.conversation_manager.get_state(),
119121
state=agent.state.get(),
120122
)
121123

tests/fixtures/mock_session_repository.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from strands.session.session_repository import SessionRepository
22
from strands.types.exceptions import SessionException
3+
from strands.types.session import SessionAgent, SessionMessage
34

45

56
class MockedSessionRepository(SessionRepository):
@@ -11,21 +12,20 @@ def __init__(self):
1112
self.agents = {}
1213
self.messages = {}
1314

14-
def create_session(self, session):
15+
def create_session(self, session) -> None:
1516
"""Create a session."""
1617
session_id = session.session_id
1718
if session_id in self.sessions:
1819
raise SessionException(f"Session {session_id} already exists")
1920
self.sessions[session_id] = session
2021
self.agents[session_id] = {}
2122
self.messages[session_id] = {}
22-
return session
2323

24-
def read_session(self, session_id):
24+
def read_session(self, session_id) -> SessionAgent:
2525
"""Read a session."""
2626
return self.sessions.get(session_id)
2727

28-
def create_agent(self, session_id, session_agent):
28+
def create_agent(self, session_id, session_agent) -> None:
2929
"""Create an agent."""
3030
agent_id = session_agent.agent_id
3131
if session_id not in self.sessions:
@@ -36,13 +36,13 @@ def create_agent(self, session_id, session_agent):
3636
self.messages.setdefault(session_id, {}).setdefault(agent_id, {})
3737
return session_agent
3838

39-
def read_agent(self, session_id, agent_id):
39+
def read_agent(self, session_id, agent_id) -> SessionAgent:
4040
"""Read an agent."""
4141
if session_id not in self.sessions:
4242
return None
4343
return self.agents.get(session_id, {}).get(agent_id)
4444

45-
def update_agent(self, session_id, session_agent):
45+
def update_agent(self, session_id, session_agent) -> None:
4646
"""Update an agent."""
4747
agent_id = session_agent.agent_id
4848
if session_id not in self.sessions:
@@ -51,7 +51,7 @@ def update_agent(self, session_id, session_agent):
5151
raise SessionException(f"Agent {agent_id} does not exist in session {session_id}")
5252
self.agents[session_id][agent_id] = session_agent
5353

54-
def create_message(self, session_id, agent_id, session_message):
54+
def create_message(self, session_id, agent_id, session_message) -> None:
5555
"""Create a message."""
5656
message_id = session_message.message_id
5757
if session_id not in self.sessions:
@@ -62,15 +62,15 @@ def create_message(self, session_id, agent_id, session_message):
6262
raise SessionException(f"Message {message_id} already exists in agent {agent_id} in session {session_id}")
6363
self.messages.setdefault(session_id, {}).setdefault(agent_id, {})[message_id] = session_message
6464

65-
def read_message(self, session_id, agent_id, message_id):
65+
def read_message(self, session_id, agent_id, message_id) -> SessionMessage:
6666
"""Read a message."""
6767
if session_id not in self.sessions:
6868
return None
6969
if agent_id not in self.agents.get(session_id, {}):
7070
return None
7171
return self.messages.get(session_id, {}).get(agent_id, {}).get(message_id)
7272

73-
def update_message(self, session_id, agent_id, session_message):
73+
def update_message(self, session_id, agent_id, session_message) -> None:
7474
"""Update a message."""
7575

7676
message_id = session_message.message_id
@@ -82,7 +82,7 @@ def update_message(self, session_id, agent_id, session_message):
8282
raise SessionException(f"Message {message_id} does not exist in session {session_id}")
8383
self.messages[session_id][agent_id][message_id] = session_message
8484

85-
def list_messages(self, session_id, agent_id, limit=None, offset=0):
85+
def list_messages(self, session_id, agent_id, limit=None, offset=0) -> list[SessionMessage]:
8686
"""List messages."""
8787
if session_id not in self.sessions:
8888
return []

0 commit comments

Comments
 (0)