Skip to content

Commit 5a17523

Browse files
authored
Fix: durable agent context (#223)
* feat: wip on orchestrator state fixing + tracing Signed-off-by: Samantha Coyle <[email protected]> * fix: separate ex/in-ternal triggers + wip fix orchestrators Signed-off-by: Samantha Coyle <[email protected]> * fix: ensure progress on substeps/steps Signed-off-by: Samantha Coyle <[email protected]> * fix: give orchestrators ability to pick up where they left off using same session id Signed-off-by: Samantha Coyle <[email protected]> * style: make linter happy Signed-off-by: Samantha Coyle <[email protected]> * fix: rm extra edge check since captured elsewhere Signed-off-by: Samantha Coyle <[email protected]> * feat: add session context for long term context on durable agents Signed-off-by: Samantha Coyle <[email protected]> * fix: address final merge conflict Signed-off-by: Samantha Coyle <[email protected]> * style: tox -e ruff Signed-off-by: Samantha Coyle <[email protected]> * fix: rm local changes on requirements file for my testing Signed-off-by: Samantha Coyle <[email protected]> * style: tox -e flake8 fix Signed-off-by: Samantha Coyle <[email protected]> * fix: updates for tests too Signed-off-by: Samantha Coyle <[email protected]> * style: add todo comment Signed-off-by: Samantha Coyle <[email protected]> --------- Signed-off-by: Samantha Coyle <[email protected]>
1 parent d4b46a6 commit 5a17523

File tree

4 files changed

+106
-49
lines changed

4 files changed

+106
-49
lines changed

dapr_agents/agents/durableagent/agent.py

Lines changed: 64 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from dapr_agents.types.workflow import DaprWorkflowStatus
2121
from dapr_agents.workflow.agentic import AgenticWorkflow
2222
from dapr_agents.workflow.decorators import message_router, task, workflow
23+
from dapr_agents.memory import ConversationDaprStateMemory
2324

2425
from .schemas import (
2526
AgentTaskResponse,
@@ -59,6 +60,12 @@ class DurableAgent(AgenticWorkflow, AgentBase):
5960
default=None,
6061
description="The current workflow instance ID for this agent.",
6162
)
63+
memory: ConversationDaprStateMemory = Field(
64+
default_factory=lambda: ConversationDaprStateMemory(
65+
store_name="workflowstatestore", session_id="durable_agent_session"
66+
),
67+
description="Persistent memory with session-based state hydration.",
68+
)
6269

6370
@model_validator(mode="before")
6471
def set_agent_and_topic_name(cls, values: dict):
@@ -86,19 +93,23 @@ def model_post_init(self, __context: Any) -> None:
8693
if not self.state:
8794
self.state = {"instances": {}}
8895

89-
# Load the current workflow instance ID from state if it exists
96+
# Load the current workflow instance ID from state using session_id
9097
logger.debug(f"State after loading: {self.state}")
9198
if self.state and self.state.get("instances"):
9299
logger.debug(f"Found {len(self.state['instances'])} instances in state")
93100
for instance_id, instance_data in self.state["instances"].items():
94101
stored_workflow_name = instance_data.get("workflow_name")
102+
stored_session_id = instance_data.get("session_id")
95103
logger.debug(
96-
f"Instance {instance_id}: workflow_name={stored_workflow_name}, current_workflow_name={self._workflow_name}"
104+
f"Instance {instance_id}: workflow_name={stored_workflow_name}, session_id={stored_session_id}, current_workflow_name={self._workflow_name}, current_session_id={self.memory.session_id}"
97105
)
98-
if stored_workflow_name == self._workflow_name:
106+
if (
107+
stored_workflow_name == self._workflow_name
108+
and stored_session_id == self.memory.session_id
109+
):
99110
self.workflow_instance_id = instance_id
100111
logger.debug(
101-
f"Loaded current workflow instance ID from state: {instance_id}"
112+
f"Loaded current workflow instance ID from state using session_id: {instance_id}"
102113
)
103114
break
104115
else:
@@ -256,9 +267,10 @@ def tool_calling_workflow(self, ctx: DaprWorkflowContext, message: TriggerAction
256267
"tool_call": tc,
257268
"instance_id": ctx.instance_id,
258269
"time": ctx.current_utc_datetime.isoformat(),
270+
"execution_order": i, # Add ordering information
259271
},
260272
)
261-
for tc in tool_calls
273+
for i, tc in enumerate(tool_calls)
262274
]
263275
yield self.when_all(parallel)
264276

