Skip to content

Commit 0d2d01a

Browse files
committed
fix(strands): make integration stateless and fix duplicate tool rendering
Signed-off-by: Tyler Slaton <[email protected]>
1 parent 04d310c commit 0d2d01a

File tree

1 file changed

+36
-35
lines changed
  • integrations/aws-strands/python/src/ag_ui_strands

1 file changed

+36
-35
lines changed

integrations/aws-strands/python/src/ag_ui_strands/agent.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import json
99
import uuid
1010
import asyncio
11-
import sys
1211
from typing import AsyncIterator, Any
1312
from strands import Agent as StrandsAgentCore
1413

@@ -93,18 +92,42 @@ async def run(self, input_data: RunAgentInput) -> AsyncIterator[Any]:
9392
has_pending_tool_result = True
9493
logger.debug(f"Has pending tool result detected: tool_call_id={getattr(last_msg, 'tool_call_id', 'unknown')}, thread_id={input_data.thread_id}")
9594

96-
# Get the latest user message
95+
# Convert AG-UI messages to Strands format
96+
# Strands Message format: {"role": str, "content": List[ContentBlock]}
97+
strands_messages = []
98+
for msg in input_data.messages:
99+
# Skip tool messages as Strands handles them internally
100+
if msg.role == "tool":
101+
continue
102+
103+
# Convert message content to Strands format
104+
if isinstance(msg.content, str):
105+
strands_messages.append({
106+
"role": msg.role,
107+
"content": [{"text": msg.content}]
108+
})
109+
elif isinstance(msg.content, list):
110+
# Already in content block format
111+
strands_messages.append({
112+
"role": msg.role,
113+
"content": msg.content
114+
})
115+
116+
# Get the latest user message for state context builder
97117
user_message = "Hello"
98118
if input_data.messages:
99119
for msg in reversed(input_data.messages):
100-
if (msg.role == "user" or msg.role == "tool") and msg.content:
120+
if msg.role == "user" and msg.content:
101121
user_message = msg.content
102122
break
103123

104124
# Optionally allow configuration to adjust the outgoing user message
105125
if self.config.state_context_builder:
106126
try:
107127
user_message = self.config.state_context_builder(input_data, user_message)
128+
# If state_context_builder modifies the message, update the last user message
129+
if strands_messages and strands_messages[-1]["role"] == "user":
130+
strands_messages[-1]["content"] = [{"text": user_message}]
108131
except Exception:
109132
# If the builder fails, keep the original message
110133
pass
@@ -116,10 +139,14 @@ async def run(self, input_data: RunAgentInput) -> AsyncIterator[Any]:
116139
stop_text_streaming = False
117140
halt_event_stream = False
118141

119-
logger.debug(f"Starting agent run: thread_id={input_data.thread_id}, run_id={input_data.run_id}, has_pending_tool_result={has_pending_tool_result}, message_count={len(input_data.messages)}")
142+
logger.debug(f"Starting agent run: thread_id={input_data.thread_id}, run_id={input_data.run_id}, has_pending_tool_result={has_pending_tool_result}, message_count={len(input_data.messages)}, strands_message_count={len(strands_messages)}")
120143

121-
# Stream from Strands agent
122-
agent_stream = self.strands_agent.stream_async(user_message)
144+
# Reset agent messages to prevent accumulation across threads
145+
# The frontend manages conversation history via input_data.messages
146+
self.strands_agent.messages = []
147+
148+
# Stream from Strands agent with full conversation history
149+
agent_stream = self.strands_agent.stream_async(strands_messages if strands_messages else user_message)
123150

124151
try:
125152
async for event in agent_stream:
@@ -196,35 +223,9 @@ async def run(self, input_data: RunAgentInput) -> AsyncIterator[Any]:
196223

197224
logger.debug(f"Processing tool result: tool_name={tool_name}, result_tool_id={result_tool_id}, has_pending_tool_result={has_pending_tool_result}, thread_id={input_data.thread_id}")
198225

199-
if not has_pending_tool_result and not (behavior and behavior.skip_messages_snapshot):
200-
assistant_msg = AssistantMessage(
201-
id=str(uuid.uuid4()),
202-
role="assistant",
203-
tool_calls=[
204-
ToolCall(
205-
id=result_tool_id,
206-
type="function",
207-
function={
208-
"name": tool_name or "default_tool",
209-
"arguments": tool_args or "{}",
210-
},
211-
)
212-
],
213-
)
214-
215-
content_str = json.dumps(result_data) if isinstance(result_data, dict) else str(result_data)
216-
tool_msg = ToolMessage(
217-
id=str(uuid.uuid4()),
218-
role="tool",
219-
content=content_str,
220-
tool_call_id=result_tool_id,
221-
)
222-
223-
all_messages = list(input_data.messages) + [assistant_msg, tool_msg]
224-
yield MessagesSnapshotEvent(
225-
type=EventType.MESSAGES_SNAPSHOT,
226-
messages=all_messages
227-
)
226+
# Skip MessagesSnapshotEvent for tool results to avoid duplicates
227+
# The frontend already has the tool call from TOOL_CALL_* events
228+
# and will construct the messages itself
228229

229230
result_context = ToolResultContext(
230231
input_data=input_data,

0 commit comments

Comments
 (0)