Skip to content

Commit 1edd81a

Browse files
authored
multiagent - use invoke_async instead of stream_async (strands-agents#463)
1 parent 1f64b4b commit 1edd81a

File tree

4 files changed

+30
-47
lines changed

4 files changed

+30
-47
lines changed

src/strands/multiagent/graph.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
import time
2020
from concurrent.futures import ThreadPoolExecutor
2121
from dataclasses import dataclass, field
22-
from typing import Any, Callable, Tuple, cast
22+
from typing import Any, Callable, Tuple
2323

2424
from opentelemetry import trace as trace_api
2525

26-
from ..agent import Agent, AgentResult
26+
from ..agent import Agent
2727
from ..telemetry import get_tracer
2828
from ..types.content import ContentBlock
2929
from ..types.event_loop import Metrics, Usage
@@ -379,15 +379,7 @@ async def _execute_node(self, node: GraphNode) -> None:
379379
)
380380

381381
elif isinstance(node.executor, Agent):
382-
agent_response: AgentResult | None = (
383-
None # Initialize with None to handle case where no result is yielded
384-
)
385-
async for event in node.executor.stream_async(node_input):
386-
if "result" in event:
387-
agent_response = cast(AgentResult, event["result"])
388-
389-
if not agent_response:
390-
raise ValueError(f"Node '{node.node_id}' did not return a result")
382+
agent_response = await node.executor.invoke_async(node_input)
391383

392384
# Extract metrics from agent response
393385
usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)

src/strands/multiagent/swarm.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import time
2020
from concurrent.futures import ThreadPoolExecutor
2121
from dataclasses import dataclass, field
22-
from typing import Any, Callable, Tuple, cast
22+
from typing import Any, Callable, Tuple
2323

2424
from ..agent import Agent, AgentResult
2525
from ..agent.state import AgentState
@@ -601,12 +601,7 @@ async def _execute_node(self, node: SwarmNode, task: str | list[ContentBlock]) -
601601
# Execute node
602602
result = None
603603
node.reset_executor_state()
604-
async for event in node.executor.stream_async(node_input):
605-
if "result" in event:
606-
result = cast(AgentResult, event["result"])
607-
608-
if not result:
609-
raise ValueError(f"Node '{node_name}' did not return a result")
604+
result = await node.executor.invoke_async(node_input)
610605

611606
execution_time = round((time.time() - start_time) * 1000)
612607

tests/strands/multiagent/test_graph.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen
2929
agent.return_value = mock_result
3030
agent.__call__ = Mock(return_value=mock_result)
3131

32-
async def mock_stream_async(*args, **kwargs):
33-
yield {"result": mock_result}
32+
async def mock_invoke_async(*args, **kwargs):
33+
return mock_result
3434

35-
agent.stream_async = MagicMock(side_effect=mock_stream_async)
35+
agent.invoke_async = MagicMock(side_effect=mock_invoke_async)
3636

3737
return agent
3838