@@ -389,6 +401,7 @@ def record_initial_entry(
389401
"workflow_instance_id": instance_id,
390402
"triggering_workflow_instance_id": triggering_workflow_instance_id,
391403
"workflow_name": self._workflow_name,
404+
"session_id": self.memory.session_id,
392405
"start_time": start_time_str,
393406
"trace_context": trace_context,
394407
"status": DaprWorkflowStatus.RUNNING.value,
@@ -430,6 +443,7 @@ def _ensure_instance_exists(
430443
"workflow_instance_id": instance_id,
431444
"triggering_workflow_instance_id": triggering_workflow_instance_id,
432445
"workflow_name": self._workflow_name,
446+
"session_id": self.memory.session_id,
433447
"messages": [],
434448
"tool_history": [],
435449
"status": DaprWorkflowStatus.RUNNING.value,
@@ -581,6 +595,7 @@ def _create_tool_message_objects(self, tool_result: Dict[str, Any]) -> tuple:
581595
tool_call_id=tool_result["tool_call_id"],
582596
name=tool_result["tool_name"],
583597
content=tool_result["execution_result"],
598+
role="tool",
584599
)
585600
agent_msg = DurableAgentMessage(**tool_msg.model_dump())
586601
tool_history_entry = ToolExecutionRecord(**tool_result)
@@ -656,7 +671,11 @@ def _get_last_message_from_state(
656671

657672
@task
658673
async def run_tool(
659-
self, tool_call: Dict[str, Any], instance_id: str, time: datetime
674+
self,
675+
tool_call: Dict[str, Any],
676+
instance_id: str,
677+
time: datetime,
678+
execution_order: int = 0,
660679
) -> Dict[str, Any]:
661680
"""
662681
Executes a tool call atomically by invoking the specified function with the provided arguments
@@ -859,7 +878,7 @@ async def process_broadcast_message(self, message: BroadcastMessage):
859878

860879
# Start the agent's workflow
861880
await self.run_and_monitor_workflow_async(
862-
workflow="ToolCallingWorkflow", input=trigger_message
881+
workflow="AgenticWorkflow", input=trigger_message
863882
)
864883

865884
except Exception as e:
@@ -871,9 +890,9 @@ def _construct_messages_with_instance_history(
871890
self, instance_id: str, input_data: Union[str, Dict[str, Any]]
872891
) -> List[Dict[str, Any]]:
873892
"""
874-
Construct messages using instance-specific chat history instead of global memory.
875-
This ensures proper message sequence for tool calls and prevents OpenAI API errors
876-
in the event an app gets terminated or restarts while the workflow is running.
893+
Construct messages using instance-specific chat history and persistent memory.
894+
This ensures proper message sequence for tool calls and maintains conversation
895+
history across workflow executions using the session_id.
877896
878897
Args:
879898
instance_id: The workflow instance ID
@@ -882,12 +901,13 @@ def _construct_messages_with_instance_history(
882901
Returns:
883902
List of formatted messages with proper sequence
884903
"""
904+
additional_context_messages: List[Dict[str, Any]] = []
885905
if not self.prompt_template:
886906
raise ValueError(
887907
"Prompt template must be initialized before constructing messages."
888908
)
889909

890-
# Get instance-specific chat history instead of global memory
910+
# Get instance-specific chat history
891911
if self.state is None:
892912
logger.warning(
893913
f"Agent state is None for instance {instance_id}, initializing empty state"
@@ -900,48 +920,51 @@ def _construct_messages_with_instance_history(
900920
else:
901921
instance_messages = []
902922

903-
# Always include long-term memory (chat_history) for context
904-
# This ensures agents have access to broadcast messages and persistent context
905-
long_term_memory_data = self.state.get("chat_history", [])
923+
# Get messages from persistent memory (session-based, cross-workflow)
924+
persistent_memory_messages = []
925+
try:
926+
persistent_memory_messages = self.memory.get_messages()
927+
logger.info(
928+
f"Retrieved {len(persistent_memory_messages)} messages for session {self.memory.session_id}"
929+
)
930+
except Exception as e:
931+
logger.warning(f"Failed to retrieve persistent memory: {e}")
906932

907-
# Convert long-term memory to dict format for LLM consumption
933+
# Get long-term memory from workflow state (for broadcast messages and persistent context)
934+
long_term_memory_data = self.state.get("chat_history", [])
908935
long_term_memory_messages = []
909936
for msg in long_term_memory_data:
910937
if isinstance(msg, dict):
911938
long_term_memory_messages.append(msg)
912939
elif hasattr(msg, "model_dump"):
913940
long_term_memory_messages.append(msg.model_dump())
914941

915-
# For broadcast-triggered workflows, also include additional context memory
916-
source = instance_data.get("source") if instance_data else None
917-
additional_context_messages = []
918-
if source and source != "direct":
919-
# Include additional context memory for broadcast-triggered workflows
920-
context_memory_data = self.memory.get_messages()
921-
for msg in context_memory_data:
922-
if isinstance(msg, dict):
923-
additional_context_messages.append(msg)
924-
elif hasattr(msg, "model_dump"):
925-
additional_context_messages.append(msg.model_dump())
926-
927-
# Build chat history with:
928-
# 1. Long-term memory (persistent context, broadcast messages)
929-
# 2. Short-term instance messages (current workflow specific)
930-
# 3. Additional context memory (for broadcast-triggered workflows)
942+
# Build chat history with proper context and order
931943
chat_history = []
932944

933-
# Add long-term memory first (broadcast messages, persistent context)
934-
chat_history.extend(long_term_memory_messages)
945+
# First add persistent memory and long-term memory as user messages for context
946+
# This ensures we have cross-workflow context but doesn't interfere with tool state order
947+
for msg in persistent_memory_messages + long_term_memory_messages:
948+
msg_dict = msg.model_dump() if hasattr(msg, "model_dump") else dict(msg)
949+
if msg_dict in chat_history:
950+
continue
951+
# TODO: We need to properly design session-based memory.
952+
# Convert tool-related messages to user messages to avoid conversation order issues
953+
if msg_dict.get("role") in ["tool", "assistant"] and (
954+
msg_dict.get("tool_calls") or msg_dict.get("tool_call_id")
955+
):
956+
msg_dict = {
957+
"role": "user",
958+
"content": f"[Previous {msg_dict['role']} message: {msg_dict.get('content', '')}]",
959+
}
960+
chat_history.append(msg_dict)
935961

936-
# Add short-term instance messages (current workflow)
962+
# Then add instance messages in their original form to maintain tool state
937963
for msg in instance_messages:
938-
if isinstance(msg, dict):
939-
chat_history.append(msg)
940-
else:
941-
# Convert DurableAgentMessage to dict if needed
942-
chat_history.append(
943-
msg.model_dump() if hasattr(msg, "model_dump") else dict(msg)
944-
)
964+
msg_dict = msg.model_dump() if hasattr(msg, "model_dump") else dict(msg)
965+
if msg_dict in chat_history:
966+
continue
967+
chat_history.append(msg_dict)
945968

946969
# Add additional context memory last (for broadcast-triggered workflows)
947970
chat_history.extend(additional_context_messages)

quickstarts/03-agent-tool-call/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ arize-phoenix>=11.22.0,<12.0.0
33
arize-phoenix-otel>=0.12.0,<0.13.0
44
opentelemetry-instrumentation>=0.56b0
55
# For local development use local changes by commenting out the dapr-agents line above and uncommenting the line below:
6-
# -e ../../
6+
# -e ../../[observability]

quickstarts/05-multi-agent-workflows/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ arize-phoenix>=11.22.0,<12.0.0
33
arize-phoenix-otel>=0.12.0,<0.13.0
44
opentelemetry-instrumentation>=0.56b0
55
# For local development use local changes by commenting out the dapr-agents line above and uncommenting the line below:
6-
# -e ../../[observability[
6+
# -e ../../[observability]

tests/agents/durableagent/test_durable_agent.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import asyncio
66
import os
77
from typing import Any
8-
from unittest.mock import AsyncMock, Mock, patch
8+
from unittest.mock import AsyncMock, Mock, patch, MagicMock
99

1010
import pytest
1111
from dapr.ext.workflow import DaprWorkflowContext
@@ -21,7 +21,7 @@
2121
DurableAgentWorkflowState,
2222
)
2323
from dapr_agents.llm import OpenAIChatClient
24-
from dapr_agents.memory import ConversationListMemory
24+
from dapr_agents.memory import ConversationDaprStateMemory
2525
from dapr_agents.tool.base import AgentTool
2626
from dapr_agents.types import (
2727
AssistantMessage,
@@ -104,13 +104,43 @@ def mock_register_agentic_system(self):
104104
yield
105105

106106

107+
class MockDaprClient:
108+
"""Mock DaprClient that supports context manager protocol"""
109+
110+
def __init__(self):
111+
self.get_state = MagicMock(return_value=Mock(data=None, json=lambda: {}))
112+
self.save_state = MagicMock()
113+
self.delete_state = MagicMock()
114+
self.query_state = MagicMock()
115+
116+
def __enter__(self):
117+
return self
118+
119+
def __exit__(self, exc_type, exc_val, exc_tb):
120+
pass
121+
122+
def __call__(self, *args, **kwargs):
123+
return self
124+
125+
107126
class TestDurableAgent:
108127
"""Test cases for the DurableAgent class."""
109128

110129
@pytest.fixture(autouse=True)
111-
def setup_env(self):
112-
"""Set up environment variables for testing."""
130+
def setup_env(self, monkeypatch):
131+
"""Set up environment variables and mocks for testing."""
113132
os.environ["OPENAI_API_KEY"] = "test-api-key"
133+
134+
# Mock DaprClient to use our context manager supporting mock
135+
mock_client = MockDaprClient()
136+
mock_client.get_state.return_value = Mock(data=None) # Default empty state
137+
138+
# Patch both the client import locations
139+
monkeypatch.setattr("dapr.clients.DaprClient", lambda: mock_client)
140+
monkeypatch.setattr(
141+
"dapr_agents.storage.daprstores.statestore.DaprClient", lambda: mock_client
142+
)
143+
114144
yield
115145
if "OPENAI_API_KEY" in os.environ:
116146
del os.environ["OPENAI_API_KEY"]
@@ -158,7 +188,9 @@ def basic_durable_agent(self, mock_llm):
158188
goal="Help with testing",
159189
instructions=["Be helpful", "Test things"],
160190
llm=mock_llm,
161-
memory=ConversationListMemory(),
191+
memory=ConversationDaprStateMemory(
192+
store_name="teststatestore", session_id="test_session"
193+
),
162194
max_iterations=5,
163195
state_store_name="teststatestore",
164196
message_bus_name="testpubsub",
@@ -174,7 +206,9 @@ def durable_agent_with_tools(self, mock_llm, mock_tool):
174206
goal="Execute tools",
175207
instructions=["Use tools when needed"],
176208
llm=mock_llm,
177-
memory=ConversationListMemory(),
209+
memory=ConversationDaprStateMemory(
210+
store_name="teststatestore", session_id="test_session"
211+
),
178212
tools=[mock_tool],
179213
max_iterations=5,
180214
state_store_name="teststatestore",

0 commit comments

Comments
 (0)