@@ -29,10 +29,10 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen
29
29
agent .return_value = mock_result
30
30
agent .__call__ = Mock (return_value = mock_result )
31
31
32
- async def mock_stream_async (* args , ** kwargs ):
33
- yield { "result" : mock_result }
32
+ async def mock_invoke_async (* args , ** kwargs ):
33
+ return mock_result
34
34
35
- agent .stream_async = MagicMock (side_effect = mock_stream_async )
35
+ agent .invoke_async = MagicMock (side_effect = mock_invoke_async )
36
36
37
37
return agent
38
38
@@ -194,14 +194,14 @@ async def test_graph_execution(mock_strands_tracer, mock_use_span, mock_graph, m
194
194
assert result .execution_order [0 ].node_id == "start_agent"
195
195
196
196
# Verify agent calls
197
- mock_agents ["start_agent" ].stream_async .assert_called_once ()
197
+ mock_agents ["start_agent" ].invoke_async .assert_called_once ()
198
198
mock_agents ["multi_agent" ].invoke_async .assert_called_once ()
199
- mock_agents ["conditional_agent" ].stream_async .assert_called_once ()
200
- mock_agents ["final_agent" ].stream_async .assert_called_once ()
201
- mock_agents ["no_metrics_agent" ].stream_async .assert_called_once ()
202
- mock_agents ["partial_metrics_agent" ].stream_async .assert_called_once ()
203
- string_content_agent .stream_async .assert_called_once ()
204
- mock_agents ["blocked_agent" ].stream_async .assert_not_called ()
199
+ mock_agents ["conditional_agent" ].invoke_async .assert_called_once ()
200
+ mock_agents ["final_agent" ].invoke_async .assert_called_once ()
201
+ mock_agents ["no_metrics_agent" ].invoke_async .assert_called_once ()
202
+ mock_agents ["partial_metrics_agent" ].invoke_async .assert_called_once ()
203
+ string_content_agent .invoke_async .assert_called_once ()
204
+ mock_agents ["blocked_agent" ].invoke_async .assert_not_called ()
205
205
206
206
# Verify metrics aggregation
207
207
assert result .accumulated_usage ["totalTokens" ] > 0
@@ -261,12 +261,10 @@ async def test_graph_execution_with_failures(mock_strands_tracer, mock_use_span)
261
261
failing_agent .id = "fail_node"
262
262
failing_agent .__call__ = Mock (side_effect = Exception ("Simulated failure" ))
263
263
264
- # Create a proper failing async generator for stream_async
265
- async def mock_stream_failure (* args , ** kwargs ):
264
+ async def mock_invoke_failure (* args , ** kwargs ):
266
265
raise Exception ("Simulated failure" )
267
- yield # This will never be reached
268
266
269
- failing_agent .stream_async = mock_stream_failure
267
+ failing_agent .invoke_async = mock_invoke_failure
270
268
271
269
success_agent = create_mock_agent ("success_agent" , "Success" )
272
270
@@ -301,7 +299,7 @@ async def test_graph_edge_cases(mock_strands_tracer, mock_use_span):
301
299
result = await graph .invoke_async ([{"text" : "Original task" }])
302
300
303
301
# Verify entry node was called with original task
304
- entry_agent .stream_async .assert_called_once_with ([{"text" : "Original task" }])
302
+ entry_agent .invoke_async .assert_called_once_with ([{"text" : "Original task" }])
305
303
assert result .status == Status .COMPLETED
306
304
mock_strands_tracer .start_multiagent_span .assert_called ()
307
305
mock_use_span .assert_called_once ()
@@ -482,8 +480,8 @@ def test_graph_synchronous_execution(mock_strands_tracer, mock_use_span, mock_ag
482
480
assert result .execution_order [1 ].node_id == "final_agent"
483
481
484
482
# Verify agent calls
485
- mock_agents ["start_agent" ].stream_async .assert_called_once ()
486
- mock_agents ["final_agent" ].stream_async .assert_called_once ()
483
+ mock_agents ["start_agent" ].invoke_async .assert_called_once ()
484
+ mock_agents ["final_agent" ].invoke_async .assert_called_once ()
487
485
488
486
# Verify return type is GraphResult
489
487
assert isinstance (result , GraphResult )
0 commit comments