@@ -194,14 +194,14 @@ async def test_graph_execution(mock_strands_tracer, mock_use_span, mock_graph, m
194194
assert result.execution_order[0].node_id == "start_agent"
195195

196196
# Verify agent calls
197-
mock_agents["start_agent"].stream_async.assert_called_once()
197+
mock_agents["start_agent"].invoke_async.assert_called_once()
198198
mock_agents["multi_agent"].invoke_async.assert_called_once()
199-
mock_agents["conditional_agent"].stream_async.assert_called_once()
200-
mock_agents["final_agent"].stream_async.assert_called_once()
201-
mock_agents["no_metrics_agent"].stream_async.assert_called_once()
202-
mock_agents["partial_metrics_agent"].stream_async.assert_called_once()
203-
string_content_agent.stream_async.assert_called_once()
204-
mock_agents["blocked_agent"].stream_async.assert_not_called()
199+
mock_agents["conditional_agent"].invoke_async.assert_called_once()
200+
mock_agents["final_agent"].invoke_async.assert_called_once()
201+
mock_agents["no_metrics_agent"].invoke_async.assert_called_once()
202+
mock_agents["partial_metrics_agent"].invoke_async.assert_called_once()
203+
string_content_agent.invoke_async.assert_called_once()
204+
mock_agents["blocked_agent"].invoke_async.assert_not_called()
205205

206206
# Verify metrics aggregation
207207
assert result.accumulated_usage["totalTokens"] > 0
@@ -261,12 +261,10 @@ async def test_graph_execution_with_failures(mock_strands_tracer, mock_use_span)
261261
failing_agent.id = "fail_node"
262262
failing_agent.__call__ = Mock(side_effect=Exception("Simulated failure"))
263263

264-
# Create a proper failing async generator for stream_async
265-
async def mock_stream_failure(*args, **kwargs):
264+
async def mock_invoke_failure(*args, **kwargs):
266265
raise Exception("Simulated failure")
267-
yield # This will never be reached
268266

269-
failing_agent.stream_async = mock_stream_failure
267+
failing_agent.invoke_async = mock_invoke_failure
270268

271269
success_agent = create_mock_agent("success_agent", "Success")
272270

@@ -301,7 +299,7 @@ async def test_graph_edge_cases(mock_strands_tracer, mock_use_span):
301299
result = await graph.invoke_async([{"text": "Original task"}])
302300

303301
# Verify entry node was called with original task
304-
entry_agent.stream_async.assert_called_once_with([{"text": "Original task"}])
302+
entry_agent.invoke_async.assert_called_once_with([{"text": "Original task"}])
305303
assert result.status == Status.COMPLETED
306304
mock_strands_tracer.start_multiagent_span.assert_called()
307305
mock_use_span.assert_called_once()
@@ -482,8 +480,8 @@ def test_graph_synchronous_execution(mock_strands_tracer, mock_use_span, mock_ag
482480
assert result.execution_order[1].node_id == "final_agent"
483481

484482
# Verify agent calls
485-
mock_agents["start_agent"].stream_async.assert_called_once()
486-
mock_agents["final_agent"].stream_async.assert_called_once()
483+
mock_agents["start_agent"].invoke_async.assert_called_once()
484+
mock_agents["final_agent"].invoke_async.assert_called_once()
487485

488486
# Verify return type is GraphResult
489487
assert isinstance(result, GraphResult)

tests/strands/multiagent/test_swarm.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,10 @@ def create_mock_result():
5656
agent.return_value = create_mock_result()
5757
agent.__call__ = Mock(side_effect=create_mock_result)
5858

59-
async def mock_stream_async(*args, **kwargs):
60-
result = create_mock_result()
61-
yield {"result": result}
59+
async def mock_invoke_async(*args, **kwargs):
60+
return create_mock_result()
6261

63-
agent.stream_async = MagicMock(side_effect=mock_stream_async)
62+
agent.invoke_async = MagicMock(side_effect=mock_invoke_async)
6463

6564
return agent
6665

@@ -227,7 +226,7 @@ async def test_swarm_execution_async(mock_swarm, mock_agents):
227226
assert len(result.results) == 1
228227

229228
# Verify agent was called
230-
mock_agents["coordinator"].stream_async.assert_called()
229+
mock_agents["coordinator"].invoke_async.assert_called()
231230

232231
# Verify metrics aggregation
233232
assert result.accumulated_usage["totalTokens"] >= 0
@@ -264,7 +263,7 @@ def test_swarm_synchronous_execution(mock_agents):
264263
assert result.execution_time >= 0
265264

266265
# Verify agent was called
267-
mock_agents["coordinator"].stream_async.assert_called()
266+
mock_agents["coordinator"].invoke_async.assert_called()
268267

269268
# Verify return type is SwarmResult
270269
assert isinstance(result, SwarmResult)
@@ -350,11 +349,10 @@ def create_handoff_result():
350349
agent.return_value = create_handoff_result()
351350
agent.__call__ = Mock(side_effect=create_handoff_result)
352351

353-
async def mock_stream_async(*args, **kwargs):
354-
result = create_handoff_result()
355-
yield {"result": result}
352+
async def mock_invoke_async(*args, **kwargs):
353+
return create_handoff_result()
356354

357-
agent.stream_async = MagicMock(side_effect=mock_stream_async)
355+
agent.invoke_async = MagicMock(side_effect=mock_invoke_async)
358356
return agent
359357

360358
# Create agents - first one hands off, second one completes
@@ -381,8 +379,8 @@ async def mock_stream_async(*args, **kwargs):
381379
assert result.node_history[1].node_id == "completion_agent"
382380

383381
# Verify both agents were called
384-
handoff_agent.stream_async.assert_called()
385-
completion_agent.stream_async.assert_called()
382+
handoff_agent.invoke_async.assert_called()
383+
completion_agent.invoke_async.assert_called()
386384

387385
# Test handoff when task is already completed
388386
completed_swarm = Swarm(nodes=[handoff_agent, completion_agent])

0 commit comments

Comments
 (0)