22# Right now we have to do a bunch of patching at the class-level instead of patching at the instance-level.
33# In future, we should do dependency injection instead of patching at the class-level to make it easier to test.
44# This applies to all areas in this file where we have with patch.object()...
5- import asyncio
65import os
7- from typing import Any
86from unittest .mock import AsyncMock , Mock , patch , MagicMock
97
108import pytest
@@ -314,11 +312,11 @@ async def test_tool_calling_workflow_initialization(
314312 "trace_context" : None ,
315313 }
316314
317- workflow_gen = basic_durable_agent .tool_calling_workflow (
315+ workflow_gen = basic_durable_agent .agent_workflow (
318316 mock_workflow_context , message
319317 )
320318 try :
321- await workflow_gen .__next__ ()
319+ await workflow_gen .__anext__ ()
322320 except StopAsyncIteration :
323321 pass
324322
@@ -366,8 +364,17 @@ async def test_call_llm_activity(self, basic_durable_agent):
366364 test_time = datetime .fromisoformat (
367365 "2024-01-01T00:00:00Z" .replace ("Z" , "+00:00" )
368366 )
369- assistant_dict = await basic_durable_agent .call_llm (
370- instance_id , test_time , "Test task"
367+
368+ # Mock the activity context
369+ mock_ctx = Mock ()
370+
371+ assistant_dict = basic_durable_agent .call_llm (
372+ mock_ctx ,
373+ {
374+ "instance_id" : instance_id ,
375+ "time" : test_time .isoformat (),
376+ "task" : "Test task"
377+ }
371378 )
372379 # The dict dumped from AssistantMessage should have our content
373380 assert assistant_dict ["content" ] == "Test response"
@@ -407,33 +414,49 @@ async def test_send_response_back_activity(self, basic_durable_agent):
407414 @pytest .mark .asyncio
408415 async def test_finish_workflow_activity (self , basic_durable_agent ):
409416 """Test finishing workflow activity."""
417+ from datetime import datetime , timezone
418+
410419 instance_id = "test-instance-123"
411420 final_output = "Final response"
412- basic_durable_agent .state ["instances" ] = {
413- instance_id : {
414- "input" : "Test task" ,
415- "source" : "test_source" ,
416- "triggering_workflow_instance_id" : None ,
417- "workflow_instance_id" : instance_id ,
418- "workflow_name" : "AgenticWorkflow" ,
419- "status" : "RUNNING" ,
420- "messages" : [],
421- "tool_history" : [],
422- "end_time" : None ,
423- "trace_context" : None ,
424- }
425- }
426-
427- basic_durable_agent .finalize_workflow (
428- instance_id , final_output , "2024-01-01T00:00:00Z"
421+ # Set up state in the state model using AgentWorkflowEntry
422+ if not hasattr (basic_durable_agent ._state_model , 'instances' ):
423+ basic_durable_agent ._state_model .instances = {}
424+
425+ basic_durable_agent ._state_model .instances [instance_id ] = AgentWorkflowEntry (
426+ input_value = "Test task" ,
427+ source = "test_source" ,
428+ triggering_workflow_instance_id = None ,
429+ workflow_instance_id = instance_id ,
430+ workflow_name = "AgenticWorkflow" ,
431+ status = "RUNNING" ,
432+ messages = [],
433+ tool_history = [],
434+ end_time = None ,
435+ start_time = datetime .now (timezone .utc ),
429436 )
430- instance_data = basic_durable_agent .state ["instances" ][instance_id ]
431- assert instance_data ["output" ] == final_output
432- assert instance_data ["end_time" ] is not None
437+
438+ # Mock the activity context and save_state
439+ mock_ctx = Mock ()
440+
441+ with patch .object (basic_durable_agent , 'save_state' ):
442+ basic_durable_agent .finalize_workflow (
443+ mock_ctx ,
444+ {
445+ "instance_id" : instance_id ,
446+ "final_output" : final_output ,
447+ "end_time" : "2024-01-01T00:00:00Z" ,
448+ "triggering_workflow_instance_id" : None
449+ }
450+ )
451+ entry = basic_durable_agent ._state_model .instances [instance_id ]
452+ assert entry .output == final_output
453+ assert entry .end_time is not None
433454
434455 @pytest .mark .asyncio
435456 async def test_run_tool (self , basic_durable_agent , mock_tool ):
436457 """Test that run_tool atomically executes and persists tool results."""
458+ from datetime import datetime , timezone
459+
437460 instance_id = "test-instance-123"
438461 tool_call = {
439462 "id" : "call_123" ,
@@ -442,56 +465,65 @@ async def test_run_tool(self, basic_durable_agent, mock_tool):
442465
443466 # Mock the tool executor
444467 with patch .object (
445- type (basic_durable_agent ._tool_executor ), "run_tool" , new_callable = AsyncMock
468+ type (basic_durable_agent .tool_executor ), "run_tool" , new_callable = AsyncMock
446469 ) as mock_run_tool :
447470 mock_run_tool .return_value = "tool_result"
448471
449- # Set up instance state
450- basic_durable_agent .state ["instances" ] = {
451- instance_id : {
452- "input" : "Test task" ,
453- "source" : "test_source" ,
454- "triggering_workflow_instance_id" : None ,
455- "workflow_instance_id" : instance_id ,
456- "workflow_name" : "AgenticWorkflow" ,
457- "status" : "RUNNING" ,
458- "messages" : [],
459- "tool_history" : [],
460- "end_time" : None ,
461- "trace_context" : None ,
462- }
463- }
464-
465- from datetime import datetime
472+ # Set up state in the state model using AgentWorkflowEntry
473+ if not hasattr (basic_durable_agent ._state_model , 'instances' ):
474+ basic_durable_agent ._state_model .instances = {}
475+
476+ basic_durable_agent ._state_model .instances [instance_id ] = AgentWorkflowEntry (
477+ input_value = "Test task" ,
478+ source = "test_source" ,
479+ triggering_workflow_instance_id = None ,
480+ workflow_instance_id = instance_id ,
481+ workflow_name = "AgenticWorkflow" ,
482+ status = "RUNNING" ,
483+ messages = [],
484+ tool_history = [],
485+ end_time = None ,
486+ start_time = datetime .now (timezone .utc ),
487+ )
466488
467489 test_time = datetime .fromisoformat (
468490 "2024-01-01T00:00:00Z" .replace ("Z" , "+00:00" )
469491 )
470- result = await basic_durable_agent .run_tool (
471- tool_call , instance_id , test_time
472- )
492+
493+ # Mock the activity context and save_state
494+ mock_ctx = Mock ()
495+
496+ with patch .object (basic_durable_agent , 'save_state' ):
497+ result = await basic_durable_agent .run_tool (
498+ mock_ctx ,
499+ {
500+ "tool_call" : tool_call ,
501+ "instance_id" : instance_id ,
502+ "time" : test_time .isoformat (),
503+ "order" : 1
504+ }
505+ )
473506
474507 # Verify tool was executed and result was returned
475508 assert result ["tool_call_id" ] == "call_123"
476509 assert result ["tool_name" ] == "test_tool"
477510 assert result ["execution_result" ] == "tool_result"
478511
479512 # Verify state was updated atomically
480- instance_data = basic_durable_agent .state ["instances" ][instance_id ]
481- assert len (instance_data ["messages" ]) == 1 # Tool message added
482- assert (
483- len (instance_data ["tool_history" ]) == 1
484- ) # Tool execution record added
513+ entry = basic_durable_agent ._state_model .instances [instance_id ]
514+ assert len (entry .messages ) == 1 # Tool message added
515+ assert len (entry .tool_history ) == 1 # Tool execution record added
485516
486517 # Verify tool execution record in tool_history
487- tool_history_entry = instance_data [ " tool_history" ] [0 ]
488- assert tool_history_entry [ " tool_call_id" ] == "call_123"
489- assert tool_history_entry [ " tool_name" ] == "test_tool"
490- assert tool_history_entry [ " execution_result" ] == "tool_result"
518+ tool_history_entry = entry . tool_history [0 ]
519+ assert tool_history_entry . tool_call_id == "call_123"
520+ assert tool_history_entry . tool_name == "test_tool"
521+ assert tool_history_entry . execution_result == "tool_result"
491522
492523 # Verify agent-level tool_history was also updated
493524 assert len (basic_durable_agent .tool_history ) == 1
494525
526+ @pytest .mark .skip (reason = "get_source_or_default() method removed in refactored architecture" )
495527 def test_get_source_or_default (self , basic_durable_agent ):
496528 """Test get_source_or_default helper method."""
497529 # Test with valid source
@@ -505,29 +537,45 @@ def test_get_source_or_default(self, basic_durable_agent):
505537
506538 def test_record_initial_entry (self , basic_durable_agent ):
507539 """Test record_initial_entry helper method."""
540+ from datetime import datetime , timezone
541+
508542 instance_id = "test-instance-123"
509543 input_data = "Test task"
510544 source = "test_source"
511545 triggering_workflow_instance_id = "parent-instance-123"
512546 start_time = "2024-01-01T00:00:00Z"
513547
514- basic_durable_agent .record_initial_entry (
515- instance_id , input_data , source , triggering_workflow_instance_id , start_time
548+ # First, ensure instance exists with ensure_instance_exists
549+ basic_durable_agent .ensure_instance_exists (
550+ instance_id = instance_id ,
551+ input_value = input_data ,
552+ triggering_workflow_instance_id = None ,
553+ time = datetime .now (timezone .utc )
516554 )
517555
518- # Verify instance was created
519- assert instance_id in basic_durable_agent .state ["instances" ]
520- instance_data = basic_durable_agent .state ["instances" ][instance_id ]
521- assert instance_data ["input" ] == input_data
522- assert instance_data ["source" ] == source
523- assert (
524- instance_data ["triggering_workflow_instance_id" ]
525- == triggering_workflow_instance_id
526- )
527- # start_time is stored as string in dict format
528- assert instance_data ["start_time" ] == "2024-01-01T00:00:00Z"
529- assert instance_data ["workflow_name" ] == "AgenticWorkflow"
530- assert instance_data ["status" ] == "running"
556+ # Mock the activity context
557+ mock_ctx = Mock ()
558+
559+ with patch .object (basic_durable_agent , 'save_state' ):
560+ basic_durable_agent .record_initial_entry (
561+ mock_ctx ,
562+ {
563+ "instance_id" : instance_id ,
564+ "input_value" : input_data ,
565+ "source" : source ,
566+ "triggering_workflow_instance_id" : triggering_workflow_instance_id ,
567+ "start_time" : start_time ,
568+ "trace_context" : None
569+ }
570+ )
571+
572+ # Verify instance was updated
573+ assert instance_id in basic_durable_agent ._state_model .instances
574+ entry = basic_durable_agent ._state_model .instances [instance_id ]
575+ assert entry .input_value == input_data
576+ assert entry .source == source
577+ assert entry .triggering_workflow_instance_id == triggering_workflow_instance_id
578+ assert entry .status .lower () == "running"
531579
532580 def test_ensure_instance_exists (self , basic_durable_agent ):
533581 """Test _ensure_instance_exists helper method."""
0 commit comments