Skip to content

Commit 83f291e

Browse files
authored
fix: decouple memory + state to work with one and/or both (#236)
* fix: decouple memory + state to work with one and/or both for agent+orchestrator Signed-off-by: Samantha Coyle <[email protected]> * feat: add memory on llm orchestrator quickstart Signed-off-by: Samantha Coyle <[email protected]> * style: tox -e ruff Signed-off-by: Samantha Coyle <[email protected]> --------- Signed-off-by: Samantha Coyle <[email protected]>
1 parent 779ecbc commit 83f291e

File tree

6 files changed

+111
-44
lines changed

6 files changed

+111
-44
lines changed

dapr_agents/agents/durableagent/agent.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,8 @@ class DurableAgent(AgenticWorkflow, AgentBase):
6060
default=None,
6161
description="The current workflow instance ID for this agent.",
6262
)
63-
memory: ConversationDaprStateMemory = Field(
64-
default_factory=lambda: ConversationDaprStateMemory(
65-
store_name="workflowstatestore", session_id="durable_agent_session"
66-
),
63+
memory: Optional[ConversationDaprStateMemory] = Field(
64+
default=None,
6765
description="Persistent memory with session-based state hydration.",
6866
)
6967

@@ -93,19 +91,31 @@ def model_post_init(self, __context: Any) -> None:
9391
if not self.state:
9492
self.state = {"instances": {}}
9593

94+
if self.memory is not None:
95+
self.memory = ConversationDaprStateMemory(
96+
store_name=self.memory.store_name,
97+
session_id=f"{self.name or 'agent'}_session",
98+
)
99+
logger.info(f"Initialized memory with store name: {self.memory.store_name}")
100+
96101
# Load the current workflow instance ID from state using session_id
97102
logger.debug(f"State after loading: {self.state}")
98103
if self.state and self.state.get("instances"):
99104
logger.debug(f"Found {len(self.state['instances'])} instances in state")
105+
current_session_id = (
106+
self.memory.session_id
107+
if self.memory
108+
else f"{self.name}_default_session"
109+
)
100110
for instance_id, instance_data in self.state["instances"].items():
101111
stored_workflow_name = instance_data.get("workflow_name")
102112
stored_session_id = instance_data.get("session_id")
103113
logger.debug(
104-
f"Instance {instance_id}: workflow_name={stored_workflow_name}, session_id={stored_session_id}, current_workflow_name={self._workflow_name}, current_session_id={self.memory.session_id}"
114+
f"Instance {instance_id}: workflow_name={stored_workflow_name}, session_id={stored_session_id}, current_workflow_name={self._workflow_name}, current_session_id={current_session_id}"
105115
)
106116
if (
107117
stored_workflow_name == self._workflow_name
108-
and stored_session_id == self.memory.session_id
118+
and stored_session_id == current_session_id
109119
):
110120
self.workflow_instance_id = instance_id
111121
logger.debug(
@@ -401,7 +411,9 @@ def record_initial_entry(
401411
"workflow_instance_id": instance_id,
402412
"triggering_workflow_instance_id": triggering_workflow_instance_id,
403413
"workflow_name": self._workflow_name,
404-
"session_id": self.memory.session_id,
414+
"session_id": self.memory.session_id
415+
if self.memory
416+
else f"{self.name}_default_session",
405417
"start_time": start_time_str,
406418
"trace_context": trace_context,
407419
"status": DaprWorkflowStatus.RUNNING.value,
@@ -443,7 +455,9 @@ def _ensure_instance_exists(
443455
"workflow_instance_id": instance_id,
444456
"triggering_workflow_instance_id": triggering_workflow_instance_id,
445457
"workflow_name": self._workflow_name,
446-
"session_id": self.memory.session_id,
458+
"session_id": self.memory.session_id
459+
if self.memory
460+
else f"{self.name}_default_session",
447461
"messages": [],
448462
"tool_history": [],
449463
"status": DaprWorkflowStatus.RUNNING.value,
@@ -462,7 +476,8 @@ def _process_user_message(
462476
return
463477

464478
user_msg = UserMessage(content=user_message_copy.get("content", ""))
465-
self.memory.add_message(user_msg)
479+
if self.memory:
480+
self.memory.add_message(user_msg)
466481

467482
msg_object = DurableAgentMessage(**user_message_copy)
468483
inst = self.state["instances"][instance_id]
@@ -502,7 +517,8 @@ def _save_assistant_message(
502517
if not message_exists:
503518
messages_list.append(agent_msg.model_dump(mode="json"))
504519
inst["last_message"] = agent_msg.model_dump(mode="json")
505-
self.memory.add_message(AssistantMessage(**assistant_message))
520+
if self.memory:
521+
self.memory.add_message(AssistantMessage(**assistant_message))
506522
self.save_state()
507523

508524
def _print_llm_interaction_messages(
@@ -649,8 +665,9 @@ def _update_agent_memory_and_history(
649665
# Update tool history and memory of agent (only if new)
650666
# Note: Memory updates are handled at workflow level to avoid replay issues
651667
self.tool_history.append(tool_history_entry)
652-
# Add the tool message to the agent's memory
653-
self.memory.add_message(tool_message)
668+
# Add the tool message to the agent's memory if enabled
669+
if self.memory:
670+
self.memory.add_message(tool_message)
654671

655672
def _get_last_message_from_state(
656673
self, instance_id: str
@@ -850,8 +867,9 @@ async def process_broadcast_message(self, message: BroadcastMessage):
850867
logger.debug(
851868
f"{self.name} processing broadcast message from '{source}'. Content: {message_content}"
852869
)
853-
# Store the message in local memory
854-
self.memory.add_message(message)
870+
# Store the message in local memory if enabled
871+
if self.memory:
872+
self.memory.add_message(message)
855873

856874
# Define DurableAgentMessage object for state persistence
857875
agent_msg = DurableAgentMessage(**message.model_dump())
@@ -920,15 +938,18 @@ def _construct_messages_with_instance_history(
920938
else:
921939
instance_messages = []
922940

923-
# Get messages from persistent memory (session-based, cross-workflow)
941+
# Get messages from persistent memory (session-based, cross-workflow) if memory is enabled
924942
persistent_memory_messages = []
925-
try:
926-
persistent_memory_messages = self.memory.get_messages()
927-
logger.info(
928-
f"Retrieved {len(persistent_memory_messages)} messages for session {self.memory.session_id}"
929-
)
930-
except Exception as e:
931-
logger.warning(f"Failed to retrieve persistent memory: {e}")
943+
if self.memory:
944+
try:
945+
persistent_memory_messages = self.memory.get_messages()
946+
logger.info(
947+
f"Retrieved {len(persistent_memory_messages)} messages for session {self.memory.session_id}"
948+
)
949+
except Exception as e:
950+
logger.warning(f"Failed to retrieve persistent memory: {e}")
951+
else:
952+
logger.debug("Memory is disabled, skipping persistent memory retrieval")
932953

933954
# Get long-term memory from workflow state (for broadcast messages and persistent context)
934955
long_term_memory_data = self.state.get("chat_history", [])

dapr_agents/memory/daprstatestore.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def add_message(self, message: Union[Dict[str, Any], BaseMessage]) -> None:
8989
message_key = self._get_message_key(message_id)
9090
message.update(
9191
{
92-
"sessionId": self.session_id,
9392
"createdAt": datetime.now().isoformat() + "Z",
9493
}
9594
)

dapr_agents/workflow/orchestrators/llm/orchestrator.py

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,8 @@ class LLMOrchestrator(OrchestratorWorkflowBase):
5454
default=None,
5555
description="The current workflow instance ID for this orchestrator.",
5656
)
57-
memory: ConversationDaprStateMemory = Field(
58-
default_factory=lambda: ConversationDaprStateMemory(
59-
store_name="workflowstatestore", session_id="orchestrator_session"
60-
),
57+
memory: Optional[ConversationDaprStateMemory] = Field(
58+
default=None,
6159
description="Persistent memory with session-based state hydration.",
6260
)
6361

@@ -77,14 +75,25 @@ def model_post_init(self, __context: Any) -> None:
7775
if not self.state:
7876
logger.debug("No state found, initializing empty state")
7977
self.state = {"instances": {}}
80-
else:
81-
logger.debug(f"State loaded successfully: {self.state}")
78+
79+
if self.memory is not None:
80+
self.memory = ConversationDaprStateMemory(
81+
store_name=self.memory.store_name,
82+
session_id=f"{self.name or 'orchestrator'}_session",
83+
)
84+
85+
print(f"sam memory store name is {self.memory.store_name}")
86+
logger.info(f"sam memory store name is {self.memory.store_name}")
8287

8388
# Load the current workflow instance ID from state using session_id)
8489
if self.state and self.state.get("instances"):
8590
logger.debug(f"Found {len(self.state['instances'])} instances in state")
8691

87-
current_session_id = self.memory.session_id
92+
current_session_id = (
93+
self.memory.session_id
94+
if self.memory
95+
else f"{self.name}_default_session"
96+
)
8897
for instance_id, instance_data in self.state["instances"].items():
8998
stored_workflow_name = instance_data.get("workflow_name")
9099
stored_session_id = instance_data.get("session_id")
@@ -654,10 +663,15 @@ async def initialize_workflow_with_plan(
654663
"instances", {}
655664
).items():
656665
stored_session_id = instance_data.get("session_id")
657-
if stored_session_id == self.memory.session_id:
666+
current_session_id = (
667+
self.memory.session_id
668+
if self.memory
669+
else f"{self.name}_default_session"
670+
)
671+
if stored_session_id == current_session_id:
658672
existing_plan = instance_data.get("plan", [])
659673
logger.debug(
660-
f"Found existing plan for session_id {self.memory.session_id} in instance {stored_instance_id}"
674+
f"Found existing plan for session_id {current_session_id} in instance {stored_instance_id}"
661675
)
662676
break
663677

@@ -686,23 +700,36 @@ async def initialize_workflow_with_plan(
686700
structured_mode="json",
687701
)
688702

689-
# Parse the response - now we get a Pydantic model directly
690-
if hasattr(response, "choices") and response.choices:
691-
# If it's still a raw response, parse it
692-
plan_data = response.choices[0].message.content
693-
logger.debug(f"Plan generation response: {plan_data}")
694-
plan_dict = json.loads(plan_data)
695-
# Convert raw dictionaries to Pydantic models
703+
# Parse the response
704+
if isinstance(response, str):
705+
# If it's a raw JSON string
706+
plan_dict = json.loads(response)
696707
plan_objects = [
697708
PlanStep(**step_dict)
698709
for step_dict in plan_dict.get("objects", [])
699710
]
711+
elif hasattr(response, "choices") and response.choices:
712+
# If it's an OpenAI-style response with multiple choices
713+
plan_objects = []
714+
for choice in response.choices:
715+
plan_data = choice.message.content
716+
if isinstance(plan_data, str):
717+
plan_dict = json.loads(plan_data)
718+
plan_objects.extend(
719+
PlanStep(**step_dict)
720+
for step_dict in plan_dict.get("objects", [])
721+
)
722+
elif hasattr(plan_data, "objects"):
723+
plan_objects.extend(plan_data.objects)
700724
else:
701725
# If it's already a Pydantic model
702726
plan_objects = (
703727
response.objects if hasattr(response, "objects") else []
704728
)
705-
logger.debug(f"Plan generation response (Pydantic): {plan_objects}")
729+
730+
logger.debug(
731+
f"Plan generation response with {len(plan_objects)} objects: {plan_objects}"
732+
)
706733

707734
# Format and broadcast message
708735
plan_dicts = self._convert_plan_objects_to_dicts(plan_objects)
@@ -1030,8 +1057,9 @@ async def update_workflow_state(
10301057
workflow_entry["messages"].append(serialized_message)
10311058
workflow_entry["last_message"] = serialized_message
10321059

1033-
# Update the local chat history
1034-
self.memory.add_message(message)
1060+
# Update the local chat history if memory is enabled
1061+
if self.memory:
1062+
self.memory.add_message(message)
10351063

10361064
if final_output is not None:
10371065
workflow_entry["output"] = final_output
@@ -1041,7 +1069,9 @@ async def update_workflow_state(
10411069
# Store workflow instance ID, workflow name, and session_id for session-based state rehydration
10421070
workflow_entry["workflow_instance_id"] = instance_id
10431071
workflow_entry["workflow_name"] = self._workflow_name
1044-
workflow_entry["session_id"] = self.memory.session_id
1072+
workflow_entry["session_id"] = (
1073+
self.memory.session_id if self.memory else f"{self.name}_default_session"
1074+
)
10451075

10461076
# Persist updated state
10471077
self.save_state()

quickstarts/03-agent-tool-call/weather_durable_agent_tracing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ async def main():
3737
message_bus_name="messagepubsub",
3838
state_store_name="workflowstatestore",
3939
agents_registry_store_name="agentstatestore",
40-
history_store_name="historystore",
4140
)
4241

4342
await AIAgent.run("What is the weather in Virginia, New York and Washington DC?")
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
apiVersion: dapr.io/v1alpha1
2+
kind: Component
3+
metadata:
4+
name: conversationstore
5+
spec:
6+
type: state.redis
7+
version: v1
8+
metadata:
9+
- name: redisHost
10+
value: localhost:6379
11+
- name: redisPassword
12+
value: ""
13+
- name: keyPrefix
14+
value: none

quickstarts/05-multi-agent-workflows/services/workflow-llm/app.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dapr_agents import LLMOrchestrator
22
from dapr_agents.llm import DaprChatClient
3+
from dapr_agents.memory import ConversationDaprStateMemory
34
from dotenv import load_dotenv
45
import asyncio
56
import logging
@@ -34,6 +35,9 @@ async def main():
3435
agents_registry_store_name="agentstatestore",
3536
agents_registry_key="agents_registry",
3637
broadcast_topic_name="beacon_channel",
38+
memory=ConversationDaprStateMemory(
39+
store_name="conversationstore", session_id="myuniqueid"
40+
),
3741
max_iterations=3,
3842
).as_service(port=8004)
3943

0 commit comments

Comments
 (0)