Skip to content

Commit a3e25e2

Browse files
committed
fix: update DurableAgent activity test signatures to (ctx, payload) and use AgentWorkflowEntry
Signed-off-by: Roberto Rodriguez <[email protected]>
1 parent 0347acf commit a3e25e2

File tree

1 file changed

+119
-71
lines changed

1 file changed

+119
-71
lines changed

tests/agents/durableagent/test_durable_agent.py

Lines changed: 119 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
# Right now we have to do a bunch of patching at the class-level instead of patching at the instance-level.
33
# In future, we should do dependency injection instead of patching at the class-level to make it easier to test.
44
# This applies to all areas in this file where we have with patch.object()...
5-
import asyncio
65
import os
7-
from typing import Any
86
from unittest.mock import AsyncMock, Mock, patch, MagicMock
97

108
import pytest
@@ -314,11 +312,11 @@ async def test_tool_calling_workflow_initialization(
314312
"trace_context": None,
315313
}
316314

317-
workflow_gen = basic_durable_agent.tool_calling_workflow(
315+
workflow_gen = basic_durable_agent.agent_workflow(
318316
mock_workflow_context, message
319317
)
320318
try:
321-
await workflow_gen.__next__()
319+
await workflow_gen.__anext__()
322320
except StopAsyncIteration:
323321
pass
324322

@@ -366,8 +364,17 @@ async def test_call_llm_activity(self, basic_durable_agent):
366364
test_time = datetime.fromisoformat(
367365
"2024-01-01T00:00:00Z".replace("Z", "+00:00")
368366
)
369-
assistant_dict = await basic_durable_agent.call_llm(
370-
instance_id, test_time, "Test task"
367+
368+
# Mock the activity context
369+
mock_ctx = Mock()
370+
371+
assistant_dict = basic_durable_agent.call_llm(
372+
mock_ctx,
373+
{
374+
"instance_id": instance_id,
375+
"time": test_time.isoformat(),
376+
"task": "Test task"
377+
}
371378
)
372379
# The dict dumped from AssistantMessage should have our content
373380
assert assistant_dict["content"] == "Test response"
@@ -407,33 +414,49 @@ async def test_send_response_back_activity(self, basic_durable_agent):
407414
@pytest.mark.asyncio
408415
async def test_finish_workflow_activity(self, basic_durable_agent):
409416
"""Test finishing workflow activity."""
417+
from datetime import datetime, timezone
418+
410419
instance_id = "test-instance-123"
411420
final_output = "Final response"
412-
basic_durable_agent.state["instances"] = {
413-
instance_id: {
414-
"input": "Test task",
415-
"source": "test_source",
416-
"triggering_workflow_instance_id": None,
417-
"workflow_instance_id": instance_id,
418-
"workflow_name": "AgenticWorkflow",
419-
"status": "RUNNING",
420-
"messages": [],
421-
"tool_history": [],
422-
"end_time": None,
423-
"trace_context": None,
424-
}
425-
}
426-
427-
basic_durable_agent.finalize_workflow(
428-
instance_id, final_output, "2024-01-01T00:00:00Z"
421+
# Set up state in the state model using AgentWorkflowEntry
422+
if not hasattr(basic_durable_agent._state_model, 'instances'):
423+
basic_durable_agent._state_model.instances = {}
424+
425+
basic_durable_agent._state_model.instances[instance_id] = AgentWorkflowEntry(
426+
input_value="Test task",
427+
source="test_source",
428+
triggering_workflow_instance_id=None,
429+
workflow_instance_id=instance_id,
430+
workflow_name="AgenticWorkflow",
431+
status="RUNNING",
432+
messages=[],
433+
tool_history=[],
434+
end_time=None,
435+
start_time=datetime.now(timezone.utc),
429436
)
430-
instance_data = basic_durable_agent.state["instances"][instance_id]
431-
assert instance_data["output"] == final_output
432-
assert instance_data["end_time"] is not None
437+
438+
# Mock the activity context and save_state
439+
mock_ctx = Mock()
440+
441+
with patch.object(basic_durable_agent, 'save_state'):
442+
basic_durable_agent.finalize_workflow(
443+
mock_ctx,
444+
{
445+
"instance_id": instance_id,
446+
"final_output": final_output,
447+
"end_time": "2024-01-01T00:00:00Z",
448+
"triggering_workflow_instance_id": None
449+
}
450+
)
451+
entry = basic_durable_agent._state_model.instances[instance_id]
452+
assert entry.output == final_output
453+
assert entry.end_time is not None
433454

434455
@pytest.mark.asyncio
435456
async def test_run_tool(self, basic_durable_agent, mock_tool):
436457
"""Test that run_tool atomically executes and persists tool results."""
458+
from datetime import datetime, timezone
459+
437460
instance_id = "test-instance-123"
438461
tool_call = {
439462
"id": "call_123",
@@ -442,56 +465,65 @@ async def test_run_tool(self, basic_durable_agent, mock_tool):
442465

443466
# Mock the tool executor
444467
with patch.object(
445-
type(basic_durable_agent._tool_executor), "run_tool", new_callable=AsyncMock
468+
type(basic_durable_agent.tool_executor), "run_tool", new_callable=AsyncMock
446469
) as mock_run_tool:
447470
mock_run_tool.return_value = "tool_result"
448471

449-
# Set up instance state
450-
basic_durable_agent.state["instances"] = {
451-
instance_id: {
452-
"input": "Test task",
453-
"source": "test_source",
454-
"triggering_workflow_instance_id": None,
455-
"workflow_instance_id": instance_id,
456-
"workflow_name": "AgenticWorkflow",
457-
"status": "RUNNING",
458-
"messages": [],
459-
"tool_history": [],
460-
"end_time": None,
461-
"trace_context": None,
462-
}
463-
}
464-
465-
from datetime import datetime
472+
# Set up state in the state model using AgentWorkflowEntry
473+
if not hasattr(basic_durable_agent._state_model, 'instances'):
474+
basic_durable_agent._state_model.instances = {}
475+
476+
basic_durable_agent._state_model.instances[instance_id] = AgentWorkflowEntry(
477+
input_value="Test task",
478+
source="test_source",
479+
triggering_workflow_instance_id=None,
480+
workflow_instance_id=instance_id,
481+
workflow_name="AgenticWorkflow",
482+
status="RUNNING",
483+
messages=[],
484+
tool_history=[],
485+
end_time=None,
486+
start_time=datetime.now(timezone.utc),
487+
)
466488

467489
test_time = datetime.fromisoformat(
468490
"2024-01-01T00:00:00Z".replace("Z", "+00:00")
469491
)
470-
result = await basic_durable_agent.run_tool(
471-
tool_call, instance_id, test_time
472-
)
492+
493+
# Mock the activity context and save_state
494+
mock_ctx = Mock()
495+
496+
with patch.object(basic_durable_agent, 'save_state'):
497+
result = await basic_durable_agent.run_tool(
498+
mock_ctx,
499+
{
500+
"tool_call": tool_call,
501+
"instance_id": instance_id,
502+
"time": test_time.isoformat(),
503+
"order": 1
504+
}
505+
)
473506

474507
# Verify tool was executed and result was returned
475508
assert result["tool_call_id"] == "call_123"
476509
assert result["tool_name"] == "test_tool"
477510
assert result["execution_result"] == "tool_result"
478511

479512
# Verify state was updated atomically
480-
instance_data = basic_durable_agent.state["instances"][instance_id]
481-
assert len(instance_data["messages"]) == 1 # Tool message added
482-
assert (
483-
len(instance_data["tool_history"]) == 1
484-
) # Tool execution record added
513+
entry = basic_durable_agent._state_model.instances[instance_id]
514+
assert len(entry.messages) == 1 # Tool message added
515+
assert len(entry.tool_history) == 1 # Tool execution record added
485516

486517
# Verify tool execution record in tool_history
487-
tool_history_entry = instance_data["tool_history"][0]
488-
assert tool_history_entry["tool_call_id"] == "call_123"
489-
assert tool_history_entry["tool_name"] == "test_tool"
490-
assert tool_history_entry["execution_result"] == "tool_result"
518+
tool_history_entry = entry.tool_history[0]
519+
assert tool_history_entry.tool_call_id == "call_123"
520+
assert tool_history_entry.tool_name == "test_tool"
521+
assert tool_history_entry.execution_result == "tool_result"
491522

492523
# Verify agent-level tool_history was also updated
493524
assert len(basic_durable_agent.tool_history) == 1
494525

526+
@pytest.mark.skip(reason="get_source_or_default() method removed in refactored architecture")
495527
def test_get_source_or_default(self, basic_durable_agent):
496528
"""Test get_source_or_default helper method."""
497529
# Test with valid source
@@ -505,29 +537,45 @@ def test_get_source_or_default(self, basic_durable_agent):
505537

506538
def test_record_initial_entry(self, basic_durable_agent):
507539
"""Test record_initial_entry helper method."""
540+
from datetime import datetime, timezone
541+
508542
instance_id = "test-instance-123"
509543
input_data = "Test task"
510544
source = "test_source"
511545
triggering_workflow_instance_id = "parent-instance-123"
512546
start_time = "2024-01-01T00:00:00Z"
513547

514-
basic_durable_agent.record_initial_entry(
515-
instance_id, input_data, source, triggering_workflow_instance_id, start_time
548+
# First, ensure instance exists with ensure_instance_exists
549+
basic_durable_agent.ensure_instance_exists(
550+
instance_id=instance_id,
551+
input_value=input_data,
552+
triggering_workflow_instance_id=None,
553+
time=datetime.now(timezone.utc)
516554
)
517555

518-
# Verify instance was created
519-
assert instance_id in basic_durable_agent.state["instances"]
520-
instance_data = basic_durable_agent.state["instances"][instance_id]
521-
assert instance_data["input"] == input_data
522-
assert instance_data["source"] == source
523-
assert (
524-
instance_data["triggering_workflow_instance_id"]
525-
== triggering_workflow_instance_id
526-
)
527-
# start_time is stored as string in dict format
528-
assert instance_data["start_time"] == "2024-01-01T00:00:00Z"
529-
assert instance_data["workflow_name"] == "AgenticWorkflow"
530-
assert instance_data["status"] == "running"
556+
# Mock the activity context
557+
mock_ctx = Mock()
558+
559+
with patch.object(basic_durable_agent, 'save_state'):
560+
basic_durable_agent.record_initial_entry(
561+
mock_ctx,
562+
{
563+
"instance_id": instance_id,
564+
"input_value": input_data,
565+
"source": source,
566+
"triggering_workflow_instance_id": triggering_workflow_instance_id,
567+
"start_time": start_time,
568+
"trace_context": None
569+
}
570+
)
571+
572+
# Verify instance was updated
573+
assert instance_id in basic_durable_agent._state_model.instances
574+
entry = basic_durable_agent._state_model.instances[instance_id]
575+
assert entry.input_value == input_data
576+
assert entry.source == source
577+
assert entry.triggering_workflow_instance_id == triggering_workflow_instance_id
578+
assert entry.status.lower() == "running"
531579

532580
def test_ensure_instance_exists(self, basic_durable_agent):
533581
"""Test _ensure_instance_exists helper method."""

0 commit comments

Comments
 (0)