diff --git a/dapr_agents/__init__.py b/dapr_agents/__init__.py index bf3285b7..5ed603a2 100644 --- a/dapr_agents/__init__.py +++ b/dapr_agents/__init__.py @@ -1,6 +1,5 @@ from dapr_agents.agents.agent import Agent from dapr_agents.agents.durableagent import DurableAgent -from dapr_agents.agents.memory_store import MemoryStore from dapr_agents.executors import DockerCodeExecutor, LocalCodeExecutor from dapr_agents.llm.dapr import DaprChatClient from dapr_agents.llm.elevenlabs import ElevenLabsSpeechClient @@ -23,7 +22,6 @@ __all__ = [ "Agent", "DurableAgent", - "MemoryStore", "DockerCodeExecutor", "LocalCodeExecutor", "ElevenLabsSpeechClient", diff --git a/dapr_agents/agents/__init__.py b/dapr_agents/agents/__init__.py index 65da6a97..0c7b8b70 100644 --- a/dapr_agents/agents/__init__.py +++ b/dapr_agents/agents/__init__.py @@ -1,6 +1,5 @@ from .agent.agent import Agent from .base import AgentBase from .durableagent.agent import DurableAgent -from .memory_store import MemoryStore -__all__ = ["AgentBase", "Agent", "DurableAgent", "MemoryStore"] +__all__ = ["AgentBase", "Agent", "DurableAgent"] diff --git a/dapr_agents/agents/agent/agent.py b/dapr_agents/agents/agent/agent.py index 7e5c13ec..8beb1791 100644 --- a/dapr_agents/agents/agent/agent.py +++ b/dapr_agents/agents/agent/agent.py @@ -21,18 +21,6 @@ class Agent(AgentBase): It integrates tools and processes them based on user inputs and task orchestration. """ - def get_chat_history(self, task: Optional[str] = None) -> List[Dict[str, Any]]: - """ - Retrieves the chat history as a list of dictionaries. - - Args: - task (Optional[str]): The task or query provided by the user. - - Returns: - List[Dict[str, Any]]: The chat history as dictionaries. - """ - return self.memory_store.get_messages() - async def run(self, input_data: Optional[Union[str, Dict[str, Any]]] = None) -> Any: """ Runs the agent with the given input, supporting graceful shutdown. @@ -81,7 +69,7 @@ async def _run_agent( ) -> Any: """ Internal method for running the agent logic. - Formats messages, updates conversation history, and drives the conversation loop. + Formats messages, updates memory, and drives the conversation loop. Args: input_data (Optional[Union[str, Dict[str, Any]]]): Input for the agent, can be a string or dict. @@ -89,7 +77,7 @@ async def _run_agent( Any: The result of the agent's conversation loop. """ logger.debug( - f"Agent run started with input: {input_data if input_data else 'Using session conversation context'}" + f"Agent run started with input: {input_data if input_data else 'Using memory context'}" ) # Construct messages using only input_data; chat history handled internally @@ -103,7 +91,7 @@ async def _run_agent( if input_data and user_message_copy: # Add the new user message to memory only if input_data is provided and user message exists user_msg = UserMessage(content=user_message_copy.get("content", "")) - self.memory_store.add_message(user_msg) + self.memory.add_message(user_msg) # Always print the last user message for context, even if no input_data is provided if user_message_copy is not None: @@ -172,8 +160,8 @@ async def run_and_record(tool_call: ToolCall) -> ToolMessage: ) # Print the tool message for visibility self.text_formatter.print_message(tool_message) - # Add tool message to storage - self.memory_store.add_message(tool_message) + # Add tool message to memory + self.memory.add_message(tool_message) # Append tool message to the persistent audit log tool_execution_record = ToolExecutionRecord( tool_call_id=tool_id, @@ -227,7 +215,7 @@ async def conversation(self, messages: List[Dict[str, Any]]) -> Any: else: assistant = response_message self.text_formatter.print_message(assistant) - self.memory_store.add_message(assistant) + self.memory.add_message(assistant) # Handle tool calls response if assistant is not None and assistant.has_tool_calls(): diff --git a/dapr_agents/agents/base.py b/dapr_agents/agents/base.py index 97366905..e7bfe864 100644 --- a/dapr_agents/agents/base.py +++ b/dapr_agents/agents/base.py @@ -1,4 +1,8 @@ -from dapr_agents.agents.memory_store import MemoryStore +from dapr_agents.memory import ( + MemoryBase, + ConversationListMemory, + ConversationVectorMemory, +) from dapr_agents.agents.utils.text_printer import ColorTextFormatter from dapr_agents.types import MessagePlaceHolder, BaseMessage, ToolExecutionRecord from dapr_agents.tool.executor import AgentToolExecutor @@ -6,7 +10,6 @@ from dapr_agents.prompt import ChatPromptTemplate from dapr_agents.tool.base import AgentTool import re -import json from datetime import datetime import logging import asyncio @@ -86,6 +89,11 @@ class AgentBase(BaseModel, ABC): max_iterations: int = Field( default=10, description="Max iterations for conversation cycles." ) + # TODO(@Sicoyle): Rename this to make clearer + memory: MemoryBase = Field( + default_factory=ConversationListMemory, + description="Handles long-term conversation history (for all workflow instance-ids within the same session) and context storage.", + ) # TODO: we should have a system_template, prompt_template, and response_template, or better separation here. # If we have something like a customer service agent, we want diff templates for different types of interactions. # In future, we could also have a way to dynamically change the template based on the context of the interaction. @@ -93,18 +101,6 @@ class AgentBase(BaseModel, ABC): default="jinja2", description="The format used for rendering the prompt template.", ) - memory_store: Optional["MemoryStore"] = Field( - default=None, - description=( - "Storage for conversation history. " - "If None, a default in-memory Storage will be created. " - "For persistent storage, specify the name of the Dapr State Store to use. " - ), - ) - registry_store: Optional[str] = Field( - default=None, - description="Agent registry store name for storing static agent information. Defaults to memory_store state store name if not provided.", - ) DEFAULT_SYSTEM_PROMPT: ClassVar[str] """Default f-string template; placeholders will be swapped to Jinja if needed.""" @@ -128,7 +124,6 @@ class AgentBase(BaseModel, ABC): _text_formatter: ColorTextFormatter = PrivateAttr( default_factory=ColorTextFormatter ) - _dapr_client: Optional[Any] = PrivateAttr(default=None) model_config = ConfigDict(arbitrary_types_allowed=True) @@ -174,46 +169,6 @@ def model_post_init(self, __context: Any) -> None: if self.llm is None: self.llm = get_default_llm() - # Initialize storage if not provided (in-memory by default) - if self.memory_store is None: - self.memory_store = MemoryStore() - logger.debug("Initialized default in-memory Storage") - - # Initialize Dapr client if storage is persistent - # This is needed for state store access and agent registration - if self.memory_store and self.memory_store.name and self._dapr_client is None: - from dapr.clients import DaprClient - - self._dapr_client = DaprClient() - logger.debug( - f"Initialized Dapr client for agent '{self.name}' with persistent storage" - ) - - # Register agent if it has persistent storage - # This applies to both Agent and DurableAgent with persistent storage - if self.memory_store and self.memory_store.name: - if self.registry_store is None: - self.registry_store = self.memory_store.name - agent_metadata = { - "name": self.name, - "role": self.role, - "goal": self.goal, - "tool_choice": self.tool_choice, - "instructions": self.instructions, - "topic_name": getattr(self, "agent_topic_name", None), - "pubsub_name": getattr(self, "message_bus_name", None), - "orchestrator": False, - "statestore_name": self.memory_store.name, - "registry_name": self.registry_store, - } - - self.register_agent( - store_name=self.registry_store, - store_key="agent_registry", - agent_name=self.name, - agent_metadata=self._serialize_metadata(agent_metadata), - ) - # Centralize prompt template selection logic self.prompt_template = self._initialize_prompt_template() # Ensure LLM client and agent both reference the same template @@ -338,17 +293,33 @@ def text_formatter(self) -> ColorTextFormatter: """Returns the text formatter for the agent.""" return self._text_formatter - @abstractmethod def get_chat_history(self, task: Optional[str] = None) -> List[Dict[str, Any]]: """ - Retrieves the chat history as a list of dictionaries. + Retrieves the chat history from memory as a list of dictionaries. Args: - task (Optional[str]): The task or query provided by the user. + task (Optional[str]): The task or query provided by the user (used for vector search). Returns: List[Dict[str, Any]]: The chat history as dictionaries. """ + if isinstance(self.memory, ConversationVectorMemory) and task: + if ( + hasattr(self.memory.vector_store, "embedding_function") + and self.memory.vector_store.embedding_function + and hasattr( + self.memory.vector_store.embedding_function, "embed_documents" + ) + ): + query_embeddings = self.memory.vector_store.embedding_function.embed( + task + ) + messages = self.memory.get_messages(query_embeddings=query_embeddings) + else: + messages = self.memory.get_messages() + else: + messages = self.memory.get_messages() + return messages @property def chat_history(self) -> List[Dict[str, Any]]: @@ -487,6 +458,10 @@ def construct_messages( else: raise ValueError("Input data must be either a string or dictionary.") + def reset_memory(self): + """Clears all messages stored in the agent's memory.""" + self.memory.reset_memory() + def get_last_message(self) -> Optional[Dict[str, Any]]: """ Retrieves the last message from the chat history. @@ -561,115 +536,6 @@ def get_llm_tools(self) -> List[Union[AgentTool, Dict[str, Any]]]: continue return llm_tools - def register_agent( - self, store_name: str, store_key: str, agent_name: str, agent_metadata: dict - ) -> None: - """ - Merges the existing data with the new data and updates the store. - Only works for agents with Dapr client access (AgenticWorkflow subclasses). - - Args: - store_name (str): The name of the Dapr state store component. - store_key (str): The key to update. - agent_name (str): The name of the agent to register. - agent_metadata (dict): The metadata to register for the agent. - """ - import json - import time - from dapr.clients.grpc._response import StateResponse - from dapr.clients.grpc._state import StateOptions, Concurrency, Consistency - from dapr.clients.grpc._request import ( - TransactionalStateOperation, - TransactionOperationType, - ) - - # Only proceed if agent has Dapr client - if not hasattr(self, "_dapr_client"): - logger.debug( - f"Agent '{self.name}' does not have Dapr client, skipping registration" - ) - return - - # retry the entire operation up to twenty times sleeping 1-2 seconds between each - # TODO: rm the custom retry logic here and use the DaprClient retry_policy instead. - for attempt in range(1, 21): - try: - # Get current registry and etag - response: StateResponse = self._dapr_client.get_state( - store_name=store_name, key=store_key - ) - if not response.etag: - # if there is no etag the following transaction won't work as expected - # so we need to save an empty object with a strong consistency to force the etag to be created - self._dapr_client.save_state( - store_name=store_name, - key=store_key, - value=json.dumps({}), - state_metadata={ - "contentType": "application/json", - "partitionKey": store_key, - }, - options=StateOptions( - concurrency=Concurrency.first_write, - consistency=Consistency.strong, - ), - ) - - # reread to obtain the freshly minted ETag - response = self._dapr_client.get_state( - store_name=store_name, key=store_key - ) - if not response.etag: - raise RuntimeError("ETag still missing after init") - - existing = ( - self._deserialize_state(response.data) if response.data else {} - ) - - if existing.get(agent_name) == agent_metadata: - logger.debug(f"Agent '{agent_name}' already registered") - return - - safe_metadata = self._serialize_metadata(agent_metadata) - - merged = {**existing, agent_name: safe_metadata} - merged_json = json.dumps(merged) - - logger.debug(f"merged data: {merged_json} etag: {response.etag}") - try: - # using the transactional API to be able to later support the Dapr outbox pattern - self._dapr_client.execute_state_transaction( - store_name=store_name, - operations=[ - TransactionalStateOperation( - key=store_key, - data=merged_json, - etag=response.etag, - operation_type=TransactionOperationType.upsert, - ) - ], - transactional_metadata={ - "contentType": "application/json", - "partitionKey": store_key, - }, - ) - except Exception as e: - raise e - return None - except Exception as e: - logger.error(f"Error on transaction attempt: {attempt}: {e}") - # Add random jitter - import random - - delay = 1 + random.uniform(0, 1) # 1-2 seconds - logger.info( - f"Sleeping for {delay:.2f} seconds before retrying transaction..." - ) - time.sleep(delay) - raise Exception( - f"Failed to update state store key: {store_key} after 20 attempts." - ) - def pre_fill_prompt_template(self, **kwargs: Union[str, Callable[[], str]]) -> None: """ Pre-fills the prompt template with specified variables, updating input variables if applicable. @@ -689,44 +555,3 @@ def pre_fill_prompt_template(self, **kwargs: Union[str, Callable[[], str]]) -> N self.prompt_template = self.prompt_template.pre_fill_variables(**kwargs) logger.debug(f"Pre-filled prompt template with variables: {kwargs.keys()}") - - def _deserialize_state(self, raw: Union[bytes, str, dict]) -> dict: - """ - Convert Dapr's raw payload (bytes, JSON string, or already a dict) into a dict. - Raises helpful errors on failure. - """ - if isinstance(raw, dict): - return raw - - if isinstance(raw, bytes): - try: - raw = raw.decode("utf-8") - except UnicodeDecodeError as exc: - raise ValueError("State bytes are not valid UTF-8") from exc - - if isinstance(raw, str): - try: - return json.loads(raw) - except json.JSONDecodeError as exc: - raise ValueError(f"State is not valid JSON: {exc}") from exc - - raise TypeError(f"Unsupported state type {type(raw)!r}") - - def _serialize_metadata(self, metadata: Any) -> Any: - """ - Recursively convert Pydantic models (e.g., AgentTool), lists, dicts to JSON-serializable format. - Handles mixed tools: [AgentTool(...), "string", ...] → [{"name": "..."}, "string", ...] - """ - - def convert(obj: Any) -> Any: - if hasattr(obj, "model_dump"): - return obj.model_dump() - if hasattr(obj, "dict"): - return obj.dict() - if isinstance(obj, (list, tuple)): - return [convert(i) for i in obj] - if isinstance(obj, dict): - return {k: convert(v) for k, v in obj.items()} - return obj - - return convert(metadata) diff --git a/dapr_agents/agents/durableagent/agent.py b/dapr_agents/agents/durableagent/agent.py index 81d7816f..ab5ce41f 100644 --- a/dapr_agents/agents/durableagent/agent.py +++ b/dapr_agents/agents/durableagent/agent.py @@ -8,7 +8,7 @@ from pydantic import Field, model_validator from dapr_agents.agents.base import AgentBase -from dapr_agents.agents.memory_store import DurableAgentWorkflowState +from dapr_agents.agents.durableagent.state import DurableAgentWorkflowState from dapr_agents.types import ( AgentError, AssistantMessage, @@ -20,6 +20,7 @@ from dapr_agents.types.workflow import DaprWorkflowStatus from dapr_agents.workflow.agentic import AgenticWorkflow from dapr_agents.workflow.decorators import message_router, task, workflow +from dapr_agents.memory import ConversationDaprStateMemory from .schemas import ( AgentTaskResponse, @@ -27,7 +28,7 @@ InternalTriggerAction, TriggerAction, ) -from dapr_agents.agents.memory_store import ( +from .state import ( DurableAgentMessage, DurableAgentWorkflowEntry, ) @@ -47,37 +48,6 @@ class DurableAgent(AgenticWorkflow, AgentBase): and refining outputs through iterative feedback loops. """ - def get_chat_history(self, task: Optional[str] = None) -> List[Dict[str, Any]]: - """ - Retrieves the chat history as a list of dictionaries. - - Args: - task (Optional[str]): The task or query provided by the user. - - Returns: - List[Dict[str, Any]]: The chat history as dictionaries. - """ - # Get instance-specific chat history - if self.memory_store._current_state is None: - logger.warning("Agent state is None, initializing empty state") - self.memory_store._current_state = {} - - # Get messages from all instances - all_messages = [] - instance_data = self.memory_store._current_state.get("instances", {}) - logger.debug(f"Current state instances: {instance_data}") - - for instance in instance_data.values(): - messages = instance.get("messages", []) - logger.debug(f"Instance messages: {messages}") - all_messages.extend(messages) - - # Get long-term memory from workflow state (for broadcast messages and persistent context) - long_term_memory = self.memory_store._current_state.get("chat_history", []) - all_messages.extend(long_term_memory) - - return all_messages - agent_topic_name: Optional[str] = Field( default=None, description="The topic name dedicated to this specific agent, derived from the agent's name if not provided.", @@ -90,6 +60,10 @@ def get_chat_history(self, task: Optional[str] = None) -> List[Dict[str, Any]]: default=None, description="The current workflow instance ID for this agent.", ) + memory: Optional[ConversationDaprStateMemory] = Field( + default=None, + description="Persistent memory with session-based state hydration.", + ) @model_validator(mode="before") def set_agent_and_topic_name(cls, values: dict): @@ -114,31 +88,34 @@ def model_post_init(self, __context: Any) -> None: self._workflow_name = "AgenticWorkflow" # Initialize state structure if it doesn't exist - if not self.memory_store._current_state: - self.memory_store._current_state = {"instances": {}} + if not self.state: + self.state = {"instances": {}} - if not self.memory_store.name: - raise ValueError("DurableAgent must have a name for persistent storage") + if self.memory is not None: + self.memory = ConversationDaprStateMemory( + store_name=self.memory.store_name, + session_id=f"{self.name or 'agent'}_session", + ) + logger.info(f"Initialized memory with store name: {self.memory.store_name}") # Load the current workflow instance ID from state using session_id - logger.debug(f"State after loading: {self.memory_store._current_state}") - if self.memory_store._current_state and self.memory_store._current_state.get( - "instances" - ): - logger.debug( - f"Found {len(self.memory_store._current_state['instances'])} instances in state" + logger.debug(f"State after loading: {self.state}") + if self.state and self.state.get("instances"): + logger.debug(f"Found {len(self.state['instances'])} instances in state") + current_session_id = ( + self.memory.session_id + if self.memory + else f"{self.name}_default_session" ) - for instance_id, instance_data in self.memory_store._current_state[ - "instances" - ].items(): + for instance_id, instance_data in self.state["instances"].items(): stored_workflow_name = instance_data.get("workflow_name") stored_session_id = instance_data.get("session_id") logger.debug( - f"Instance {instance_id}: workflow_name={stored_workflow_name}, session_id={stored_session_id}, current_workflow_name={self._workflow_name}, current_session_id={self.memory_store.session_id}" + f"Instance {instance_id}: workflow_name={stored_workflow_name}, session_id={stored_session_id}, current_workflow_name={self._workflow_name}, current_session_id={current_session_id}" ) if ( stored_workflow_name == self._workflow_name - and stored_session_id == self.memory_store.session_id + and stored_session_id == current_session_id ): self.workflow_instance_id = instance_id logger.debug( @@ -152,25 +129,31 @@ def model_post_init(self, __context: Any) -> None: # This ensures our database reflects the actual state of resumed workflows self._sync_workflow_state_after_startup() - async def run( - self, input_data: Union[str, Dict[str, Any]], session_id: Optional[str] = None - ) -> Any: + # Register the agentic system + self._agent_metadata = { + "name": self.name, + "role": self.role, + "goal": self.goal, + "instructions": self.instructions, + "topic_name": self.agent_topic_name, + "pubsub_name": self.message_bus_name, + "orchestrator": False, + } + + self.register_agentic_system() + + async def run(self, input_data: Union[str, Dict[str, Any]]) -> Any: """ Fire up the workflow, wait for it to complete, then return the final serialized_output. Dapr automatically handles resuming any incomplete workflows when the runtime starts. Args: input_data (Union[str, Dict[str, Any]]): The input for the workflow. Can be a string (task) or a dict. - session_id (Optional[str]): Override the default session ID for this run. Returns: Any: The final output from the workflow execution. """ logger.debug(f"DurableAgent.run() called with input: {input_data}") - # Update session ID if provided - if session_id is not None: - self.memory_store.session_id = session_id - # Set up signal handlers for graceful shutdown when using run() method self.setup_signal_handlers() @@ -264,8 +247,6 @@ def tool_calling_workflow(self, ctx: DaprWorkflowContext, message: TriggerAction try: for turn in range(1, self.max_iterations + 1): - self.load_state() - if not ctx.is_replaying: logger.debug( f"Workflow turn {turn}/{self.max_iterations} (Instance ID: {ctx.instance_id})" @@ -430,7 +411,9 @@ def record_initial_entry( "workflow_instance_id": instance_id, "triggering_workflow_instance_id": triggering_workflow_instance_id, "workflow_name": self._workflow_name, - "session_id": self.memory_store.session_id, + "session_id": self.memory.session_id + if self.memory + else f"{self.name}_default_session", "start_time": start_time_str, "trace_context": trace_context, "status": DaprWorkflowStatus.RUNNING.value, @@ -438,10 +421,9 @@ def record_initial_entry( "tool_history": [], "end_time": None, } - if "instances" not in self.memory_store._current_state: - self.memory_store._current_state["instances"] = {} - self.memory_store._current_state["instances"][instance_id] = entry - self.memory_store._update_session_index(instance_id) + if "instances" not in self.state: + self.state["instances"] = {} + self.state["instances"][instance_id] = entry # Note: This is only really needed bc of the in-memory storage solutions. # With persistent storage, this is not needed as we rehydrate the conversation state from the database upon app restart. @@ -453,9 +435,9 @@ def _ensure_instance_exists( time: Optional[datetime] = None, ) -> None: """Ensure the instance entry exists in the state.""" - if instance_id not in self.memory_store._current_state.get("instances", {}): - if "instances" not in self.memory_store._current_state: - self.memory_store._current_state["instances"] = {} + if instance_id not in self.state.get("instances", {}): + if "instances" not in self.state: + self.state["instances"] = {} # Handle time parameter - it might be a datetime object or a string if time: @@ -466,14 +448,16 @@ def _ensure_instance_exists( else: start_time = datetime.now(timezone.utc).isoformat() - self.memory_store._current_state["instances"][instance_id] = { + self.state["instances"][instance_id] = { "input": input, "start_time": start_time, "source": "user_input", "workflow_instance_id": instance_id, "triggering_workflow_instance_id": triggering_workflow_instance_id, "workflow_name": self._workflow_name, - "session_id": self.memory_store.session_id, + "session_id": self.memory.session_id + if self.memory + else f"{self.name}_default_session", "messages": [], "tool_history": [], "status": DaprWorkflowStatus.RUNNING.value, @@ -491,11 +475,14 @@ def _process_user_message( if not (task and user_message_copy): return + user_msg = UserMessage(content=user_message_copy.get("content", "")) + if self.memory: + self.memory.add_message(user_msg) + msg_object = DurableAgentMessage(**user_message_copy) - inst = self.memory_store._current_state["instances"][instance_id] + inst = self.state["instances"][instance_id] inst["messages"].append(msg_object.model_dump(mode="json")) inst["last_message"] = msg_object.model_dump(mode="json") - # TODO: below uses the state mixin. When we remove the state mixin then this should be cleaned up below! self.save_state() def _call_llm(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]: @@ -522,7 +509,7 @@ def _save_assistant_message( assistant_message["name"] = self.name agent_msg = DurableAgentMessage(**assistant_message) - inst = self.memory_store._current_state["instances"][instance_id] + inst = self.state["instances"][instance_id] messages_list = inst["messages"] # Check for duplicate by message ID (idempotent for workflow replay) @@ -530,6 +517,8 @@ def _save_assistant_message( if not message_exists: messages_list.append(agent_msg.model_dump(mode="json")) inst["last_message"] = agent_msg.model_dump(mode="json") + if self.memory: + self.memory.add_message(AssistantMessage(**assistant_message)) self.save_state() def _print_llm_interaction_messages( @@ -643,8 +632,7 @@ def _append_tool_message_to_instance( agent_msg: The DurableAgentMessage object tool_history_entry: The ToolExecutionRecord object """ - wf_instance = self.memory_store._current_state["instances"][instance_id] - logger.debug(f"Before appending - instance state: {wf_instance}") + wf_instance = self.state["instances"][instance_id] # Check if message already exists (idempotent operation for workflow replay) wf_messages = wf_instance["messages"] @@ -652,9 +640,7 @@ def _append_tool_message_to_instance( # Check for duplicate by message ID (idempotent for workflow replay) message_exists = any(msg.get("id") == agent_msg.id for msg in wf_messages) if not message_exists: - msg_dump = agent_msg.model_dump(mode="json") - logger.debug(f"Adding new message to instance: {msg_dump}") - wf_messages.append(msg_dump) + wf_messages.append(agent_msg.model_dump(mode="json")) # Check for duplicate tool history entry by tool_call_id tool_history = wf_instance["tool_history"] @@ -664,11 +650,24 @@ def _append_tool_message_to_instance( for th in tool_history ) if not tool_exists: - history_dump = tool_history_entry.model_dump(mode="json") - logger.debug(f"Adding new tool history entry: {history_dump}") - tool_history.append(history_dump) + tool_history.append(tool_history_entry.model_dump(mode="json")) - logger.debug(f"After appending - instance state: {wf_instance}") + def _update_agent_memory_and_history( + self, tool_message: ToolMessage, tool_history_entry: ToolExecutionRecord + ) -> None: + """ + Update agent's memory and tool history. + + Args: + tool_message: The ToolMessage object + tool_history_entry: The ToolExecutionRecord object + """ + # Update tool history and memory of agent (only if new) + # Note: Memory updates are handled at workflow level to avoid replay issues + self.tool_history.append(tool_history_entry) + # Add the tool message to the agent's memory if enabled + if self.memory: + self.memory.add_message(tool_message) def _get_last_message_from_state( self, instance_id: str @@ -682,9 +681,7 @@ def _get_last_message_from_state( Returns: The last message dict or None if not found """ - instance_data = self.memory_store._current_state.get("instances", {}).get( - instance_id - ) + instance_data = self.state.get("instances", {}).get(instance_id) if instance_data is not None: return instance_data.get("last_message") return None @@ -738,40 +735,23 @@ async def run_tool( "execution_result": str(result) if result is not None else "", } - # Atomically persist both the assistant's tool call and the tool result + # Atomically persist the tool execution result # Get existing input or use placeholder existing_input = ( - self.memory_store._current_state["instances"][instance_id]["input"] - if instance_id in self.memory_store._current_state.get("instances", {}) + self.state["instances"][instance_id]["input"] + if instance_id in self.state.get("instances", {}) else "Tool execution" ) self._ensure_instance_exists(instance_id, existing_input, time=time) - - # Save both the assistant's tool call and the tool result atomically - instance = self.memory_store._current_state["instances"][instance_id] - instance["messages"].append( - { - "role": "tool", - "content": str(result) if result is not None else "", - "name": fn_name, - "tool_call_id": tool_call["id"], - } + tool_msg, agent_msg, tool_history_entry = self._create_tool_message_objects( + tool_result ) - instance["tool_history"].append(tool_result) - tool_record = ToolExecutionRecord(**tool_result) - self.tool_history.append(tool_record) - - self.save_state() - - # Print for visibility - self.text_formatter.print_message( - { - "role": "tool", - "content": str(result) if result is not None else "", - "name": fn_name, - "tool_call_id": tool_call["id"], - } + self._append_tool_message_to_instance( + instance_id, agent_msg, tool_history_entry ) + self._update_agent_memory_and_history(tool_msg, tool_history_entry) + self.save_state() + self.text_formatter.print_message(tool_msg) return tool_result @@ -828,14 +808,14 @@ def finalize_workflow( """ # Ensure the instance entry exists existing_input = ( - self.memory_store._current_state["instances"][instance_id]["input"] - if instance_id in self.memory_store._current_state.get("instances", {}) + self.state["instances"][instance_id]["input"] + if instance_id in self.state.get("instances", {}) else "Workflow completion" ) self._ensure_instance_exists( instance_id, existing_input, triggering_workflow_instance_id, time ) - instance = self.memory_store._current_state["instances"][instance_id] + instance = self.state["instances"][instance_id] instance["output"] = final_output # Convert time to string for JSON serialization if time: @@ -887,15 +867,17 @@ async def process_broadcast_message(self, message: BroadcastMessage): logger.debug( f"{self.name} processing broadcast message from '{source}'. Content: {message_content}" ) + # Store the message in local memory if enabled + if self.memory: + self.memory.add_message(message) + # Define DurableAgentMessage object for state persistence agent_msg = DurableAgentMessage(**message.model_dump()) # Persist to global chat history - if "chat_history" not in self.memory_store._current_state: - self.memory_store._current_state["chat_history"] = [] - self.memory_store._current_state["chat_history"].append( - agent_msg.model_dump(mode="json") - ) + if "chat_history" not in self.state: + self.state["chat_history"] = [] + self.state["chat_history"].append(agent_msg.model_dump(mode="json")) # Save the state after processing the broadcast message self.save_state() @@ -944,22 +926,33 @@ def _construct_messages_with_instance_history( ) # Get instance-specific chat history - if self.memory_store._current_state is None: + if self.state is None: logger.warning( f"Agent state is None for instance {instance_id}, initializing empty state" ) - self.memory_store._current_state = {} + self.state = {} - instance_data = self.memory_store._current_state.get("instances", {}).get( - instance_id - ) + instance_data = self.state.get("instances", {}).get(instance_id) if instance_data is not None: instance_messages = instance_data.get("messages", []) else: instance_messages = [] + # Get messages from persistent memory (session-based, cross-workflow) if memory is enabled + persistent_memory_messages = [] + if self.memory: + try: + persistent_memory_messages = self.memory.get_messages() + logger.info( + f"Retrieved {len(persistent_memory_messages)} messages for session {self.memory.session_id}" + ) + except Exception as e: + logger.warning(f"Failed to retrieve persistent memory: {e}") + else: + logger.debug("Memory is disabled, skipping persistent memory retrieval") + # Get long-term memory from workflow state (for broadcast messages and persistent context) - long_term_memory_data = self.memory_store._current_state.get("chat_history", []) + long_term_memory_data = self.state.get("chat_history", []) long_term_memory_messages = [] for msg in long_term_memory_data: if isinstance(msg, dict): @@ -970,46 +963,29 @@ def _construct_messages_with_instance_history( # Build chat history with proper context and order chat_history = [] - # Add long-term memory messages - for msg in long_term_memory_messages: + # First add persistent memory and long-term memory as user messages for context + # This ensures we have cross-workflow context but doesn't interfere with tool state order + for msg in persistent_memory_messages + long_term_memory_messages: msg_dict = msg.model_dump() if hasattr(msg, "model_dump") else dict(msg) - if msg_dict not in chat_history: - chat_history.append(msg_dict) - - # Add instance messages in chronological order, preserving proper tool call/response pairing - # Track which assistant messages have been added to avoid duplicates - added_assistant_ids = set() + if msg_dict in chat_history: + continue + # TODO: We need to properly design session-based memory. + # Convert tool-related messages to user messages to avoid conversation order issues + if msg_dict.get("role") in ["tool", "assistant"] and ( + msg_dict.get("tool_calls") or msg_dict.get("tool_call_id") + ): + msg_dict = { + "role": "user", + "content": f"[Previous {msg_dict['role']} message: {msg_dict.get('content', '')}]", + } + chat_history.append(msg_dict) + # Then add instance messages in their original form to maintain tool state for msg in instance_messages: msg_dict = msg.model_dump() if hasattr(msg, "model_dump") else dict(msg) - - # For assistant messages with tool calls, add them only once along with all their tool responses - if msg_dict.get("role") == "assistant" and msg_dict.get("tool_calls"): - msg_id = msg_dict.get("id") - if msg_id not in added_assistant_ids: - # Add the assistant message - chat_history.append(msg_dict) - added_assistant_ids.add(msg_id) - - # Add all tool responses for this assistant message's tool calls - tool_call_ids = { - tc.get("id") for tc in msg_dict.get("tool_calls", []) - } - for tool_msg in instance_messages: - tool_dict = ( - tool_msg.model_dump() - if hasattr(tool_msg, "model_dump") - else dict(tool_msg) - ) - if ( - tool_dict.get("role") == "tool" - and tool_dict.get("tool_call_id") in tool_call_ids - ): - chat_history.append(tool_dict) - # For other messages (user, tool, assistant without tool_calls), add them if not already added - elif msg_dict.get("role") != "tool": - # Only add if it's not a tool message (tools are added with their assistant messages) - chat_history.append(msg_dict) + if msg_dict in chat_history: + continue + chat_history.append(msg_dict) # Add additional context memory last (for broadcast-triggered workflows) chat_history.extend(additional_context_messages) diff --git a/dapr_agents/agents/durableagent/state.py b/dapr_agents/agents/durableagent/state.py new file mode 100644 index 00000000..242dca22 --- /dev/null +++ b/dapr_agents/agents/durableagent/state.py @@ -0,0 +1,80 @@ +from pydantic import BaseModel, Field +from typing import List, Optional, Dict, Any +from dapr_agents.types import MessageContent, ToolExecutionRecord +from dapr_agents.types.workflow import DaprWorkflowStatus +from datetime import datetime +import uuid + + +class DurableAgentMessage(MessageContent): + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Unique identifier for the message", + ) + timestamp: datetime = Field( + default_factory=datetime.now, + description="Timestamp when the message was created", + ) + + +class DurableAgentWorkflowEntry(BaseModel): + """Represents a workflow and its associated data, including metadata on the source of the task request.""" + + input: str = Field( + ..., description="The input or description of the Workflow to be performed" + ) + output: Optional[str] = Field( + default=None, description="The output or result of the Workflow, if completed" + ) + start_time: datetime = Field( + default_factory=datetime.now, + description="Timestamp when the workflow was started", + ) + end_time: Optional[datetime] = Field( + default=None, + description="Timestamp when the workflow was completed or failed", + ) + messages: List[DurableAgentMessage] = Field( + default_factory=list, + description="Messages exchanged during the workflow (user, assistant, or tool messages).", + ) + last_message: Optional[DurableAgentMessage] = Field( + default=None, description="Last processed message in the workflow" + ) + tool_history: List[ToolExecutionRecord] = Field( + default_factory=list, description="Tool message exchanged during the workflow" + ) + source: Optional[str] = Field(None, description="Entity that initiated the task.") + workflow_instance_id: Optional[str] = Field( + default=None, + description="The agent's own workflow instance ID.", + ) + triggering_workflow_instance_id: Optional[str] = Field( + default=None, + description="The workflow instance ID of the entity that triggered this agent (for multi-agent communication).", + ) + workflow_name: Optional[str] = Field( + default=None, + description="The name of the workflow.", + ) + trace_context: Optional[Dict[str, Any]] = Field( + default=None, + description="OpenTelemetry trace context for workflow resumption.", + ) + status: str = Field( + default=DaprWorkflowStatus.RUNNING.value, + description="Current status of the workflow.", + ) + + +class DurableAgentWorkflowState(BaseModel): + """Represents the state of multiple Agent workflows.""" + + instances: Dict[str, DurableAgentWorkflowEntry] = Field( + default_factory=dict, + description="Workflow entries indexed by their instance_id.", + ) + chat_history: List[DurableAgentMessage] = Field( + default_factory=list, + description="Chat history of messages exchanged during the workflow.", + ) diff --git a/dapr_agents/agents/memory_store.py b/dapr_agents/agents/memory_store.py deleted file mode 100644 index 11e413d8..00000000 --- a/dapr_agents/agents/memory_store.py +++ /dev/null @@ -1,359 +0,0 @@ -import json -from pydantic import Field, BaseModel, PrivateAttr -from typing import Optional, Type, Dict, Any, List, Union -from datetime import datetime -import uuid -from dapr_agents.types import MessageContent, ToolExecutionRecord -from dapr_agents.types.workflow import DaprWorkflowStatus -import logging -from dapr_agents.types import BaseMessage -from dapr.clients import DaprClient - -logger = logging.getLogger(__name__) - - -class MemoryStore(BaseModel): - """ - Unified storage for both Agent and DurableAgent. - - For regular Agent: - - If `name` is None: Pure in-memory operation (no persistence, no registration) - - If `name` is provided: - - Conversation history: Persistent in Dapr state store - - For DurableAgent: - - Requires `name` to be provided - - Conversation history: Persistent in Dapr state store (via workflow instances) - - Workflow state: Full workflow instance tracking with sessions - """ - - name: Optional[str] = Field( - default=None, - description=( - "Dapr state store name. " - "For Agent: If set, stores conversation and registers agent. If None, pure in-memory. " - "For DurableAgent: Required. Stores workflow state, conversation, and registers agent." - ), - ) - session_id: Optional[str] = Field( - default=None, - description=( - "Session ID to group related conversations and workflow state. " - "If not provided, a unique ID will be generated. " - "Can be overridden in agent.run()." - ), - ) - local_directory: Optional[str] = Field( - default=None, - description=( - "Directory path where state files will be saved locally. " - "If not set, no local saving occurs. " - "Can be absolute or relative to workspace root. " - "Files will be named '{agent_name}_state.json'." - ), - ) - # TODO: remove this when we remove state mixin class and just use db state... - _current_state: dict = PrivateAttr(default_factory=dict) - _agent_name: str = PrivateAttr(default=None) # Set by AgenticWorkflow - _key: str = PrivateAttr(default="workflow_state") - _in_memory_messages: List[Dict[str, Any]] = PrivateAttr(default_factory=list) - _dapr_client: DaprClient = PrivateAttr(default=None) - - def model_post_init(self, __context: Any) -> None: - if self.name is None: - if self._dapr_client is not None: - logger.warning( - "DaprClient initialized but name is None. It will be ignored." - ) - self._dapr_client = None - else: - self._dapr_client = DaprClient() - - def _set_key(self, agent_name: str) -> None: - """Internal method to set the agent name and initialize storage.""" - self._agent_name = agent_name - self._key = f"{agent_name}_workflow_state" - - def _get_instance_key(self, instance_id: str) -> str: - """Get the state store key for a workflow instance.""" - return f"{self._agent_name}_workflow_{instance_id}" - - def _get_session_key(self, session_id: str) -> str: - """Get the state store key for a specific session.""" - return f"{self._agent_name}_session_{session_id}" - - def _get_sessions_index_key(self) -> str: - """Get the state store key for the sessions index.""" - return f"{self._agent_name}_sessions" - - def _get_session_id(self) -> str: - """Get or generate a session ID.""" - if not self.session_id: - self.session_id = str(uuid.uuid4()) - return self.session_id - - def _update_session_index(self, instance_id: str) -> None: - """Update session index with workflow instance.""" - if not self.name: - logger.debug("In-memory mode: skipping session update") - return - - session_id = self._get_session_id() - session_key = self._get_session_key(session_id) - sessions_index_key = self._get_sessions_index_key() - - response = self._dapr_client.get_state(self.name, session_key) - - session_data = {} - is_new_session = not bool(response.data) - - if response.data: - # Safely decode and parse - raw = response.data - if isinstance(raw, (bytes, bytearray)): - try: - raw = raw.decode("utf-8") - except UnicodeDecodeError: - logger.error(f"Failed to decode session data for '{session_key}'") - raw = "" - if isinstance(raw, str) and raw.strip(): - try: - session_data = json.loads(raw) - if not isinstance(session_data, dict): - logger.warning( - f"Session data not a dict, resetting: {type(session_data)}" - ) - session_data = {} - except json.JSONDecodeError as e: - logger.error(f"Invalid session JSON for '{session_key}': {e}") - session_data = {} - else: - session_data = {} - - if is_new_session: - session_data = { - "session_id": session_id, - "workflow_instances": [], - "metadata": { - "agent_name": self._agent_name, - "created_at": datetime.now().isoformat(), - }, - "last_active": datetime.now().isoformat(), - } - logger.debug(f"Created new session '{session_id}'") - - instances = session_data.get("workflow_instances", []) - if instance_id not in instances: - instances.append(instance_id) - session_data["workflow_instances"] = instances - session_data["last_active"] = datetime.now().isoformat() - logger.debug(f"Added instance '{instance_id}' to session '{session_id}'") - - # === 4. Save session === - self._save_state_with_metadata(session_key, session_data) - - # update sessions index - only on first instance - if is_new_session: - index_resp = self._dapr_client.get_state(self.name, sessions_index_key) - index_data = {"sessions": [], "last_updated": datetime.now().isoformat()} - - if index_resp.data: - raw = index_resp.data - if isinstance(raw, (bytes, bytearray)): - raw = raw.decode("utf-8") - if raw.strip(): - try: - parsed = json.loads(raw) - if isinstance(parsed, dict): - index_data["sessions"] = parsed.get("sessions", []) - index_data["last_updated"] = parsed.get( - "last_updated", index_data["last_updated"] - ) - except json.JSONDecodeError: - logger.warning("Corrupted sessions index, resetting") - - if session_id not in index_data["sessions"]: - index_data["sessions"].append(session_id) - index_data["last_updated"] = datetime.now().isoformat() - self._save_state_with_metadata(sessions_index_key, index_data) - logger.debug(f"Registered session '{session_id}' in index") - - # TODO: in future remove this in favor of just using client.save_state when we use objects and not dictionaries in storage. - def _save_state_with_metadata(self, key: str, data: Any) -> None: - """Save state with content type metadata.""" - # Serialize data to JSON string if it's not already - if isinstance(data, dict): - data_to_save = json.dumps(data) - elif isinstance(data, str): - data_to_save = data - else: - data_to_save = json.dumps(data) - - self._dapr_client.save_state( - self.name, - key, - data_to_save, - state_metadata={"contentType": "application/json"}, - ) - - def is_persistent(self) -> bool: - """Check if storage is persistent (has a state store name) or in-memory.""" - return self.name is not None - - def add_message(self, message: Union[Dict[str, Any], "BaseMessage"]) -> None: - """ - Add a single message to storage. - - If name is None: Uses in-memory list - - If name is set: Stores in Dapr state store - - Args: - message (Union[Dict[str, Any], BaseMessage]): The message to add - """ - msg_dict = message.model_dump() if hasattr(message, "model_dump") else message - - if self.is_persistent(): - # Save to state store - messages = self.get_messages() - messages.append(msg_dict) - self._save_messages_to_store(messages) - else: - # In-memory mode - self._in_memory_messages.append(msg_dict) - - def add_messages(self, messages: List[Dict[str, Any]]) -> None: - """ - Add multiple messages to storage. - - If name is None: Uses in-memory list - - If name is set: Stores in Dapr state store - - Args: - messages (List[Dict[str, Any]]): The messages to add - """ - if self.is_persistent(): - # Save to state store - current_messages = self.get_messages() - current_messages.extend(messages) - self._save_messages_to_store(current_messages) - else: - # In-memory mode - self._in_memory_messages.extend(messages) - - def get_messages(self) -> List[Dict[str, Any]]: - """ - Get all messages from storage. - - If name is None: Returns in-memory list - - If name is set: Loads from Dapr state store - - Returns: - List[Dict[str, Any]]: All stored messages - """ - if self.is_persistent(): - # Load from state store - return self._load_messages_from_store() - else: - # In-memory mode - return self._in_memory_messages.copy() - - def _get_messages_key(self) -> str: - """Get the state store key for conversation messages.""" - session_id = self._get_session_id() - return f"{self._agent_name}_messages_{session_id}" - - def _save_messages_to_store(self, messages: List[Dict[str, Any]]) -> None: - """Save messages to the Dapr state store.""" - key = self._get_messages_key() - data = json.dumps({"messages": messages}) - - self._dapr_client.save_state( - store_name=self.name, - key=key, - value=data, - state_metadata={"contentType": "application/json"}, - ) - - def _load_messages_from_store(self) -> List[Dict[str, Any]]: - """Load messages from the Dapr state store.""" - key = self._get_messages_key() - response = self._dapr_client.get_state(store_name=self.name, key=key) - if response.data: - data = ( - json.loads(response.data) - if isinstance(response.data, (str, bytes)) - else response.data - ) - return data.get("messages", []) - return [] - - -class DurableAgentMessage(MessageContent): - id: str = Field( - default_factory=lambda: str(uuid.uuid4()), - description="Unique identifier for the message", - ) - timestamp: datetime = Field( - default_factory=datetime.now, - description="Timestamp when the message was created", - ) - - -class DurableAgentWorkflowEntry(BaseModel): - """Represents a workflow and its associated data, including metadata on the source of the task request.""" - - input: str = Field( - ..., description="The input or description of the Workflow to be performed" - ) - output: Optional[str] = Field( - default=None, description="The output or result of the Workflow, if completed" - ) - start_time: datetime = Field( - default_factory=datetime.now, - description="Timestamp when the workflow was started", - ) - end_time: Optional[datetime] = Field( - default=None, - description="Timestamp when the workflow was completed or failed", - ) - messages: List[DurableAgentMessage] = Field( - default_factory=list, - description="Messages exchanged during the workflow (user, assistant, or tool messages).", - ) - last_message: Optional[DurableAgentMessage] = Field( - default=None, description="Last processed message in the workflow" - ) - tool_history: List[ToolExecutionRecord] = Field( - default_factory=list, description="Tool message exchanged during the workflow" - ) - source: Optional[str] = Field(None, description="Entity that initiated the task.") - workflow_instance_id: Optional[str] = Field( - default=None, - description="The agent's own workflow instance ID.", - ) - triggering_workflow_instance_id: Optional[str] = Field( - default=None, - description="The workflow instance ID of the entity that triggered this agent (for multi-agent communication).", - ) - workflow_name: Optional[str] = Field( - default=None, - description="The name of the workflow.", - ) - trace_context: Optional[Dict[str, Any]] = Field( - default=None, - description="OpenTelemetry trace context for workflow resumption.", - ) - status: str = Field( - default=DaprWorkflowStatus.RUNNING.value, - description="Current status of the workflow.", - ) - - -class DurableAgentWorkflowState(BaseModel): - """Represents the state of multiple Agent workflows.""" - - instances: Dict[str, DurableAgentWorkflowEntry] = Field( - default_factory=dict, - description="Workflow entries indexed by their instance_id.", - ) - chat_history: List[DurableAgentMessage] = Field( - default_factory=list, - description="Chat history of messages exchanged during the workflow.", - ) diff --git a/dapr_agents/memory/__init__.py b/dapr_agents/memory/__init__.py index 0596d93a..f76795ac 100644 --- a/dapr_agents/memory/__init__.py +++ b/dapr_agents/memory/__init__.py @@ -1,4 +1,5 @@ from .base import MemoryBase +from .daprstatestore import ConversationDaprStateMemory from .liststore import ConversationListMemory from .vectorstore import ConversationVectorMemory @@ -6,4 +7,5 @@ "MemoryBase", "ConversationListMemory", "ConversationVectorMemory", + "ConversationDaprStateMemory", ] diff --git a/dapr_agents/memory/base.py b/dapr_agents/memory/base.py index 9d4d531a..c49e09e0 100644 --- a/dapr_agents/memory/base.py +++ b/dapr_agents/memory/base.py @@ -66,6 +66,16 @@ def get_messages(self) -> List[Dict[str, Any]]: """ pass + @abstractmethod + def reset_memory(self): + """ + Clears all messages from the memory storage. + + Note: + This method must be implemented by subclasses. + """ + pass + @staticmethod def _convert_to_dict(message: Union[Dict, BaseMessage]) -> Dict: """ diff --git a/dapr_agents/memory/daprstatestore.py b/dapr_agents/memory/daprstatestore.py new file mode 100644 index 00000000..f2fa3734 --- /dev/null +++ b/dapr_agents/memory/daprstatestore.py @@ -0,0 +1,184 @@ +import json +import logging +import uuid +from datetime import datetime +from typing import Any, Dict, List, Optional, Union + +from pydantic import Field, model_validator + +from dapr_agents.memory import MemoryBase +from dapr_agents.storage.daprstores.statestore import DaprStateStore +from dapr_agents.types import BaseMessage + +logger = logging.getLogger(__name__) + + +def generate_numeric_session_id() -> int: + """ + Generates a random numeric session ID by extracting digits from a UUID. + + Returns: + int: A numeric session ID. + """ + return int("".join(filter(str.isdigit, str(uuid.uuid4())))) + + +class ConversationDaprStateMemory(MemoryBase): + """ + Manages conversation memory stored in a Dapr state store. Each message in the conversation is saved + individually with a unique key and includes a session ID and timestamp for querying and retrieval. + """ + + store_name: str = Field( + default="statestore", description="The name of the Dapr state store." + ) + session_id: Optional[Union[str, int]] = Field( + default=None, description="Unique identifier for the conversation session." + ) + + dapr_store: Optional[DaprStateStore] = Field( + default=None, init=False, description="Dapr State Store." + ) + + @model_validator(mode="before") + def set_session_id(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """ + Sets a numeric session ID if none is provided. + + Args: + values (Dict[str, Any]): The dictionary of attribute values before initialization. + + Returns: + Dict[str, Any]: Updated values including the generated session ID if not provided. + """ + if not values.get("session_id"): + values["session_id"] = generate_numeric_session_id() + return values + + def model_post_init(self, __context: Any) -> None: + """ + Initializes the Dapr state store after validation. + """ + self.dapr_store = DaprStateStore(store_name=self.store_name) + logger.info( + f"ConversationDaprStateMemory initialized with session ID: {self.session_id}" + ) + super().model_post_init(__context) + + def _get_message_key(self, message_id: str) -> str: + """ + Generates a unique key for each message using session_id and message_id. + + Args: + message_id (str): A unique identifier for the message. + + Returns: + str: A composite key for storing individual messages. + """ + return f"{self.session_id}:{message_id}" + + def add_message(self, message: Union[Dict[str, Any], BaseMessage]) -> None: + """ + Adds a single message to the memory and saves it to the Dapr state store. + + Args: + message (Union[Dict[str, Any], BaseMessage]): The message to add to the memory. + """ + message = self._convert_to_dict(message) + message_id = str(uuid.uuid4()) + message_key = self._get_message_key(message_id) + message.update( + { + "createdAt": datetime.now().isoformat() + "Z", + } + ) + existing = self.get_messages() + existing.append(message) + logger.debug( + f"Adding message {message} with key {message_key} to session {self.session_id}" + ) + self.dapr_store.save_state( + self.session_id, json.dumps(existing), {"contentType": "application/json"} + ) + + def add_messages(self, messages: List[Union[Dict[str, Any], BaseMessage]]) -> None: + """ + Adds multiple messages to the memory and saves each one individually to the Dapr state store. + + Args: + messages (List[Union[Dict[str, Any], BaseMessage]]): A list of messages to add to the memory. + """ + logger.info(f"Adding {len(messages)} messages to session {self.session_id}") + for message in messages: + self.add_message(message) + + def add_interaction( + self, + user_message: Union[Dict[str, Any], BaseMessage], + assistant_message: Union[Dict[str, Any], BaseMessage], + ) -> None: + """ + Adds a user-assistant interaction to the memory storage and saves it to the state store. + + Args: + user_message (Union[Dict[str, Any], BaseMessage]): The user message. + assistant_message (Union[Dict[str, Any], BaseMessage]): The assistant message. + """ + self.add_messages([user_message, assistant_message]) + + def _decode_message(self, message_data: Union[bytes, str]) -> Dict[str, Any]: + """ + Decodes the message data if it's in bytes, otherwise parses it as a JSON string. + + Args: + message_data (Union[bytes, str]): The message data to decode. + + Returns: + Dict[str, Any]: The decoded message as a dictionary. + """ + if isinstance(message_data, bytes): + message_data = message_data.decode("utf-8") + return json.loads(message_data) + + def get_messages(self, limit: int = 100) -> List[Dict[str, Any]]: + """ + Retrieves messages stored in the state store for the current session_id, with an optional limit. + + Args: + limit (int, optional): The maximum number of messages to retrieve. Defaults to 100. + + Returns: + List[Dict[str, Any]]: A list of message dicts with all fields. + """ + response = self.query_messages(session_id=self.session_id) + if response and hasattr(response, "data") and response.data: + raw_messages = json.loads(response.data) + if raw_messages: + messages = raw_messages[:limit] + logger.info( + f"Retrieved {len(messages)} messages for session {self.session_id}" + ) + return messages + return [] + + def query_messages(self, session_id: str) -> Any: + """ + Queries messages from the state store for the given session_id. + + Args: + session_id (str): The session ID to query messages for. + + Returns: + Any: The response object from the Dapr state store, typically with a 'data' attribute containing the messages as JSON. + """ + logger.debug(f"Executing query for session {self.session_id}") + states_metadata = {"contentType": "application/json"} + response = self.dapr_store.get_state(session_id, state_metadata=states_metadata) + return response + + def reset_memory(self) -> None: + """ + Clears all messages stored in the memory and resets the state store for the current session. + """ + self.dapr_store.delete_state(self.session_id) + logger.info(f"Memory reset for session {self.session_id} completed.") diff --git a/dapr_agents/memory/liststore.py b/dapr_agents/memory/liststore.py index 654a4b03..0c93ad75 100644 --- a/dapr_agents/memory/liststore.py +++ b/dapr_agents/memory/liststore.py @@ -55,3 +55,7 @@ def get_messages(self) -> List[Dict[str, Any]]: List[Dict[str, Any]]: A list containing copies of all stored messages as dictionaries. """ return self.messages.copy() + + def reset_memory(self): + """Clears all messages stored in the memory, resetting the memory to an empty state.""" + self.messages.clear() diff --git a/dapr_agents/observability/wrappers/workflow.py b/dapr_agents/observability/wrappers/workflow.py index 3fcf4ef3..8e331c86 100644 --- a/dapr_agents/observability/wrappers/workflow.py +++ b/dapr_agents/observability/wrappers/workflow.py @@ -403,7 +403,7 @@ def _store_context_in_workflow_state( """ try: # Import here to avoid circular imports - from dapr_agents.agents.memory_store import DurableAgentWorkflowEntry + from dapr_agents.agents.durableagent.state import DurableAgentWorkflowEntry # This is a simplified approach - in practice, you'd need to access # the actual workflow state through the Dapr Workflow runtime diff --git a/dapr_agents/storage/daprstores/__init__.py b/dapr_agents/storage/daprstores/__init__.py new file mode 100644 index 00000000..b1f89a29 --- /dev/null +++ b/dapr_agents/storage/daprstores/__init__.py @@ -0,0 +1,4 @@ +from .base import DaprStoreBase +from .statestore import DaprStateStore + +__all__ = ["DaprStoreBase", "DaprStateStore"] diff --git a/dapr_agents/storage/daprstores/base.py b/dapr_agents/storage/daprstores/base.py new file mode 100644 index 00000000..d06e59d4 --- /dev/null +++ b/dapr_agents/storage/daprstores/base.py @@ -0,0 +1,24 @@ +from dapr.clients import DaprClient +from pydantic import BaseModel, Field, ConfigDict +from typing import Optional, Any + + +class DaprStoreBase(BaseModel): + """ + Pydantic-based Dapr store base model with configuration options for store name, address, host, and port. + """ + + store_name: str = Field(..., description="The name of the Dapr store.") + client: Optional[DaprClient] = Field( + default=None, init=False, description="Dapr client for store operations." + ) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def model_post_init(self, __context: Any) -> None: + """ + Post-initialization to set Dapr settings based on provided or environment values for host and port. + """ + + # Complete post-initialization + super().model_post_init(__context) diff --git a/dapr_agents/storage/daprstores/secretstore.py b/dapr_agents/storage/daprstores/secretstore.py new file mode 100644 index 00000000..d9b12cfc --- /dev/null +++ b/dapr_agents/storage/daprstores/secretstore.py @@ -0,0 +1,39 @@ +from dapr_agents.storage.daprstores.base import DaprStoreBase +from typing import Dict, Optional + + +class DaprSecretStore(DaprStoreBase): + def get_secret( + self, key: str, secret_metadata: Optional[Dict[str, str]] = {} + ) -> Optional[Dict[str, str]]: + """ + Retrieves a secret from the secret store using the provided key. + + Args: + key (str): The key for the secret. + secret_metadata (Dict[str, str], optional): Metadata for the secret request. + + Returns: + Optional[Dict[str, str]]: The secret stored in the secret store, or None if not found. + """ + response = self.client.get_secret( + store_name=self.store_name, key=key, secret_metadata=secret_metadata + ) + return response.secret + + def get_bulk_secret( + self, secret_metadata: Optional[Dict[str, str]] = {} + ) -> Dict[str, Dict[str, str]]: + """ + Retrieves all granted secrets from the secret store. + + Args: + secret_metadata (Dict[str, str], optional): Metadata for the secret request. + + Returns: + Dict[str, Dict[str, str]]: A dictionary of secrets. + """ + response = self.client.get_bulk_secret( + store_name=self.store_name, secret_metadata=secret_metadata + ) + return response.secrets diff --git a/dapr_agents/storage/daprstores/statestore.py b/dapr_agents/storage/daprstores/statestore.py new file mode 100644 index 00000000..d2d41e68 --- /dev/null +++ b/dapr_agents/storage/daprstores/statestore.py @@ -0,0 +1,151 @@ +from dapr.clients.grpc._response import ( + BulkStatesResponse, + BulkStateItem, + StateResponse, + QueryResponse, +) +from dapr.clients import DaprClient +from dapr.clients.grpc._state import StateItem +from dapr_agents.storage.daprstores.base import DaprStoreBase +from typing import Optional, Union, Dict, List, Tuple + + +class DaprStateStore(DaprStoreBase): + def get_state( + self, + key: str, + state_metadata: Optional[Dict[str, str]] = dict(), + ) -> StateResponse: + """ + Retrieves a value from the state store using the provided key. + + Args: + key (str): The key for the state store item. + state_metadata (Dict[str, str], optional): Dapr metadata for state request + + Returns: + StateResponse: gRPC metadata returned from callee and value obtained from the state store + """ + with DaprClient() as client: + response: StateResponse = client.get_state( + store_name=self.store_name, key=key, state_metadata=state_metadata + ) + return response + + def try_get_state( + self, key: str, state_metadata: Optional[Dict[str, str]] = dict() + ) -> Tuple[bool, Optional[dict]]: + """ + Attempts to retrieve a value from the state store using the provided key. + + Args: + key (str): The key for the state store item. + state_metadata (Dict[str, str], optional): Dapr metadata for state request. + + Returns: + Tuple[bool, Optional[dict]]: A tuple where the first element is a boolean indicating whether the state exists, + and the second element is the retrieved state data or None if not found. + """ + with DaprClient() as client: + response: StateResponse = client.get_state( + store_name=self.store_name, key=key, state_metadata=state_metadata + ) + if response and response.data: + return True, response.json() + return False, None + + def get_bulk_state( + self, + keys: List[str], + parallelism: int = 1, + states_metadata: Optional[Dict[str, str]] = None, + ) -> List[BulkStateItem]: + """ + Retrieves multiple values from the state store in bulk using a list of keys. + + Args: + keys (List[str]): The keys to retrieve in bulk. + parallelism (int, optional): Number of keys to retrieve in parallel. + states_metadata (Dict[str, str], optional): Metadata for state request. + + Returns: + List[BulkStateItem]: A list of BulkStateItem objects representing the retrieved state. + """ + states_metadata = states_metadata or {} + + with DaprClient() as client: + response: BulkStatesResponse = client.get_bulk_state( + store_name=self.store_name, + keys=keys, + parallelism=parallelism, + states_metadata=states_metadata, + ) + + if response and response.items: + return response.items + return [] + + def save_state( + self, + key: str, + value: Union[str, bytes], + state_metadata: Optional[Dict[str, str]] = dict(), + ): + """ + Saves a key-value pair in the state store. + + Args: + key (str): The key to save. + value (Union[str, bytes]): The value to save. + state_metadata (Dict[str, str], optional): Dapr metadata for state request + """ + with DaprClient() as client: + client.save_state( + store_name=self.store_name, + key=key, + value=value, + state_metadata=state_metadata, + ) + + def save_bulk_state( + self, states: List[StateItem], metadata: Optional[Dict[str, str]] = None + ) -> None: + """ + Saves multiple key-value pairs to the state store in bulk. + + Args: + states (List[StateItem]): The list of key-value pairs to save. + metadata (Dict[str, str], optional): Metadata for the save request. + """ + with DaprClient() as client: + client.save_bulk_state( + store_name=self.store_name, states=states, metadata=metadata + ) + + def delete_state(self, key: str): + """ + Deletes a key-value pair from the state store. + + Args: + key (str): The key to delete. + """ + with DaprClient() as client: + client.delete_state(store_name=self.store_name, key=key) + + def query_state( + self, query: str, states_metadata: Optional[Dict[str, str]] = None + ) -> QueryResponse: + """ + Queries the state store with a specific query. + + Args: + query (str): The query to be executed (in JSON format). + states_metadata (Dict[str, str], optional): Custom metadata for the state request. + + Returns: + QueryResponse: Contains query results and metadata. + """ + with DaprClient() as client: + client.query_state( + store_name=self.store_name, query=query, states_metadata=states_metadata + ) diff --git a/dapr_agents/workflow/agentic.py b/dapr_agents/workflow/agentic.py index c26d48b0..49b1974b 100644 --- a/dapr_agents/workflow/agentic.py +++ b/dapr_agents/workflow/agentic.py @@ -1,9 +1,9 @@ import asyncio import json import logging -import random import time -from typing import Any, Callable, Dict, Optional, Tuple, Type, List, Union +from typing import Any, Callable, Dict, Optional, Tuple, Type, List + from cloudevents.http.conversion import from_http from cloudevents.http.event import CloudEvent from dapr.clients import DaprClient @@ -18,7 +18,12 @@ from pydantic import BaseModel, Field, PrivateAttr from dapr_agents.agents.utils.text_printer import ColorTextFormatter -from dapr_agents.memory import MemoryBase, ConversationVectorMemory +from dapr_agents.memory import ( + ConversationListMemory, + ConversationVectorMemory, + MemoryBase, +) +from dapr_agents.storage.daprstores.statestore import DaprStateStore from dapr_agents.workflow.base import WorkflowApp from dapr_agents.workflow.mixins import ( MessagingMixin, @@ -26,7 +31,6 @@ ServiceMixin, StateManagementMixin, ) -from dapr_agents.agents.memory_store import MemoryStore logger = logging.getLogger(__name__) @@ -52,34 +56,52 @@ class AgenticWorkflow( default=None, description="Default topic for broadcasting messages. Set explicitly for multi-agent setups.", ) - - memory_store: MemoryStore = Field( - ..., - description="The durable storage for workflow state and agent registration.", + state_store_name: str = Field( + ..., description="Dapr state store for workflow state." ) - - registry_store: Optional[str] = Field( + state_key: str = Field( + default="workflow_state", + description="Dapr state key for workflow state storage.", + ) + state: dict = Field( + default_factory=dict, description="Current state of the workflow." + ) + state_format: Optional[Type[BaseModel]] = Field( default=None, - description="Agent registry store name for storing static agent information.", + description=( + "Optional Pydantic model used to validate the persisted workflow " + "state. If provided, state loaded from storage is coerced to this " + "schema." + ), + ) + agents_registry_store_name: str = Field( + ..., description="Dapr state store for agent metadata." + ) + agents_registry_key: str = Field( + default="agents_registry", description="Key for agents registry in state store." ) - # TODO: test this is respected by runtime. max_iterations: int = Field( default=10, description="Maximum iterations for workflows.", ge=1 ) - - # Long term memory based on an execution run, so should be in the execution config class! - memory: Optional[MemoryBase] = Field( - default=None, + memory: MemoryBase = Field( + default_factory=ConversationListMemory, description="Handles conversation history storage.", ) - + save_state_locally: bool = Field( + default=True, description="Whether to save workflow state locally." + ) + local_state_path: Optional[str] = Field( + default=None, description="Local path for saving state files." + ) client: Optional[DaprClient] = Field( default=None, init=False, description="Dapr client instance." ) # Private internal attributes (not schema/validated) + _state_store_client: Optional[DaprStateStore] = PrivateAttr(default=None) _text_formatter: ColorTextFormatter = PrivateAttr(default=ColorTextFormatter) + _agent_metadata: Optional[Dict[str, Any]] = PrivateAttr(default=None) _workflow_name: str = PrivateAttr(default=None) _dapr_client: Optional[DaprClient] = PrivateAttr(default=None) _is_running: bool = PrivateAttr(default=False) @@ -106,15 +128,9 @@ def model_post_init(self, __context: Any) -> None: """ self._dapr_client = DaprClient() self._text_formatter = ColorTextFormatter() - - # Set storage key based on agent name - self.memory_store._set_key(self.name) - - logger.info(f"State store '{self.memory_store.name}' initialized.") + self._state_store_client = DaprStateStore(store_name=self.state_store_name) + logger.info(f"State store '{self.state_store_name}' initialized.") self.initialize_state() - if self.registry_store is None: - self.registry_store = self.memory_store.name - super().model_post_init(__context) def get_chat_history(self, task: Optional[str] = None) -> List[Dict[str, Any]]: @@ -138,33 +154,12 @@ def get_chat_history(self, task: Optional[str] = None) -> List[Dict[str, Any]]: query_embeddings = self.memory.vector_store.embedding_function.embed( task ) - vector_messages = self.memory.get_messages( - query_embeddings=query_embeddings - ) - if vector_messages: - return vector_messages - - # Get messages from storage - if self.memory_store._current_state is None: - logger.debug("Agent state is None, initializing empty state") - self.memory_store._current_state = {} - - # Get messages from all instances - all_messages = [] - for instance in self.memory_store._current_state.get("instances", {}).values(): - messages = instance.get("messages", []) - all_messages.extend(messages) - - # Get long-term memory from workflow state - long_term_memory = self.memory_store._current_state.get("chat_history", []) - all_messages.extend(long_term_memory) - - # If we have vector memory but no task, also include vector memory messages - if isinstance(self.memory, ConversationVectorMemory): - vector_messages = self.memory.get_messages() - all_messages.extend(vector_messages) - - return all_messages + messages = self.memory.get_messages(query_embeddings=query_embeddings) + else: + messages = self.memory.get_messages() + else: + messages = self.memory.get_messages() + return messages @property def chat_history(self) -> List[Dict[str, Any]]: @@ -217,12 +212,15 @@ def get_agents_metadata( """ try: agents_metadata = ( - self.get_data_from_store(self.registry_store, "agent_registry") or {} + self.get_data_from_store( + self.agents_registry_store_name, self.agents_registry_key + ) + or {} ) if agents_metadata: logger.info( - f"Agents found in '{self.registry_store}' for key 'agent_registry'." + f"Agents found in '{self.agents_registry_store_name}' for key '{self.agents_registry_key}'." ) filtered = { name: metadata @@ -237,7 +235,7 @@ def get_agents_metadata( return filtered logger.info( - f"No agents found in '{self.registry_store}' for key 'agent_registry'." + f"No agents found in '{self.agents_registry_store_name}' for key '{self.agents_registry_key}'." ) return {} except Exception as e: @@ -265,80 +263,17 @@ def print_interaction( ] self._text_formatter.print_colored_text(interaction_text) - async def run_workflow_from_request(self, request: Request) -> JSONResponse: - """ - Run a workflow instance triggered by an HTTP POST request. - - Args: - request: The incoming FastAPI request. - - Returns: - JSONResponse: HTTP response with workflow instance ID or error. - """ - try: - workflow_name = request.query_params.get("name") or self._workflow_name - if not workflow_name: - return JSONResponse( - content={"error": "No workflow name specified."}, - status_code=status.HTTP_400_BAD_REQUEST, - ) - - if workflow_name not in self.workflows: - return JSONResponse( - content={ - "error": f"Unknown workflow '{workflow_name}'. Available: {list(self.workflows.keys())}" - }, - status_code=status.HTTP_400_BAD_REQUEST, - ) - - try: - event: CloudEvent = from_http( - dict(request.headers), await request.body() - ) - input_data = event.data - except Exception: - input_data = await request.json() - - logger.info(f"Starting workflow '{workflow_name}' with input: {input_data}") - instance_id = await self.run_and_monitor_workflow_async( - workflow=workflow_name, input=input_data - ) - - return JSONResponse( - content={ - "message": "Workflow initiated successfully.", - "workflow_instance_id": instance_id, - }, - status_code=status.HTTP_202_ACCEPTED, - ) - except Exception as e: - logger.error(f"Error starting workflow: {str(e)}", exc_info=True) - return JSONResponse( - content={"error": "Failed to start workflow", "details": str(e)}, - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - - # TODO: Delete this once we rm orchestrators in favor of agents as tools. def register_agent( self, store_name: str, store_key: str, agent_name: str, agent_metadata: dict ) -> None: """ Merges the existing data with the new data and updates the store. - Only works for agents with Dapr client access (AgenticWorkflow subclasses). Args: store_name (str): The name of the Dapr state store component. - store_key (str): The key to update. - agent_name (str): The name of the agent to register. - agent_metadata (dict): The metadata to register for the agent. + key (str): The key to update. + data (dict): The data to update the store with. """ - # Only proceed if agent has Dapr client - if not hasattr(self, "_dapr_client"): - logger.debug( - f"Agent '{self.name}' does not have Dapr client, skipping registration" - ) - return - # retry the entire operation up to twenty times sleeping 1-2 seconds between each # TODO: rm the custom retry logic here and use the DaprClient retry_policy instead. for attempt in range(1, 21): @@ -364,9 +299,7 @@ def register_agent( ) # raise an exception to retry the entire operation raise Exception(f"No etag found for key: {store_key}") - existing_data = ( - self._deserialize_state(response.data) if response.data else {} - ) + existing_data = json.loads(response.data) if response.data else {} if (agent_name, agent_metadata) in existing_data.items(): logger.debug(f"agent {agent_name} already registered.") return None @@ -396,6 +329,8 @@ def register_agent( except Exception as e: logger.error(f"Error on transaction attempt: {attempt}: {e}") # Add random jitter + import random + delay = 1 + random.uniform(0, 1) # 1-2 seconds logger.info( f"Sleeping for {delay:.2f} seconds before retrying transaction..." @@ -405,24 +340,73 @@ def register_agent( f"Failed to update state store key: {store_key} after 20 attempts." ) - def _deserialize_state(self, raw: Union[bytes, str, dict]) -> dict: + def register_agentic_system(self) -> None: """ - Convert Dapr's raw payload (bytes, JSON string, or already a dict) into a dict. - Raises helpful errors on failure. + Register this agent's metadata in the Dapr state store. + + Raises: + Exception: If registration fails. + """ + try: + self.register_agent( + store_name=self.agents_registry_store_name, + store_key=self.agents_registry_key, + agent_name=self.name, + agent_metadata=self._agent_metadata, + ) + except Exception as e: + logger.error(f"Failed to register metadata for agent {self.name}: {e}") + raise e + + async def run_workflow_from_request(self, request: Request) -> JSONResponse: """ - if isinstance(raw, dict): - return raw + Run a workflow instance triggered by an HTTP POST request. - if isinstance(raw, bytes): - try: - raw = raw.decode("utf-8") - except UnicodeDecodeError as exc: - raise ValueError("State bytes are not valid UTF-8") from exc + Args: + request: The incoming FastAPI request. + + Returns: + JSONResponse: HTTP response with workflow instance ID or error. + """ + try: + workflow_name = request.query_params.get("name") or self._workflow_name + if not workflow_name: + return JSONResponse( + content={"error": "No workflow name specified."}, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + if workflow_name not in self.workflows: + return JSONResponse( + content={ + "error": f"Unknown workflow '{workflow_name}'. Available: {list(self.workflows.keys())}" + }, + status_code=status.HTTP_400_BAD_REQUEST, + ) - if isinstance(raw, str): try: - return json.loads(raw) - except json.JSONDecodeError as exc: - raise ValueError(f"State is not valid JSON: {exc}") from exc + event: CloudEvent = from_http( + dict(request.headers), await request.body() + ) + input_data = event.data + except Exception: + input_data = await request.json() + + logger.info(f"Starting workflow '{workflow_name}' with input: {input_data}") + instance_id = await self.run_and_monitor_workflow_async( + workflow=workflow_name, input=input_data + ) - raise TypeError(f"Unsupported state type {type(raw)!r}") + return JSONResponse( + content={ + "message": "Workflow initiated successfully.", + "workflow_instance_id": instance_id, + }, + status_code=status.HTTP_202_ACCEPTED, + ) + except Exception as e: + logger.error(f"Error starting workflow: {str(e)}", exc_info=True) + return JSONResponse( + content={"error": "Failed to start workflow", "details": str(e)}, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) diff --git a/dapr_agents/workflow/mixins/state.py b/dapr_agents/workflow/mixins/state.py index 0ad484dd..03a78362 100644 --- a/dapr_agents/workflow/mixins/state.py +++ b/dapr_agents/workflow/mixins/state.py @@ -3,10 +3,9 @@ import os import tempfile import threading -from datetime import datetime from typing import Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError logger = logging.getLogger(__name__) @@ -18,146 +17,6 @@ class StateManagementMixin: Mixin providing workflow state initialization, validation, and persistence. """ - def _reconcile_workflow_statuses(self) -> None: - """ - Reconcile workflow statuses between our Redis state and Dapr's actual workflow state. - - This method checks Dapr's actual status and updates our state to match, - preventing stale "running" workflows from blocking new executions. - """ - from dapr.clients import DaprClient - - instances = self.memory_store._current_state["instances"] - updated_instances = [] - - for instance_id, instance_data in instances.items(): - our_status = instance_data["status"].lower() - - # Only check running instances (completed/failed instances are already finalized) - if our_status in ["running", "pending"]: - try: - # Query Dapr for the actual workflow status - with DaprClient() as client: - state = client.get_workflow( - instance_id=instance_id, - ) - - dapr_status = state.runtime_status.upper() - - # If Dapr says FAILED/TERMINATED but we say RUNNING, update our state - if dapr_status in ["FAILED", "TERMINATED", "CANCELED"]: - logger.warning( - f"Workflow {instance_id} is {dapr_status} in Dapr but 'running' in Redis. " - f"Updating Redis state to match Dapr." - ) - instance_data["status"] = dapr_status.lower() - instance_data["end_time"] = datetime.now().isoformat() - updated_instances.append(instance_id) - - # Save the updated instance back to Redis - instance_key = self.memory_store._get_instance_key( - instance_id - ) - self.memory_store._save_state_with_metadata( - instance_key, instance_data - ) - - elif dapr_status == "COMPLETED": - logger.info( - f"Workflow {instance_id} completed in Dapr. Updating Redis state." - ) - instance_data["status"] = "completed" - if not instance_data["end_time"]: - instance_data["end_time"] = datetime.now().isoformat() - updated_instances.append(instance_id) - - # Save the updated instance - instance_key = self.memory_store._get_instance_key( - instance_id - ) - self.memory_store._save_state_with_metadata( - instance_key, instance_data - ) - - except Exception as e: - logger.debug( - f"Could not query Dapr status for workflow {instance_id}: {e}. " - f"Instance may have been purged or not exist in Dapr yet." - ) - - if updated_instances: - logger.info( - f"Reconciled {len(updated_instances)} workflow status(es) with Dapr: {updated_instances}" - ) - - def _has_valid_message_sequence(self, instance_data: dict) -> bool: - """ - Validate that all assistant messages with tool_calls have corresponding tool responses. - This prevents loading instances with incomplete tool call sequences that would break LLM calls. - - Args: - instance_data: The workflow instance data to validate - - Returns: - bool: True if message sequence is valid, False otherwise - """ - messages = instance_data["messages"] - - # Collect all tool_call_ids that need responses - pending_tool_calls = set() - - for msg in messages: - msg_dict = ( - msg - if isinstance(msg, dict) - else (msg.model_dump() if hasattr(msg, "model_dump") else {}) - ) - role = msg_dict.get("role", []) - - if role == "assistant" and msg_dict.get("tool_calls", []): - # Add all tool_call_ids from this assistant message - for tool_call in msg_dict.get("tool_calls", []): - if isinstance(tool_call, dict): - pending_tool_calls.add(tool_call["id"]) - - elif role == "tool": - # Remove this tool_call_id as it has a response - tool_call_id = msg_dict.get("tool_call_id", []) - if tool_call_id in pending_tool_calls: - pending_tool_calls.remove(tool_call_id) - - # If there are still pending tool calls, the sequence is invalid - if pending_tool_calls: - logger.debug( - f"Invalid message sequence: pending tool_call_ids: {pending_tool_calls}" - ) - return False - - return True - - # TODO: Delete this once we rm orchestrators in favor of agents as tools. - @property - def state(self) -> dict: - """ - Get the current workflow state. - - Returns: - dict: The current workflow state. - """ - return self.memory_store._current_state if hasattr(self, "storage") else {} - - # TODO: Delete this once we rm orchestrators in favor of agents as tools. - @state.setter - def state(self, value: dict) -> None: - """ - Set the current workflow state. - - Args: - value (dict): The new workflow state. - """ - if hasattr(self, "storage"): - self.memory_store._current_state = value - def initialize_state(self) -> None: """ Initialize workflow state from provided value or storage. @@ -166,30 +25,52 @@ def initialize_state(self) -> None: RuntimeError: If state initialization or loading from storage fails. """ try: - if self.memory_store._current_state is None: + if self.state is None: logger.debug("No user-provided state. Attempting to load from storage.") - self.memory_store._current_state = self.load_state() + self.state = self.load_state() - if isinstance(self.memory_store._current_state, BaseModel): + if isinstance(self.state, BaseModel): logger.debug( "User provided a state as a Pydantic model. Converting to dict." ) - self.memory_store._current_state = ( - self.memory_store._current_state.model_dump() - ) + self.state = self.state.model_dump() - if not isinstance(self.memory_store._current_state, dict): + if not isinstance(self.state, dict): raise TypeError( - f"Invalid state type: {type(self.memory_store._current_state)}. Expected dict." + f"Invalid state type: {type(self.state)}. Expected dict." ) - logger.debug( - f"Workflow state initialized with {len(self.memory_store._current_state)} key(s)." - ) + logger.debug(f"Workflow state initialized with {len(self.state)} key(s).") self.save_state() except Exception as e: raise RuntimeError(f"Error initializing workflow state: {e}") from e + def validate_state(self, state_data: dict) -> dict: + """ + Validate the workflow state against ``state_format`` if provided. + + Args: + state_data: The raw state data to validate. + + Returns: + dict: The validated and structured state. + + Raises: + ValidationError: If the state data does not conform to the expected schema. + """ + try: + if not self.state_format: + logger.warning( + "No schema (state_format) provided; returning state as-is." + ) + return state_data + + logger.debug("Validating workflow state against schema.") + validated_state: BaseModel = self.state_format(**state_data) + return validated_state.model_dump() + except ValidationError as e: + raise ValidationError(f"Invalid workflow state: {e.errors()}") from e + def load_state(self) -> dict: """ Load the workflow state from the configured Dapr state store. @@ -204,105 +85,45 @@ def load_state(self) -> dict: """ try: if ( - not self._dapr_client - or not self.memory_store.name - or not self.memory_store._key + not self._state_store_client + or not self.state_store_name + or not self.state_key ): logger.error("State store is not configured. Cannot load state.") raise RuntimeError( - "State store is not configured. Please provide 'storage.name'." + "State store is not configured. Please provide 'state_store_name' and 'state_key'." ) # For durable agents, always load from database to ensure it's the source of truth - response = self._dapr_client.get_state( - self.memory_store.name, self.memory_store._key + has_state, state_data = self._state_store_client.try_get_state( + self.state_key ) - if response.data: - state_data = self._deserialize_state(response.data) - self.memory_store._current_state = state_data - else: - self.memory_store._current_state = {} - - # Load workflow instances from ALL sessions to support workflow resumption after restart - # This ensures that if the app crashes mid-workflow and restarts, all in-flight - # workflows across all sessions will be loaded and can be resumed by Dapr - # Always ensure "instances" key exists - self.memory_store._current_state.setdefault("instances", {}) - - # Get all sessions for this agent - sessions_index_key = self.memory_store._get_sessions_index_key() - response = self._dapr_client.get_state( - self.memory_store.name, sessions_index_key - ) - - if response.data: - index_data = self._deserialize_state(response.data) - session_ids = index_data.get("sessions", []) + if has_state and state_data: logger.debug( - f"Found {len(session_ids)} session(s) for agent '{self.memory_store._agent_name}'" + f"Existing state found for key '{self.state_key}'. Validating it." ) - - # Load workflow instances from each session - for session_id in session_ids: - session_key = self.memory_store._get_session_key(session_id) - response = self._dapr_client.get_state( - self.memory_store.name, session_key + if not isinstance(state_data, dict): + raise TypeError( + f"Invalid state type retrieved: {type(state_data)}. Expected dict." ) - if response.data: - session_data = self._deserialize_state(response.data) + # Set self.state to the loaded data + if self.state_format: + loaded_state = self.validate_state(state_data) + else: + loaded_state = state_data - instance_ids = session_data.get("workflow_instances", []) - logger.debug( - f"Loading {len(instance_ids)} instance(s) from session '{session_id}'" - ) + self.state = loaded_state + logger.debug(f"Set self.state to loaded data: {self.state}") - # Load each instance - for instance_id in instance_ids: - instance_key = self.memory_store._get_instance_key( - instance_id - ) - response = self._dapr_client.get_state( - self.memory_store.name, instance_key - ) - if response.data: - instance_data = self._deserialize_state(response.data) - - # Validate message sequence before loading, but ONLY for completed workflows - # Running workflows are expected to have incomplete sequences mid-execution - status = instance_data["status"].lower() - if status in ["running", "pending"]: - # Always load running/pending instances (they're allowed to be incomplete) - self.memory_store._current_state["instances"][ - instance_id - ] = instance_data - logger.debug( - f"Loaded active workflow instance {instance_id} from key '{instance_key}' (session: {session_id}, status: {status})" - ) - elif self._has_valid_message_sequence(instance_data): - # For completed/failed instances, validate message sequence - self.memory_store._current_state["instances"][ - instance_id - ] = instance_data - logger.debug( - f"Loaded completed workflow instance {instance_id} from key '{instance_key}' (session: {session_id}, status: {status})" - ) - else: - logger.warning( - f"Skipping completed instance {instance_id} due to invalid message sequence (incomplete tool calls, status: {status})" - ) - - # Reconcile workflow statuses with Dapr's actual state - self._reconcile_workflow_statuses() + return loaded_state logger.debug( - f"Set self.memory_store._current_state to loaded data: {self.memory_store._current_state}" + f"No existing state found for key '{self.state_key}'. Initializing empty state." ) - return self.memory_store._current_state + return {} except Exception as e: - logger.error( - f"Failed to load state for key '{self.memory_store._key}': {e}" - ) + logger.error(f"Failed to load state for key '{self.state_key}': {e}") raise RuntimeError(f"Error loading workflow state: {e}") from e def get_local_state_file_path(self) -> str: @@ -312,17 +133,9 @@ def get_local_state_file_path(self) -> str: Returns: str: The absolute path to the local state file. """ - if not self.memory_store.local_directory: - return os.path.join(os.getcwd(), f"{self.name}_state.json") - os.makedirs(self.memory_store.local_directory, exist_ok=True) - - # If relative path, make it absolute from workspace root - if not os.path.isabs(self.memory_store.local_directory): - abs_path = os.path.join(os.getcwd(), self.memory_store.local_directory) - else: - abs_path = self.memory_store.local_directory - - return os.path.join(abs_path, f"{self.name}_state.json") + directory = self.local_state_path or os.getcwd() + os.makedirs(directory, exist_ok=True) + return os.path.join(directory, f"{self.state_key}.json") def save_state_to_disk( self, state_data: str, filename: Optional[str] = None @@ -338,9 +151,10 @@ def save_state_to_disk( RuntimeError: If saving to disk fails. """ try: - file_path = filename or self.get_local_state_file_path() - save_directory = os.path.dirname(file_path) + save_directory = self.local_state_path or os.getcwd() os.makedirs(save_directory, exist_ok=True) + filename = filename or f"{self.name}_state.json" + file_path = os.path.join(save_directory, filename) with tempfile.NamedTemporaryFile( "w", dir=save_directory, delete=False @@ -395,102 +209,44 @@ def save_state( """ try: if ( - not self._dapr_client - or not self.memory_store.name - or not self.memory_store._key + not self._state_store_client + or not self.state_store_name + or not self.state_key ): logger.error("State store is not configured. Cannot save state.") raise RuntimeError( - "State store is not configured. Please provide 'storage.name'." + "State store is not configured. Please provide 'state_store_name' and 'state_key'." ) - self.memory_store._current_state = state or self.memory_store._current_state - if not self.memory_store._current_state: + self.state = state or self.state + if not self.state: logger.warning("Skipping state save: Empty state.") return - if isinstance(self.memory_store._current_state, BaseModel): - state_to_save = self.memory_store._current_state.model_dump_json() - elif isinstance(self.memory_store._current_state, dict): - state_to_save = json.dumps(self.memory_store._current_state) - elif isinstance(self.memory_store._current_state, str): + if isinstance(self.state, BaseModel): + state_to_save = self.state.model_dump_json() + elif isinstance(self.state, dict): + state_to_save = json.dumps(self.state) + elif isinstance(self.state, str): try: - json.loads(self.memory_store._current_state) + json.loads(self.state) except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON string provided as state: {e}") - state_to_save = self.memory_store._current_state + state_to_save = self.state else: raise TypeError( - f"Invalid state type: {type(self.memory_store._current_state)}. Expected dict, BaseModel, or JSON string." + f"Invalid state type: {type(self.state)}. Expected dict, BaseModel, or JSON string." ) - # Save each workflow instance separately - if "instances" in self.memory_store._current_state: - for instance_id, instance_data in self.memory_store._current_state[ - "instances" - ].items(): - instance_key = self.memory_store._get_instance_key(instance_id) - # Handle both dict and already-serialized string - if isinstance(instance_data, dict): - instance_json = json.dumps(instance_data) - elif isinstance(instance_data, str): - instance_json = instance_data - else: - instance_json = json.dumps(instance_data) - self._dapr_client.save_state( - self.memory_store.name, instance_key, instance_json - ) - logger.debug( - f"Saved workflow instance {instance_id} to key '{instance_key}'" - ) - - # Save other state data (like chat_history) to main key - other_state = { - k: v - for k, v in self.memory_store._current_state.items() - if k != "instances" - } - if other_state: - other_state_json = json.dumps(other_state) - self._dapr_client.save_state( - self.memory_store.name, self.memory_store._key, other_state_json - ) - logger.debug( - f"Saved non-instance state to key '{self.memory_store._key}'" - ) + self._state_store_client.save_state(self.state_key, state_to_save) + logger.debug(f"Successfully saved state for key '{self.state_key}'.") - if self.memory_store.local_directory is not None: + if self.save_state_locally: self.save_state_to_disk(state_data=state_to_save) if force_reload: - self.memory_store._current_state = self.load_state() - logger.debug( - f"State reloaded after saving for key '{self.memory_store._key}'." - ) + self.state = self.load_state() + logger.debug(f"State reloaded after saving for key '{self.state_key}'.") except Exception as e: - logger.error( - f"Failed to save state for key '{self.memory_store._key}': {e}" - ) + logger.error(f"Failed to save state for key '{self.state_key}': {e}") raise - - def _deserialize_state(self, raw: Union[bytes, str, dict]) -> dict: - """ - Convert Dapr's raw payload (bytes, JSON string, or already a dict) into a dict. - Raises helpful errors on failure. - """ - if isinstance(raw, dict): - return raw - - if isinstance(raw, bytes): - try: - raw = raw.decode("utf-8") - except UnicodeDecodeError as exc: - raise ValueError("State bytes are not valid UTF-8") from exc - - if isinstance(raw, str): - try: - return json.loads(raw) - except json.JSONDecodeError as exc: - raise ValueError(f"State is not valid JSON: {exc}") from exc - - raise TypeError(f"Unsupported state type {type(raw)!r}") diff --git a/dapr_agents/workflow/orchestrators/base.py b/dapr_agents/workflow/orchestrators/base.py index 6bd01f0c..8ec06f86 100644 --- a/dapr_agents/workflow/orchestrators/base.py +++ b/dapr_agents/workflow/orchestrators/base.py @@ -1,6 +1,7 @@ import logging from abc import ABC, abstractmethod from typing import Any, Optional + from dapr.ext.workflow import DaprWorkflowContext from pydantic import Field, model_validator @@ -32,22 +33,15 @@ def model_post_init(self, __context: Any) -> None: super().model_post_init(__context) # Prepare agent metadata - agent_metadata = { + self._agent_metadata = { "name": self.name, "topic_name": self.orchestrator_topic_name, "pubsub_name": self.message_bus_name, "orchestrator": True, } - if self.memory_store and self.memory_store.name: - if self.registry_store is None: - self.registry_store = self.memory_store.name - self.register_agent( - store_name=self.registry_store, - store_key="agent_registry", - agent_name=self.name, - agent_metadata=self._serialize_metadata(agent_metadata), - ) + # Register agent metadata + self.register_agentic_system() # Start the runtime if it's not already running self.start_runtime() @@ -66,6 +60,11 @@ def main_workflow(self, ctx: DaprWorkflowContext, message: Any) -> Any: """ pass + @abstractmethod + async def process_agent_response(self, message: Any) -> None: + """Process responses from agents.""" + pass + @abstractmethod async def broadcast_message_to_agents(self, **kwargs) -> None: """Broadcast a message to all registered agents.""" @@ -75,22 +74,3 @@ async def broadcast_message_to_agents(self, **kwargs) -> None: async def trigger_agent(self, name: str, instance_id: str, **kwargs) -> None: """Trigger a specific agent to perform an action.""" pass - - def _serialize_metadata(self, metadata: Any) -> Any: - """ - Recursively convert Pydantic models (e.g., AgentTool), lists, dicts to JSON-serializable format. - Handles mixed tools: [AgentTool(...), "string", ...] → [{"name": "..."}, "string", ...] - """ - - def convert(obj: Any) -> Any: - if hasattr(obj, "model_dump"): - return obj.model_dump() - if hasattr(obj, "dict"): - return obj.dict() - if isinstance(obj, (list, tuple)): - return [convert(i) for i in obj] - if isinstance(obj, dict): - return {k: convert(v) for k, v in obj.items()} - return obj - - return convert(metadata) diff --git a/dapr_agents/workflow/orchestrators/llm/orchestrator.py b/dapr_agents/workflow/orchestrators/llm/orchestrator.py index 2f4ad380..aad561b5 100644 --- a/dapr_agents/workflow/orchestrators/llm/orchestrator.py +++ b/dapr_agents/workflow/orchestrators/llm/orchestrator.py @@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional from dapr.ext.workflow import DaprWorkflowContext -from dapr_agents.types.workflow import DaprWorkflowStatus from pydantic import Field from dapr_agents.workflow.decorators import message_router, task, workflow @@ -29,6 +28,7 @@ from dapr_agents.workflow.orchestrators.llm.state import ( LLMWorkflowEntry, LLMWorkflowMessage, + LLMWorkflowState, PlanStep, TaskResult, ) @@ -37,7 +37,7 @@ restructure_plan, update_step_statuses, ) -from dapr_agents.agents.memory_store import MemoryStore +from dapr_agents.memory import ConversationDaprStateMemory logger = logging.getLogger(__name__) @@ -54,11 +54,9 @@ class LLMOrchestrator(OrchestratorWorkflowBase): default=None, description="The current workflow instance ID for this orchestrator.", ) - memory_store: MemoryStore = Field( - default_factory=lambda: MemoryStore( - name="workflowstatestore", session_id="orchestrator_session" - ), - description="Persistent storage with session-based state hydration.", + memory: Optional[ConversationDaprStateMemory] = Field( + default=None, + description="Persistent memory with session-based state hydration.", ) def model_post_init(self, __context: Any) -> None: @@ -74,28 +72,29 @@ def model_post_init(self, __context: Any) -> None: # TODO(@Sicoyle): fix this later!! self._is_orchestrator = True # Flag for PubSub deduplication to prevent orchestrator workflows from being triggered multiple times - if not self.memory_store._current_state: - self.memory_store._current_state = {"instances": {}} + if not self.state: + logger.debug("No state found, initializing empty state") + self.state = {"instances": {}} + + if self.memory is not None: + self.memory = ConversationDaprStateMemory( + store_name=self.memory.store_name, + session_id=f"{self.name or 'orchestrator'}_session", + ) - if not self.memory_store.name: - raise ValueError("LLMOrchestrator must have a name for persistent storage") + print(f"sam memory store name is {self.memory.store_name}") + logger.info(f"sam memory store name is {self.memory.store_name}") # Load the current workflow instance ID from state using session_id) - if self.memory_store._current_state and self.memory_store._current_state.get( - "instances" - ): - logger.debug( - f"Found {len(self.memory_store._current_state['instances'])} instances in state" - ) + if self.state and self.state.get("instances"): + logger.debug(f"Found {len(self.state['instances'])} instances in state") current_session_id = ( - self.memory_store.session_id - if self.memory_store + self.memory.session_id + if self.memory else f"{self.name}_default_session" ) - for instance_id, instance_data in self.memory_store._current_state[ - "instances" - ].items(): + for instance_id, instance_data in self.state["instances"].items(): stored_workflow_name = instance_data.get("workflow_name") stored_session_id = instance_data.get("session_id") logger.debug( @@ -196,14 +195,9 @@ def main_workflow(self, ctx: DaprWorkflowContext, message: TriggerAction): # Step 1: Retrieve initial task and ensure state entry exists task = message.get("task") instance_id = ctx.instance_id - self.memory_store._current_state.setdefault("instances", {}).setdefault( + self.state.setdefault("instances", {}).setdefault( instance_id, LLMWorkflowEntry(input=task).model_dump(mode="json") ) - - # Update session index to track this workflow instance - if not ctx.is_replaying: - self.memory_store._update_session_index(instance_id) - # Initialize plan as empty list - it will be set after turn 1 plan = [] final_summary: Optional[str] = None @@ -221,7 +215,7 @@ def main_workflow(self, ctx: DaprWorkflowContext, message: TriggerAction): # On turn 1, atomically generate plan and broadcast task if turn == 1: if not ctx.is_replaying: - logger.debug(f"Initial message from User -> {self.name}") + logger.info(f"Initial message from User -> {self.name}") init_result = yield ctx.call_activity( self.initialize_workflow_with_plan, @@ -232,7 +226,7 @@ def main_workflow(self, ctx: DaprWorkflowContext, message: TriggerAction): "wf_time": ctx.current_utc_datetime.isoformat(), }, ) - logger.debug(f"Workflow initialized with plan: {init_result['status']}") + logger.info(f"Workflow initialized with plan: {init_result['status']}") plan = init_result["plan"] # Determine next step and dispatch @@ -241,16 +235,7 @@ def main_workflow(self, ctx: DaprWorkflowContext, message: TriggerAction): # If plan is empty, read from workflow state if not plan_objects: - if ( - self.memory_store._current_state - and "instances" in self.memory_store._current_state - and instance_id in self.memory_store._current_state["instances"] - ): - plan_objects = self.memory_store._current_state["instances"][ - instance_id - ].get("plan", []) - else: - plan_objects = [] + plan_objects = self.state["instances"][instance_id].get("plan", []) plan = plan_objects next_step = yield ctx.call_activity( self.generate_next_step, @@ -320,7 +305,7 @@ def main_workflow(self, ctx: DaprWorkflowContext, message: TriggerAction): else: task_results = yield event_data if not ctx.is_replaying: - logger.debug(f"{task_results['name']} sent a response.") + logger.info(f"{task_results['name']} sent a response.") # Atomically process agent response, update history, check progress, and update plan response_result = yield ctx.call_activity( @@ -674,68 +659,77 @@ async def initialize_workflow_with_plan( try: # Look for existing plan using session_id existing_plan = None - if ( - self.memory_store._current_state - and self.memory_store._current_state.get("instances") - ): - logger.debug( - f"Found {len(self.memory_store._current_state['instances'])} instances in state" + for stored_instance_id, instance_data in self.state.get( + "instances", {} + ).items(): + stored_session_id = instance_data.get("session_id") + current_session_id = ( + self.memory.session_id + if self.memory + else f"{self.name}_default_session" ) - for instance_id, instance_data in self.memory_store._current_state.get( - "instances", {} - ).items(): - # stored_session_id = instance_data.get("session_id") - # if stored_session_id == self.memory.session_id: + if stored_session_id == current_session_id: existing_plan = instance_data.get("plan", []) - logger.info( - f"Found existing plan for session_id {self.memory_store.session_id}" + logger.debug( + f"Found existing plan for session_id {current_session_id} in instance {stored_instance_id}" ) - # break + break - if existing_plan: - logger.info( - f"Found existing plan in workflow state, reusing it: {len(existing_plan)} steps" - ) - plan_objects = existing_plan + if existing_plan: + logger.debug( + f"Found existing plan in workflow state, reusing it: {len(existing_plan)} steps" + ) + plan_objects = existing_plan + else: + # Generate new plan using the LLM + logger.debug( + "No existing plan found in workflow state, generating new plan" + ) + response = self.llm.generate( + messages=[ + { + "role": "user", + "content": TASK_PLANNING_PROMPT.format( + task=task, + agents=agents, + plan_schema=schemas.plan, + ), + } + ], + response_format=IterablePlanStep, + structured_mode="json", + ) + + # Parse the response + if isinstance(response, str): + # If it's a raw JSON string + plan_dict = json.loads(response) + plan_objects = [ + PlanStep(**step_dict) + for step_dict in plan_dict.get("objects", []) + ] + elif hasattr(response, "choices") and response.choices: + # If it's an OpenAI-style response with multiple choices + plan_objects = [] + for choice in response.choices: + plan_data = choice.message.content + if isinstance(plan_data, str): + plan_dict = json.loads(plan_data) + plan_objects.extend( + PlanStep(**step_dict) + for step_dict in plan_dict.get("objects", []) + ) + elif hasattr(plan_data, "objects"): + plan_objects.extend(plan_data.objects) else: - # Generate new plan using the LLM - logger.info( - "No existing plan found in workflow state, generating new plan" - ) - response = self.llm.generate( - messages=[ - { - "role": "user", - "content": TASK_PLANNING_PROMPT.format( - task=task, - agents=agents, - plan_schema=schemas.plan, - ), - } - ], - response_format=IterablePlanStep, - structured_mode="json", + # If it's already a Pydantic model + plan_objects = ( + response.objects if hasattr(response, "objects") else [] ) - # Parse the response - now we get a Pydantic model directly - if hasattr(response, "choices") and response.choices: - # If it's still a raw response, parse it - plan_data = response.choices[0].message.content - logger.debug(f"Plan generation response: {plan_data}") - plan_dict = json.loads(plan_data) - # Convert raw dictionaries to Pydantic models - plan_objects = [ - PlanStep(**step_dict) - for step_dict in plan_dict.get("objects", []) - ] - else: - # If it's already a Pydantic model - plan_objects = ( - response.objects if hasattr(response, "objects") else [] - ) - logger.debug( - f"Plan generation response (Pydantic): {plan_objects}" - ) + logger.debug( + f"Plan generation response with {len(plan_objects)} objects: {plan_objects}" + ) # Format and broadcast message plan_dicts = self._convert_plan_objects_to_dicts(plan_objects) @@ -1047,13 +1041,7 @@ async def update_workflow_state( Raises: ValueError: If the workflow instance ID is not found in the local state. """ - if ( - not self.memory_store._current_state - or "instances" not in self.memory_store._current_state - ): - raise ValueError("No workflow instances found in local state.") - - workflow_entry = self.memory_store._current_state["instances"].get(instance_id) + workflow_entry = self.state["instances"].get(instance_id) if not workflow_entry: raise ValueError( f"No workflow entry found for instance_id {instance_id} in local state." @@ -1069,12 +1057,9 @@ async def update_workflow_state( workflow_entry["messages"].append(serialized_message) workflow_entry["last_message"] = serialized_message - # Update the local chat history - if not self.memory_store._current_state: - self.memory_store._current_state = {} - if "chat_history" not in self.memory_store._current_state: - self.memory_store._current_state["chat_history"] = [] - self.memory_store._current_state["chat_history"].append(serialized_message) + # Update the local chat history if memory is enabled + if self.memory: + self.memory.add_message(message) if final_output is not None: workflow_entry["output"] = final_output @@ -1085,9 +1070,7 @@ async def update_workflow_state( workflow_entry["workflow_instance_id"] = instance_id workflow_entry["workflow_name"] = self._workflow_name workflow_entry["session_id"] = ( - self.memory_store.session_id - if self.memory_store - else f"{self.name}_default_session" + self.memory.session_id if self.memory else f"{self.name}_default_session" ) # Persist updated state @@ -1164,14 +1147,8 @@ async def trigger_agent_internal( f"Triggering agent {name} for step {step}, substep {substep} (Instance ID: {instance_id})" ) - # Get the workflow entry from self.memory_store._current_state - if ( - not self.memory_store._current_state - or "instances" not in self.memory_store._current_state - ): - raise ValueError("No workflow instances found in local state.") - - workflow_entry = self.memory_store._current_state["instances"].get(instance_id) + # Get the workflow entry from self.state + workflow_entry = self.state["instances"].get(instance_id) if not workflow_entry: raise ValueError(f"No workflow entry found for instance_id: {instance_id}") @@ -1224,13 +1201,7 @@ async def update_task_history_internal( await self.update_workflow_state(instance_id=instance_id, message=results) # Retrieve Workflow state - if ( - not self.memory_store._current_state - or "instances" not in self.memory_store._current_state - ): - raise ValueError("No workflow instances found in local state.") - - workflow_entry = self.memory_store._current_state["instances"].get(instance_id) + workflow_entry = self.state["instances"].get(instance_id) if not workflow_entry: raise ValueError(f"No workflow entry found for instance_id: {instance_id}") @@ -1363,13 +1334,9 @@ async def rollback_workflow_initialization(self, instance_id: str): Rollback workflow initialization by clearing partial state. """ try: - if ( - self.memory_store._current_state - and "instances" in self.memory_store._current_state - and instance_id in self.memory_store._current_state["instances"] - ): + if instance_id in self.state["instances"]: # Clear the plan if it was partially created - self.memory_store._current_state["instances"][instance_id]["plan"] = [] + self.state["instances"][instance_id]["plan"] = [] self.save_state() logger.info(f"Rolled back workflow initialization for {instance_id}") except Exception as e: @@ -1382,24 +1349,16 @@ async def rollback_agent_trigger( Rollback agent trigger by reverting step status. """ try: - if ( - self.memory_store._current_state - and "instances" in self.memory_store._current_state - ): - workflow_entry = self.memory_store._current_state["instances"].get( - instance_id - ) - if workflow_entry and "plan" in workflow_entry: - plan = workflow_entry["plan"] - step_entry = find_step_in_plan(plan, step_id, substep_id) - if step_entry and step_entry["status"] == "in_progress": - step_entry["status"] = "not_started" - await self.update_workflow_state( - instance_id=instance_id, plan=plan - ) - logger.info( - f"Rolled back agent trigger for step {step_id}, substep {substep_id}" - ) + workflow_entry = self.state["instances"].get(instance_id) + if workflow_entry and "plan" in workflow_entry: + plan = workflow_entry["plan"] + step_entry = find_step_in_plan(plan, step_id, substep_id) + if step_entry and step_entry["status"] == "in_progress": + step_entry["status"] = "not_started" + await self.update_workflow_state(instance_id=instance_id, plan=plan) + logger.info( + f"Rolled back agent trigger for step {step_id}, substep {substep_id}" + ) except Exception as e: logger.error(f"Failed to rollback agent trigger: {e}") @@ -1410,44 +1369,35 @@ async def rollback_agent_response_processing( Rollback agent response processing by reverting changes. """ try: - if ( - self.memory_store._current_state - and "instances" in self.memory_store._current_state - ): - workflow_entry = self.memory_store._current_state["instances"].get( - instance_id - ) - if workflow_entry: - # Remove the last task result if it was added - if ( - "task_history" in workflow_entry - and workflow_entry["task_history"] - ): - # Find and remove the last entry for this agent/step - task_history = workflow_entry["task_history"] - for i in range(len(task_history) - 1, -1, -1): - task = task_history[i] - if ( - task.get("agent") == agent - and task.get("step") == step_id - and task.get("substep") == substep_id - ): - task_history.pop(i) - break - - # Revert step status if it was changed - if "plan" in workflow_entry: - plan = workflow_entry["plan"] - step_entry = find_step_in_plan(plan, step_id, substep_id) - if step_entry and step_entry["status"] == "completed": - step_entry["status"] = "in_progress" - await self.update_workflow_state( - instance_id=instance_id, plan=plan - ) + workflow_entry = self.state["instances"].get(instance_id) + if workflow_entry: + # Remove the last task result if it was added + if "task_history" in workflow_entry and workflow_entry["task_history"]: + # Find and remove the last entry for this agent/step + task_history = workflow_entry["task_history"] + for i in range(len(task_history) - 1, -1, -1): + task = task_history[i] + if ( + task.get("agent") == agent + and task.get("step") == step_id + and task.get("substep") == substep_id + ): + task_history.pop(i) + break + + # Revert step status if it was changed + if "plan" in workflow_entry: + plan = workflow_entry["plan"] + step_entry = find_step_in_plan(plan, step_id, substep_id) + if step_entry and step_entry["status"] == "completed": + step_entry["status"] = "in_progress" + await self.update_workflow_state( + instance_id=instance_id, plan=plan + ) - logger.info( - f"Rolled back agent response processing for {agent} at step {step_id}, substep {substep_id}" - ) + logger.info( + f"Rolled back agent response processing for {agent} at step {step_id}, substep {substep_id}" + ) except Exception as e: logger.error(f"Failed to rollback agent response processing: {e}") @@ -1456,22 +1406,16 @@ async def rollback_workflow_finalization(self, instance_id: str): Rollback workflow finalization to ensure consistent state. """ try: - if ( - self.memory_store._current_state - and "instances" in self.memory_store._current_state - ): - workflow_entry = self.memory_store._current_state["instances"].get( - instance_id - ) - if workflow_entry: - # Clear final output if it was set - if "output" in workflow_entry: - workflow_entry["output"] = None - if "end_time" in workflow_entry: - workflow_entry["end_time"] = None - - self.save_state() - logger.info(f"Rolled back workflow finalization for {instance_id}") + workflow_entry = self.state["instances"].get(instance_id) + if workflow_entry: + # Clear final output if it was set + if "output" in workflow_entry: + workflow_entry["output"] = None + if "end_time" in workflow_entry: + workflow_entry["end_time"] = None + + self.save_state() + logger.info(f"Rolled back workflow finalization for {instance_id}") except Exception as e: logger.error(f"Failed to rollback workflow finalization: {e}") @@ -1550,34 +1494,28 @@ async def ensure_workflow_state_consistency(self, instance_id: str): Ensures workflow state is consistent after compensation. """ try: - if ( - self.memory_store._current_state - and "instances" in self.memory_store._current_state - ): - workflow_entry = self.memory_store._current_state["instances"].get( - instance_id + workflow_entry = self.state["instances"].get(instance_id) + if not workflow_entry: + logger.warning( + f"No workflow entry found for {instance_id} during consistency check" ) - if not workflow_entry: - logger.warning( - f"No workflow entry found for {instance_id} during consistency check" - ) - return + return - # Ensure plan exists and is valid - if "plan" not in workflow_entry or not workflow_entry["plan"]: - workflow_entry["plan"] = [] + # Ensure plan exists and is valid + if "plan" not in workflow_entry or not workflow_entry["plan"]: + workflow_entry["plan"] = [] - # Ensure task_history exists - if "task_history" not in workflow_entry: - workflow_entry["task_history"] = [] + # Ensure task_history exists + if "task_history" not in workflow_entry: + workflow_entry["task_history"] = [] - # Ensure messages exists - if "messages" not in workflow_entry: - workflow_entry["messages"] = [] + # Ensure messages exists + if "messages" not in workflow_entry: + workflow_entry["messages"] = [] - # Save the consistent state - self.save_state() - logger.info(f"Ensured workflow state consistency for {instance_id}") + # Save the consistent state + self.save_state() + logger.info(f"Ensured workflow state consistency for {instance_id}") except Exception as e: logger.error(f"Failed to ensure workflow state consistency: {e}") diff --git a/quickstarts/01-hello-world/02_build_agentdurable.py b/quickstarts/01-hello-world/02_build_agentdurable.py deleted file mode 100644 index 42f72df9..00000000 --- a/quickstarts/01-hello-world/02_build_agentdurable.py +++ /dev/null @@ -1,39 +0,0 @@ -import asyncio -from dapr_agents import tool, DurableAgent, MemoryStore -from dapr_agents import OpenAIChatClient - - -@tool -def my_weather_func() -> str: - """Get current weather.""" - return "It's 72°F and sunny" - - -async def main(): - weather_agent = DurableAgent( - name="WeatherAgent", - role="Weather Assistant", - instructions=["Help users with weather information"], - tools=[my_weather_func], - message_bus_name="messagepubsub", - memory_store=MemoryStore( - name="statestore", - # Optional - local_directory="./local-state", - session_id="session", - ), - llm=OpenAIChatClient(model="gpt-3.5-turbo"), - ) - try: - # Can override session_id per run - response = await weather_agent.run( - "What's the weather?", - # session_id="new_session" # Optional: override default session - ) - print(response) - except Exception as e: - print(f"Error: {e}") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/quickstarts/01-hello-world/03_durable_agent.py b/quickstarts/01-hello-world/03_durable_agent.py index 97dc5989..63e5083a 100644 --- a/quickstarts/01-hello-world/03_durable_agent.py +++ b/quickstarts/01-hello-world/03_durable_agent.py @@ -12,7 +12,8 @@ from typing import List from pydantic import BaseModel, Field -from dapr_agents import tool, DurableAgent, MemoryStore +from dapr_agents import tool, DurableAgent +from dapr_agents.memory import ConversationDaprStateMemory from dotenv import load_dotenv @@ -65,11 +66,12 @@ async def main(): ], tools=[search_flights], message_bus_name="messagepubsub", - memory_store=MemoryStore( - name="statestore", - # Optional - local_directory="./local-state", - session_id="agent_session", + state_store_name="workflowstatestore", + state_key="workflow_state", + agents_registry_store_name="registrystatestore", + agents_registry_key="agents_registry", + memory=ConversationDaprStateMemory( + store_name="conversationstore", session_id="my-unique-id" ), # llm=llm, # if you don't set the llm attribute, it will be by default set to DaprChatClient() ) @@ -79,6 +81,8 @@ async def main(): except Exception as e: print(f"Error starting service: {e}") + finally: + travel_planner.graceful_shutdown() if __name__ == "__main__": diff --git a/quickstarts/01-hello-world/05_agent_with_vectorstore.py b/quickstarts/01-hello-world/05_agent_with_vectorstore.py index 5cbc292e..c7cdc2cd 100644 --- a/quickstarts/01-hello-world/05_agent_with_vectorstore.py +++ b/quickstarts/01-hello-world/05_agent_with_vectorstore.py @@ -7,7 +7,6 @@ from dapr_agents.storage.vectorstores import ChromaVectorStore from dapr_agents.tool import tool from dapr_agents.types.document import Document -from dapr_agents import MemoryStore load_dotenv() @@ -110,15 +109,8 @@ async def main(): "Provide relevant information from stored documents", ], tools=[search_documents, add_document, add_machine_learning_doc], + vector_store=vector_store, llm=OpenAIChatClient(model="gpt-3.5-turbo"), - memory_store=MemoryStore( - name="statestore", - # Optional - local_directory="./local-state", - session_id="session", - ), - # Note: Regular Agent always uses in-memory conversation history - # The 'name' field only registers the agent for discovery by orchestrators ) try: logging.info("Starting Vector Database Agent...") diff --git a/quickstarts/01-hello-world/components/conversationmemory.yaml b/quickstarts/01-hello-world/components/conversationmemory.yaml new file mode 100644 index 00000000..d6a28c60 --- /dev/null +++ b/quickstarts/01-hello-world/components/conversationmemory.yaml @@ -0,0 +1,14 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: conversationstore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 + - name: redisPassword + value: "" + - name: keyPrefix + value: none \ No newline at end of file diff --git a/quickstarts/01-hello-world/components/statestore.yaml b/quickstarts/01-hello-world/components/statestore.yaml index d0ed4be6..d753b092 100644 --- a/quickstarts/01-hello-world/components/statestore.yaml +++ b/quickstarts/01-hello-world/components/statestore.yaml @@ -1,7 +1,7 @@ apiVersion: dapr.io/v1alpha1 kind: Component metadata: - name: statestore + name: registrystatestore spec: type: state.redis version: v1 @@ -13,6 +13,4 @@ spec: - name: enableTLS value: "false" - name: keyPrefix - value: none - - name: actorStateStore - value: "true" \ No newline at end of file + value: none \ No newline at end of file diff --git a/quickstarts/05-multi-agent-workflows/components/memorystore.yaml b/quickstarts/01-hello-world/components/workflowstate.yaml similarity index 90% rename from quickstarts/05-multi-agent-workflows/components/memorystore.yaml rename to quickstarts/01-hello-world/components/workflowstate.yaml index f90b7ce9..8f7c55c2 100644 --- a/quickstarts/05-multi-agent-workflows/components/memorystore.yaml +++ b/quickstarts/01-hello-world/components/workflowstate.yaml @@ -1,7 +1,7 @@ apiVersion: dapr.io/v1alpha1 kind: Component metadata: - name: memorystore + name: workflowstatestore spec: type: state.redis version: v1 diff --git a/quickstarts/03-agent-tool-call/components/agentstatestore.yaml b/quickstarts/03-agent-tool-call/components/agentstatestore.yaml new file mode 100644 index 00000000..c7bc15d6 --- /dev/null +++ b/quickstarts/03-agent-tool-call/components/agentstatestore.yaml @@ -0,0 +1,16 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: agentstatestore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 + - name: redisPassword + value: "" + - name: enableTLS + value: "false" + - name: keyPrefix + value: none diff --git a/quickstarts/03-agent-tool-call/components/historystore.yaml b/quickstarts/03-agent-tool-call/components/historystore.yaml new file mode 100644 index 00000000..65bd3540 --- /dev/null +++ b/quickstarts/03-agent-tool-call/components/historystore.yaml @@ -0,0 +1,12 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: historystore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 + - name: redisPassword + value: "" \ No newline at end of file diff --git a/quickstarts/03-agent-tool-call/components/statestore.yaml b/quickstarts/03-agent-tool-call/components/workflowstatestore.yaml similarity index 90% rename from quickstarts/03-agent-tool-call/components/statestore.yaml rename to quickstarts/03-agent-tool-call/components/workflowstatestore.yaml index 66800d1f..55a4231e 100644 --- a/quickstarts/03-agent-tool-call/components/statestore.yaml +++ b/quickstarts/03-agent-tool-call/components/workflowstatestore.yaml @@ -1,7 +1,7 @@ apiVersion: dapr.io/v1alpha1 kind: Component metadata: - name: statestore + name: workflowstatestore spec: type: state.redis version: v1 diff --git a/quickstarts/03-agent-tool-call/weather_agent_dapr.py b/quickstarts/03-agent-tool-call/weather_agent_dapr.py index 91c11cf1..b061b105 100644 --- a/quickstarts/03-agent-tool-call/weather_agent_dapr.py +++ b/quickstarts/03-agent-tool-call/weather_agent_dapr.py @@ -1,7 +1,8 @@ import asyncio from weather_tools import tools -from dapr_agents import Agent, MemoryStore +from dapr_agents import Agent from dotenv import load_dotenv +from dapr_agents.memory import ConversationDaprStateMemory load_dotenv() @@ -14,12 +15,7 @@ "If you perform any additional actions (like jumping), summarize those actions and their results.", "At the end, provide a concise summary that combines the weather information for all requested locations and any other actions you performed.", ], - memory_store=MemoryStore( - name="statestore", - # Optional - local_directory="./local-state", - session_id="session", - ), + memory=ConversationDaprStateMemory(store_name="historystore", session_id="some-id"), tools=tools, ) diff --git a/quickstarts/03-agent-tool-call/weather_durable_agent_tracing.py b/quickstarts/03-agent-tool-call/weather_durable_agent_tracing.py index a2bcb9e2..54f83e80 100644 --- a/quickstarts/03-agent-tool-call/weather_durable_agent_tracing.py +++ b/quickstarts/03-agent-tool-call/weather_durable_agent_tracing.py @@ -1,6 +1,6 @@ import asyncio from weather_tools import tools -from dapr_agents import DurableAgent, MemoryStore +from dapr_agents import DurableAgent from dotenv import load_dotenv load_dotenv() @@ -35,12 +35,8 @@ async def main(): ], tools=tools, message_bus_name="messagepubsub", - memory_store=MemoryStore( - name="statestore", - # Optional - local_directory="./local-state", - session_id="session", - ), + state_store_name="workflowstatestore", + agents_registry_store_name="agentstatestore", ) await AIAgent.run("What is the weather in Virginia, New York and Washington DC?") diff --git a/quickstarts/03-durable-agent-multitool-dapr/components/statestore.yaml b/quickstarts/03-durable-agent-multitool-dapr/components/statestore.yaml index 9bb4e07c..2fc32cd0 100644 --- a/quickstarts/03-durable-agent-multitool-dapr/components/statestore.yaml +++ b/quickstarts/03-durable-agent-multitool-dapr/components/statestore.yaml @@ -1,7 +1,7 @@ apiVersion: dapr.io/v1alpha1 kind: Component metadata: - name: statestore + name: workflowstatestore spec: type: state.redis version: v1 diff --git a/quickstarts/03-durable-agent-multitool-dapr/multi_tool_agent_dapr.py b/quickstarts/03-durable-agent-multitool-dapr/multi_tool_agent_dapr.py index d95cb853..9a6fc5e2 100644 --- a/quickstarts/03-durable-agent-multitool-dapr/multi_tool_agent_dapr.py +++ b/quickstarts/03-durable-agent-multitool-dapr/multi_tool_agent_dapr.py @@ -1,4 +1,5 @@ -from dapr_agents import DurableAgent, MemoryStore +from dapr_agents import DurableAgent +from dapr_agents.llm.dapr import DaprChatClient from dotenv import load_dotenv from multi_tools import tools import asyncio @@ -36,12 +37,10 @@ async def main(): "Use the weather tool for location-based weather.", ], message_bus_name="messagepubsub", - memory_store=MemoryStore( - name="statestore", - # Optional - local_directory="./local-state", - session_id="session", - ), + state_store_name="workflowstatestore", + state_key="workflow_state", + agents_registry_store_name="workflowstatestore", + agents_registry_key="agents_registry", tools=tools, llm=llm_provider, ) diff --git a/quickstarts/03-durable-agent-tool-call/components/statestore.yaml b/quickstarts/03-durable-agent-tool-call/components/statestore.yaml index c0d8c42d..2dec8bef 100644 --- a/quickstarts/03-durable-agent-tool-call/components/statestore.yaml +++ b/quickstarts/03-durable-agent-tool-call/components/statestore.yaml @@ -1,7 +1,7 @@ apiVersion: dapr.io/v1alpha1 kind: Component metadata: - name: statestore + name: agentstatestore spec: type: state.redis version: v1 @@ -11,6 +11,4 @@ spec: - name: redisPassword value: "" - name: keyPrefix - value: none - - name: actorStateStore - value: "true" \ No newline at end of file + value: none \ No newline at end of file diff --git a/quickstarts/04-llm-based-workflows/components/statestore.yaml b/quickstarts/03-durable-agent-tool-call/components/workflowstate.yaml similarity index 100% rename from quickstarts/04-llm-based-workflows/components/statestore.yaml rename to quickstarts/03-durable-agent-tool-call/components/workflowstate.yaml diff --git a/quickstarts/03-durable-agent-tool-call/durable_weather_agent.py b/quickstarts/03-durable-agent-tool-call/durable_weather_agent.py index b5dd83ad..b68ca7d5 100644 --- a/quickstarts/03-durable-agent-tool-call/durable_weather_agent.py +++ b/quickstarts/03-durable-agent-tool-call/durable_weather_agent.py @@ -1,4 +1,4 @@ -from dapr_agents import DurableAgent, MemoryStore +from dapr_agents import DurableAgent from dotenv import load_dotenv from weather_tools import tools import asyncio @@ -19,12 +19,10 @@ async def main(): "Use tools when appropriate to fetch weather data.", ], message_bus_name="messagepubsub", - memory_store=MemoryStore( - name="statestore", - # Optional - local_directory="./local-state", - session_id="session", - ), + state_store_name="workflowstatestore", + state_key="workflow_state", + agents_registry_store_name="agentstatestore", + agents_registry_key="agents_registry", tools=tools, ) # 2️⃣ Start the agent service diff --git a/quickstarts/03-durable-agent-tool-call/durable_weather_agent_hf.py b/quickstarts/03-durable-agent-tool-call/durable_weather_agent_hf.py index bd49cd20..6a579172 100644 --- a/quickstarts/03-durable-agent-tool-call/durable_weather_agent_hf.py +++ b/quickstarts/03-durable-agent-tool-call/durable_weather_agent_hf.py @@ -1,4 +1,4 @@ -from dapr_agents import DurableAgent, HFHubChatClient, MemoryStore +from dapr_agents import DurableAgent, HFHubChatClient from dotenv import load_dotenv from weather_tools import tools import asyncio @@ -23,12 +23,10 @@ async def main(): ], llm=llm, message_bus_name="messagepubsub", - memory_store=MemoryStore( - name="statestore", - # Optional - local_directory="./local-state", - session_id="session", - ), + state_store_name="workflowstatestore", + state_key="workflow_state", + agents_registry_store_name="agentstatestore", + agents_registry_key="agents_registry", tools=tools, ) # Start the agent service diff --git a/quickstarts/03-durable-agent-tool-call/durable_weather_agent_nv.py b/quickstarts/03-durable-agent-tool-call/durable_weather_agent_nv.py index f80a7d2f..b96dc13b 100644 --- a/quickstarts/03-durable-agent-tool-call/durable_weather_agent_nv.py +++ b/quickstarts/03-durable-agent-tool-call/durable_weather_agent_nv.py @@ -1,4 +1,4 @@ -from dapr_agents import DurableAgent, NVIDIAChatClient, MemoryStore +from dapr_agents import DurableAgent, NVIDIAChatClient from dotenv import load_dotenv from weather_tools import tools import asyncio @@ -23,12 +23,10 @@ async def main(): ], llm=llm, message_bus_name="messagepubsub", - memory_store=MemoryStore( - name="memorystore", - # Optional - local_directory="./local-state", - session_id="session", - ), + state_store_name="workflowstatestore", + state_key="workflow_state", + agents_registry_store_name="agentstatestore", + agents_registry_key="agents_registry", tools=tools, ) # Start the agent service diff --git a/quickstarts/03-durable-agent-tool-call/durable_weather_agent_tracing.py b/quickstarts/03-durable-agent-tool-call/durable_weather_agent_tracing.py index 47dce4a2..5ea79d4a 100644 --- a/quickstarts/03-durable-agent-tool-call/durable_weather_agent_tracing.py +++ b/quickstarts/03-durable-agent-tool-call/durable_weather_agent_tracing.py @@ -4,7 +4,7 @@ from dotenv import load_dotenv from weather_tools import tools -from dapr_agents import DurableAgent, MemoryStore +from dapr_agents import DurableAgent logging.basicConfig(level=logging.INFO) @@ -35,12 +35,10 @@ async def main(): "Use tools when appropriate to fetch weather data.", ], message_bus_name="messagepubsub", - memory_store=MemoryStore( - name="statestore", - # Optional - local_directory="./local-state", - session_id="session", - ), + state_store_name="workflowstatestore", + state_key="workflow_state", + agents_registry_store_name="agentstatestore", + agents_registry_key="agents_registry", tools=tools, ) # 2️⃣ Start the agent service diff --git a/quickstarts/05-multi-agent-workflows/components/registrystore.yaml b/quickstarts/04-llm-based-workflows/components/workflowstate.yaml similarity index 73% rename from quickstarts/05-multi-agent-workflows/components/registrystore.yaml rename to quickstarts/04-llm-based-workflows/components/workflowstate.yaml index ca2146c3..2b09d93b 100644 --- a/quickstarts/05-multi-agent-workflows/components/registrystore.yaml +++ b/quickstarts/04-llm-based-workflows/components/workflowstate.yaml @@ -1,7 +1,7 @@ apiVersion: dapr.io/v1alpha1 kind: Component metadata: - name: registrystore + name: workflowstatestore spec: type: state.redis version: v1 @@ -10,5 +10,5 @@ spec: value: localhost:6379 - name: redisPassword value: "" - - name: keyPrefix - value: none \ No newline at end of file + - name: actorStateStore + value: "true" \ No newline at end of file diff --git a/quickstarts/05-multi-agent-workflow-k8s/components/statestore.yaml b/quickstarts/05-multi-agent-workflow-k8s/components/agentstate.yaml similarity index 93% rename from quickstarts/05-multi-agent-workflow-k8s/components/statestore.yaml rename to quickstarts/05-multi-agent-workflow-k8s/components/agentstate.yaml index b7e10452..99efff07 100644 --- a/quickstarts/05-multi-agent-workflow-k8s/components/statestore.yaml +++ b/quickstarts/05-multi-agent-workflow-k8s/components/agentstate.yaml @@ -1,7 +1,7 @@ apiVersion: dapr.io/v1alpha1 kind: Component metadata: - name: statestore + name: agentstatestore spec: type: state.redis version: v1 diff --git a/quickstarts/05-multi-agent-workflow-k8s/components/workflowstate.yaml b/quickstarts/05-multi-agent-workflow-k8s/components/workflowstate.yaml new file mode 100644 index 00000000..8062a122 --- /dev/null +++ b/quickstarts/05-multi-agent-workflow-k8s/components/workflowstate.yaml @@ -0,0 +1,14 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: workflowstatestore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: dapr-redis-master:6379 + - name: redisPassword + secretKeyRef: + name: dapr-redis + key: "redis-password" \ No newline at end of file diff --git a/quickstarts/05-multi-agent-workflows/components/agentstate.yaml b/quickstarts/05-multi-agent-workflows/components/agentstate.yaml new file mode 100644 index 00000000..df889845 --- /dev/null +++ b/quickstarts/05-multi-agent-workflows/components/agentstate.yaml @@ -0,0 +1,16 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: agentstatestore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 + - name: redisPassword + value: "" + - name: keyPrefix + value: none + - name: actorStateStore + value: "true" \ No newline at end of file diff --git a/quickstarts/05-multi-agent-workflows/components/conversationmemory.yaml b/quickstarts/05-multi-agent-workflows/components/conversationmemory.yaml new file mode 100644 index 00000000..d6a28c60 --- /dev/null +++ b/quickstarts/05-multi-agent-workflows/components/conversationmemory.yaml @@ -0,0 +1,14 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: conversationstore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 + - name: redisPassword + value: "" + - name: keyPrefix + value: none \ No newline at end of file diff --git a/quickstarts/05-multi-agent-workflows/components/workflowstate.yaml b/quickstarts/05-multi-agent-workflows/components/workflowstate.yaml new file mode 100644 index 00000000..46c22080 --- /dev/null +++ b/quickstarts/05-multi-agent-workflows/components/workflowstate.yaml @@ -0,0 +1,12 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: workflowstatestore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 + - name: redisPassword + value: "" \ No newline at end of file diff --git a/quickstarts/05-multi-agent-workflows/services/elf/app.py b/quickstarts/05-multi-agent-workflows/services/elf/app.py index 0226263b..c5ef0405 100644 --- a/quickstarts/05-multi-agent-workflows/services/elf/app.py +++ b/quickstarts/05-multi-agent-workflows/services/elf/app.py @@ -1,4 +1,4 @@ -from dapr_agents import DurableAgent, MemoryStore +from dapr_agents import DurableAgent from dotenv import load_dotenv import asyncio import logging @@ -34,12 +34,10 @@ async def main(): "Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.", ], message_bus_name="messagepubsub", - memory_store=MemoryStore( - name="memorystore", - # Optional - local_directory="./local-state", - session_id="session", - ), + state_store_name="workflowstatestore", + state_key="workflow_state", + agents_registry_store_name="agentstatestore", + agents_registry_key="agents_registry", broadcast_topic_name="beacon_channel", ) diff --git a/quickstarts/05-multi-agent-workflows/services/hobbit/app.py b/quickstarts/05-multi-agent-workflows/services/hobbit/app.py index 822b9b43..41f52032 100644 --- a/quickstarts/05-multi-agent-workflows/services/hobbit/app.py +++ b/quickstarts/05-multi-agent-workflows/services/hobbit/app.py @@ -1,4 +1,4 @@ -from dapr_agents import DurableAgent, MemoryStore +from dapr_agents import DurableAgent from dotenv import load_dotenv import asyncio import logging @@ -34,12 +34,10 @@ async def main(): "Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.", ], message_bus_name="messagepubsub", - memory_store=MemoryStore( - name="memorystore", - # Optional - local_directory="./local-state", - session_id="session", - ), + state_store_name="workflowstatestore", + state_key="workflow_state", + agents_registry_store_name="agentstatestore", + agents_registry_key="agents_registry", broadcast_topic_name="beacon_channel", ) diff --git a/quickstarts/05-multi-agent-workflows/services/wizard/app.py b/quickstarts/05-multi-agent-workflows/services/wizard/app.py index 45e41f7e..75b6e920 100644 --- a/quickstarts/05-multi-agent-workflows/services/wizard/app.py +++ b/quickstarts/05-multi-agent-workflows/services/wizard/app.py @@ -1,4 +1,4 @@ -from dapr_agents import DurableAgent, MemoryStore +from dapr_agents import DurableAgent from dotenv import load_dotenv import asyncio import logging @@ -34,12 +34,10 @@ async def main(): "Respond concisely, accurately, and relevantly, ensuring clarity and strict alignment with the task.", ], message_bus_name="messagepubsub", - memory_store=MemoryStore( - name="memorystore", - # Optional - local_directory="./local-state", - session_id="session", - ), + state_store_name="workflowstatestore", + state_key="workflow_state", + agents_registry_store_name="agentstatestore", + agents_registry_key="agents_registry", broadcast_topic_name="beacon_channel", ) diff --git a/quickstarts/05-multi-agent-workflows/services/workflow-llm/app.py b/quickstarts/05-multi-agent-workflows/services/workflow-llm/app.py index 40dcce81..125146a4 100644 --- a/quickstarts/05-multi-agent-workflows/services/workflow-llm/app.py +++ b/quickstarts/05-multi-agent-workflows/services/workflow-llm/app.py @@ -1,6 +1,6 @@ from dapr_agents import LLMOrchestrator from dapr_agents.llm import DaprChatClient -from dapr_agents import MemoryStore +from dapr_agents.memory import ConversationDaprStateMemory from dotenv import load_dotenv import asyncio import logging @@ -30,14 +30,14 @@ async def main(): name="LLMOrchestrator", llm=llm, message_bus_name="messagepubsub", - memory_store=MemoryStore( - name="memorystore", - # Optional - local_directory="./local-state", - session_id="session", - ), - registry_store="registrystore", + state_store_name="workflowstatestore", + state_key="workflow_state", + agents_registry_store_name="agentstatestore", + agents_registry_key="agents_registry", broadcast_topic_name="beacon_channel", + memory=ConversationDaprStateMemory( + store_name="conversationstore", session_id="myuniqueid" + ), max_iterations=3, ).as_service(port=8004) diff --git a/quickstarts/05-multi-agent-workflows/services/workflow-random/app.py b/quickstarts/05-multi-agent-workflows/services/workflow-random/app.py index 529d64d2..b30a42bd 100644 --- a/quickstarts/05-multi-agent-workflows/services/workflow-random/app.py +++ b/quickstarts/05-multi-agent-workflows/services/workflow-random/app.py @@ -1,4 +1,4 @@ -from dapr_agents import RandomOrchestrator, MemoryStore +from dapr_agents import RandomOrchestrator from dotenv import load_dotenv import asyncio import logging @@ -9,12 +9,10 @@ async def main(): workflow_service = RandomOrchestrator( name="RandomOrchestrator", message_bus_name="messagepubsub", - memory_store=MemoryStore( - name="statestore", - # Optional - local_directory="./local-state", - session_id="session", - ), + state_store_name="workflowstatestore", + state_key="workflow_state", + agents_registry_store_name="agentstatestore", + agents_registry_key="agents_registry", broadcast_topic_name="beacon_channel", max_iterations=3, ).as_service(port=8004) diff --git a/quickstarts/05-multi-agent-workflows/services/workflow-roundrobin/app.py b/quickstarts/05-multi-agent-workflows/services/workflow-roundrobin/app.py index dd60106c..491e2215 100644 --- a/quickstarts/05-multi-agent-workflows/services/workflow-roundrobin/app.py +++ b/quickstarts/05-multi-agent-workflows/services/workflow-roundrobin/app.py @@ -1,4 +1,4 @@ -from dapr_agents import RoundRobinOrchestrator, MemoryStore +from dapr_agents import RoundRobinOrchestrator from dotenv import load_dotenv import asyncio import logging @@ -9,12 +9,10 @@ async def main(): workflow_service = RoundRobinOrchestrator( name="RoundRobinOrchestrator", message_bus_name="messagepubsub", - memory_store=MemoryStore( - name="statestore", - # Optional - local_directory="./local-state", - session_id="session", - ), + state_store_name="workflowstatestore", + state_key="workflow_state", + agents_registry_store_name="agentstatestore", + agents_registry_key="agents_registry", broadcast_topic_name="beacon_channel", max_iterations=3, ).as_service(port=8004) diff --git a/quickstarts/06-document-agent-chainlit/app.py b/quickstarts/06-document-agent-chainlit/app.py index 48656a00..0e8fbd49 100644 --- a/quickstarts/06-document-agent-chainlit/app.py +++ b/quickstarts/06-document-agent-chainlit/app.py @@ -4,8 +4,9 @@ from unstructured.partition.pdf import partition_pdf from dapr_agents import Agent +from dapr_agents.memory import ConversationDaprStateMemory from dapr_agents.types import AssistantMessage -from dapr_agents import OpenAIChatClient, MemoryStore +from dapr_agents import OpenAIChatClient load_dotenv() @@ -20,11 +21,8 @@ name="KnowledgeBase", role="Content Expert", instructions=instructions, - memory_store=MemoryStore( - name="conversationstore", - # Optional - local_directory="./local-state", - session_id="session", + memory=ConversationDaprStateMemory( + store_name="conversationstore", session_id="my-unique-id" ), llm=OpenAIChatClient(model="gpt-3.5-turbo"), ) diff --git a/quickstarts/07-agent-mcp-client-sse/app.py b/quickstarts/07-agent-mcp-client-sse/app.py index b40d51ca..c9e471f6 100644 --- a/quickstarts/07-agent-mcp-client-sse/app.py +++ b/quickstarts/07-agent-mcp-client-sse/app.py @@ -2,7 +2,7 @@ import logging from dotenv import load_dotenv -from dapr_agents import DurableAgent, MemoryStore +from dapr_agents import DurableAgent from dapr_agents.tool.mcp import MCPClient @@ -27,12 +27,10 @@ async def main(): ], tools=tools, message_bus_name="messagepubsub", - memory_store=MemoryStore( - name="statestore", - # Optional - local_directory="./local-state", - session_id="session", - ), + state_store_name="workflowstatestore", + state_key="workflow_state", + agents_registry_store_name="agentstatestore", + agents_registry_key="agents_registry", ).as_service(port=8001) # Start the FastAPI agent service diff --git a/quickstarts/07-agent-mcp-client-sse/components/statestore.yaml b/quickstarts/07-agent-mcp-client-sse/components/statestore.yaml index cd1a2ae2..c63520e5 100644 --- a/quickstarts/07-agent-mcp-client-sse/components/statestore.yaml +++ b/quickstarts/07-agent-mcp-client-sse/components/statestore.yaml @@ -1,7 +1,7 @@ apiVersion: dapr.io/v1alpha1 kind: Component metadata: - name: statestore + name: agentstatestore spec: type: state.redis version: v1 diff --git a/quickstarts/07-agent-mcp-client-sse/components/workflowstate.yaml b/quickstarts/07-agent-mcp-client-sse/components/workflowstate.yaml new file mode 100644 index 00000000..5b307894 --- /dev/null +++ b/quickstarts/07-agent-mcp-client-sse/components/workflowstate.yaml @@ -0,0 +1,12 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: workflowstatestore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 + - name: redisPassword + value: "" \ No newline at end of file diff --git a/quickstarts/07-agent-mcp-client-streamablehttp/app.py b/quickstarts/07-agent-mcp-client-streamablehttp/app.py index 9ad6815e..b27f1498 100644 --- a/quickstarts/07-agent-mcp-client-streamablehttp/app.py +++ b/quickstarts/07-agent-mcp-client-streamablehttp/app.py @@ -2,7 +2,7 @@ import logging from dotenv import load_dotenv -from dapr_agents import DurableAgent, MemoryStore +from dapr_agents import DurableAgent from dapr_agents.tool.mcp import MCPClient @@ -29,12 +29,10 @@ async def main(): ], tools=tools, message_bus_name="messagepubsub", - memory_store=MemoryStore( - name="statestore", - # Optional - local_directory="./local-state", - session_id="session", - ), + state_store_name="workflowstatestore", + state_key="workflow_state", + agents_registry_store_name="agentstatestore", + agents_registry_key="agents_registry", ).as_service(port=8001) # Start the FastAPI agent service diff --git a/quickstarts/07-agent-mcp-client-streamablehttp/components/statestore.yaml b/quickstarts/07-agent-mcp-client-streamablehttp/components/statestore.yaml index cd1a2ae2..c63520e5 100644 --- a/quickstarts/07-agent-mcp-client-streamablehttp/components/statestore.yaml +++ b/quickstarts/07-agent-mcp-client-streamablehttp/components/statestore.yaml @@ -1,7 +1,7 @@ apiVersion: dapr.io/v1alpha1 kind: Component metadata: - name: statestore + name: agentstatestore spec: type: state.redis version: v1 diff --git a/quickstarts/07-agent-mcp-client-streamablehttp/components/workflowstate.yaml b/quickstarts/07-agent-mcp-client-streamablehttp/components/workflowstate.yaml new file mode 100644 index 00000000..5b307894 --- /dev/null +++ b/quickstarts/07-agent-mcp-client-streamablehttp/components/workflowstate.yaml @@ -0,0 +1,12 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: workflowstatestore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 + - name: redisPassword + value: "" \ No newline at end of file diff --git a/quickstarts/09-agents-as-tasks-in-workflows/components/workflowstate.yaml b/quickstarts/09-agents-as-tasks-in-workflows/components/workflowstate.yaml index 0681c0c2..2b09d93b 100644 --- a/quickstarts/09-agents-as-tasks-in-workflows/components/workflowstate.yaml +++ b/quickstarts/09-agents-as-tasks-in-workflows/components/workflowstate.yaml @@ -1,7 +1,7 @@ apiVersion: dapr.io/v1alpha1 kind: Component metadata: - name: statestore + name: workflowstatestore spec: type: state.redis version: v1 diff --git a/tests/agents/agent/test_agent.py b/tests/agents/agent/test_agent.py index 0d10ce4d..a42eadb8 100644 --- a/tests/agents/agent/test_agent.py +++ b/tests/agents/agent/test_agent.py @@ -267,8 +267,8 @@ def test_agent_properties(self, basic_agent): @pytest.mark.asyncio async def test_agent_with_memory_context(self, basic_agent): - """Test agent using storage context when no input is provided.""" - basic_agent.memory_store.add_message(UserMessage(content="Previous message")) + """Test agent using memory context when no input is provided.""" + basic_agent.memory.add_message(UserMessage(content="Previous message")) mock_response = Mock(spec=LLMChatResponse) assistant_msg = AssistantMessage(content="Response") diff --git a/tests/agents/durableagent/test_durable_agent.py b/tests/agents/durableagent/test_durable_agent.py index f68d6d37..90d10d5b 100644 --- a/tests/agents/durableagent/test_durable_agent.py +++ b/tests/agents/durableagent/test_durable_agent.py @@ -10,23 +10,25 @@ import pytest from dapr.ext.workflow import DaprWorkflowContext -from dapr_agents import DurableAgent -from dapr_agents.agents.memory_store import ( - DurableAgentMessage, - DurableAgentWorkflowState, - MemoryStore, -) +from dapr_agents.agents.durableagent.agent import DurableAgent from dapr_agents.agents.durableagent.schemas import ( AgentTaskResponse, BroadcastMessage, ) +from dapr_agents.agents.durableagent.state import ( + DurableAgentMessage, + DurableAgentWorkflowEntry, + DurableAgentWorkflowState, +) from dapr_agents.llm import OpenAIChatClient +from dapr_agents.memory import ConversationDaprStateMemory from dapr_agents.tool.base import AgentTool from dapr_agents.types import ( AssistantMessage, LLMChatCandidate, LLMChatResponse, ToolExecutionRecord, + ToolMessage, ) @@ -59,11 +61,22 @@ def mock_workflow_app_post_init(self, __context: Any) -> None: def mock_agentic_post_init(self, __context: Any) -> None: self._text_formatter = Mock() self.client = Mock() - self._client = Mock() + self._state_store_client = Mock() # Configure the mock to return a tuple as expected by try_get_state - self._client.get_state.return_value = (False, None) + self._state_store_client.try_get_state.return_value = (False, None) # Configure the mock for save_state method - self._client.save_state.return_value = None + self._state_store_client.save_state.return_value = None + self._agent_metadata = { + "name": getattr(self, "name", "TestAgent"), + "role": getattr(self, "role", "Test Role"), + "goal": getattr(self, "goal", "Test Goal"), + "instructions": getattr(self, "instructions", []), + "topic_name": getattr( + self, "agent_topic_name", getattr(self, "name", "TestAgent") + ), + "pubsub_name": getattr(self, "message_bus_name", "testpubsub"), + "orchestrator": False, + } self._workflow_name = "AgenticWorkflow" self._is_running = False self._shutdown_event = asyncio.Event() @@ -81,13 +94,12 @@ def mock_agentic_post_init(self, __context: Any) -> None: ) # No-op for testing - def mock_register_agent(self, store_name, store_key, agent_name, agent_metadata): + def mock_register_agentic_system(self): pass - # register_agent is now in AgentBase - from dapr_agents.agents.base import AgentBase - - monkeypatch.setattr(AgentBase, "register_agent", mock_register_agent) + monkeypatch.setattr( + agentic.AgenticWorkflow, "register_agentic_system", mock_register_agentic_system + ) yield @@ -96,11 +108,10 @@ class MockDaprClient: """Mock DaprClient that supports context manager protocol""" def __init__(self): - self.get_state = MagicMock( - return_value=Mock(data=None, json=lambda: {}, etag="test-etag") - ) + self.get_state = MagicMock(return_value=Mock(data=None, json=lambda: {})) self.save_state = MagicMock() - self.execute_state_transaction = MagicMock() + self.delete_state = MagicMock() + self.query_state = MagicMock() def __enter__(self): return self @@ -124,12 +135,10 @@ def setup_env(self, monkeypatch): mock_client = MockDaprClient() mock_client.get_state.return_value = Mock(data=None) # Default empty state - # Patch DaprClient factory to return our mock; MemoryStore will pick it up + # Patch both the client import locations monkeypatch.setattr("dapr.clients.DaprClient", lambda: mock_client) - from dapr_agents.agents.memory_store import MemoryStore as _MS - monkeypatch.setattr( - _MS, "_update_session_index", lambda self, instance_id: None + "dapr_agents.storage.daprstores.statestore.DaprClient", lambda: mock_client ) yield @@ -173,45 +182,51 @@ def mock_workflow_context(self): @pytest.fixture def basic_durable_agent(self, mock_llm): """Create a basic durable agent instance for testing.""" - memory_store = MemoryStore(name="teststatestore", session_id="test_session") return DurableAgent( name="TestDurableAgent", role="Test Durable Assistant", goal="Help with testing", instructions=["Be helpful", "Test things"], llm=mock_llm, - memory_store=memory_store, + memory=ConversationDaprStateMemory( + store_name="teststatestore", session_id="test_session" + ), max_iterations=5, + state_store_name="teststatestore", message_bus_name="testpubsub", + agents_registry_store_name="testregistry", ) @pytest.fixture def durable_agent_with_tools(self, mock_llm, mock_tool): """Create a durable agent with tools for testing.""" - memory_store = MemoryStore(name="teststatestore", session_id="test_session") return DurableAgent( name="ToolDurableAgent", role="Tool Durable Assistant", goal="Execute tools", instructions=["Use tools when needed"], llm=mock_llm, - memory_store=memory_store, + memory=ConversationDaprStateMemory( + store_name="teststatestore", session_id="test_session" + ), tools=[mock_tool], max_iterations=5, + state_store_name="teststatestore", message_bus_name="testpubsub", + agents_registry_store_name="testregistry", ) def test_durable_agent_initialization(self, mock_llm): """Test durable agent initialization with basic parameters.""" - memory_store = MemoryStore(name="teststatestore") agent = DurableAgent( name="TestDurableAgent", role="Test Durable Assistant", goal="Help with testing", instructions=["Be helpful"], llm=mock_llm, - memory_store=memory_store, + state_store_name="teststatestore", message_bus_name="testpubsub", + agents_registry_store_name="testregistry", ) assert agent.name == "TestDurableAgent" @@ -220,6 +235,7 @@ def test_durable_agent_initialization(self, mock_llm): assert agent.instructions == ["Be helpful"] assert agent.max_iterations == 10 # default value assert agent.tool_history == [] + assert agent.state_store_name == "teststatestore" assert agent.message_bus_name == "testpubsub" assert agent.agent_topic_name == "TestDurableAgent" assert agent.state is not None @@ -228,33 +244,45 @@ def test_durable_agent_initialization(self, mock_llm): def test_durable_agent_initialization_with_custom_topic(self, mock_llm): """Test durable agent initialization with custom topic name.""" - memory_store = MemoryStore(name="teststatestore") agent = DurableAgent( name="TestDurableAgent", role="Test Durable Assistant", goal="Help with testing", llm=mock_llm, - memory_store=memory_store, agent_topic_name="custom-topic", + state_store_name="teststatestore", message_bus_name="testpubsub", + agents_registry_store_name="testregistry", ) assert agent.agent_topic_name == "custom-topic" def test_durable_agent_initialization_name_from_role(self, mock_llm): """Test durable agent initialization with name derived from role.""" - memory_store = MemoryStore(name="teststatestore") agent = DurableAgent( role="Test Durable Assistant", goal="Help with testing", llm=mock_llm, - memory_store=memory_store, + state_store_name="teststatestore", message_bus_name="testpubsub", + agents_registry_store_name="testregistry", ) assert agent.name == "Test Durable Assistant" assert agent.agent_topic_name == "Test Durable Assistant" + def test_durable_agent_metadata(self, basic_durable_agent): + """Test durable agent metadata creation.""" + metadata = basic_durable_agent._agent_metadata + + assert metadata is not None + assert metadata["name"] == "TestDurableAgent" + assert metadata["role"] == "Test Durable Assistant" + assert metadata["goal"] == "Help with testing" + assert metadata["topic_name"] == "TestDurableAgent" + assert metadata["pubsub_name"] == "testpubsub" + assert metadata["orchestrator"] is False + @pytest.fixture def mock_wf_client(self): client = Mock() @@ -270,6 +298,51 @@ async def test_run_method(self, basic_durable_agent, mock_wf_client): result = await basic_durable_agent.run("test input") assert result == {"output": "test"} + @pytest.mark.asyncio + async def test_tool_calling_workflow_initialization( + self, basic_durable_agent, mock_workflow_context + ): + """Test workflow initialization on first iteration.""" + message = { + "task": "Test task", + "iteration": 0, + "workflow_instance_id": "parent-instance-123", + } + + mock_workflow_context.instance_id = "test-instance-123" + mock_workflow_context.call_activity.side_effect = [ + {"content": "Test response"}, + {"message": "Test response"}, + "stop", + ] + + basic_durable_agent.state["instances"]["test-instance-123"] = { + "input": "Test task", + "source": None, + "triggering_workflow_instance_id": "parent-instance-123", + "workflow_instance_id": "test-instance-123", + "workflow_name": "AgenticWorkflow", + "status": "RUNNING", + "messages": [], + "tool_history": [], + "end_time": None, + "trace_context": None, + } + + workflow_gen = basic_durable_agent.tool_calling_workflow( + mock_workflow_context, message + ) + try: + await workflow_gen.__next__() + except StopAsyncIteration: + pass + + assert "test-instance-123" in basic_durable_agent.state["instances"] + instance_data = basic_durable_agent.state["instances"]["test-instance-123"] + assert instance_data["input"] == "Test task" + assert instance_data["source"] is None + assert instance_data["triggering_workflow_instance_id"] == "parent-instance-123" + @pytest.mark.asyncio async def test_call_llm_activity(self, basic_durable_agent): """Test that call_llm unwraps an LLMChatResponse properly.""" @@ -288,7 +361,7 @@ async def test_call_llm_activity(self, basic_durable_agent): instance_id = "test-instance-123" # set up a minimal instance record - basic_durable_agent.memory_store._current_state["instances"] = { + basic_durable_agent.state["instances"] = { instance_id: { "input": "Test task", "source": "test_source", @@ -351,7 +424,7 @@ async def test_finish_workflow_activity(self, basic_durable_agent): """Test finishing workflow activity.""" instance_id = "test-instance-123" final_output = "Final response" - basic_durable_agent.memory_store._current_state["instances"] = { + basic_durable_agent.state["instances"] = { instance_id: { "input": "Test task", "source": "test_source", @@ -369,9 +442,7 @@ async def test_finish_workflow_activity(self, basic_durable_agent): basic_durable_agent.finalize_workflow( instance_id, final_output, "2024-01-01T00:00:00Z" ) - instance_data = basic_durable_agent.memory_store._current_state["instances"][ - instance_id - ] + instance_data = basic_durable_agent.state["instances"][instance_id] assert instance_data["output"] == final_output assert instance_data["end_time"] is not None @@ -391,7 +462,7 @@ async def test_run_tool(self, basic_durable_agent, mock_tool): mock_run_tool.return_value = "tool_result" # Set up instance state - basic_durable_agent.memory_store._current_state["instances"] = { + basic_durable_agent.state["instances"] = { instance_id: { "input": "Test task", "source": "test_source", @@ -421,9 +492,7 @@ async def test_run_tool(self, basic_durable_agent, mock_tool): assert result["execution_result"] == "tool_result" # Verify state was updated atomically - instance_data = basic_durable_agent.memory_store._current_state[ - "instances" - ][instance_id] + instance_data = basic_durable_agent.state["instances"][instance_id] assert len(instance_data["messages"]) == 1 # Tool message added assert ( len(instance_data["tool_history"]) == 1 @@ -462,12 +531,8 @@ def test_record_initial_entry(self, basic_durable_agent): ) # Verify instance was created - assert ( - instance_id in basic_durable_agent.memory_store._current_state["instances"] - ) - instance_data = basic_durable_agent.memory_store._current_state["instances"][ - instance_id - ] + assert instance_id in basic_durable_agent.state["instances"] + instance_data = basic_durable_agent.state["instances"][instance_id] assert instance_data["input"] == input_data assert instance_data["source"] == source assert ( @@ -493,12 +558,8 @@ def test_ensure_instance_exists(self, basic_durable_agent): instance_id, "Test input", triggering_workflow_instance_id, test_time ) - assert ( - instance_id in basic_durable_agent.memory_store._current_state["instances"] - ) - instance_data = basic_durable_agent.memory_store._current_state["instances"][ - instance_id - ] + assert instance_id in basic_durable_agent.state["instances"] + instance_data = basic_durable_agent.state["instances"][instance_id] assert ( instance_data["triggering_workflow_instance_id"] == triggering_workflow_instance_id @@ -509,9 +570,7 @@ def test_ensure_instance_exists(self, basic_durable_agent): # Test that existing instance is not overwritten original_input = "Original input" - basic_durable_agent.memory_store._current_state["instances"][instance_id][ - "input" - ] = original_input + basic_durable_agent.state["instances"][instance_id]["input"] = original_input basic_durable_agent._ensure_instance_exists( instance_id, "different-parent", "2024-01-02T00:00:00Z" @@ -519,9 +578,7 @@ def test_ensure_instance_exists(self, basic_durable_agent): # Input should remain unchanged assert ( - basic_durable_agent.memory_store._current_state["instances"][instance_id][ - "input" - ] + basic_durable_agent.state["instances"][instance_id]["input"] == original_input ) @@ -532,7 +589,7 @@ def test_process_user_message(self, basic_durable_agent): user_message_copy = {"role": "user", "content": "Hello, world!"} # Set up instance - basic_durable_agent.memory_store._current_state["instances"][instance_id] = { + basic_durable_agent.state["instances"][instance_id] = { "input": "Test task", "source": "test_source", "triggering_workflow_instance_id": None, @@ -545,16 +602,14 @@ def test_process_user_message(self, basic_durable_agent): "trace_context": None, } - # Mock memory_store.add_message - with patch.object(type(basic_durable_agent.memory_store), "add_message"): + # Mock memory.add_message + with patch.object(type(basic_durable_agent.memory), "add_message"): basic_durable_agent._process_user_message( instance_id, task, user_message_copy ) # Verify message was added to instance - instance_data = basic_durable_agent.memory_store._current_state["instances"][ - instance_id - ] + instance_data = basic_durable_agent.state["instances"][instance_id] assert len(instance_data["messages"]) == 1 assert instance_data["messages"][0]["role"] == "user" assert instance_data["messages"][0]["content"] == "Hello, world!" @@ -566,7 +621,7 @@ def test_save_assistant_message(self, basic_durable_agent): assistant_message = {"role": "assistant", "content": "Hello back!"} # Set up instance - basic_durable_agent.memory_store._current_state["instances"][instance_id] = { + basic_durable_agent.state["instances"][instance_id] = { "input": "Test task", "source": "test_source", "triggering_workflow_instance_id": None, @@ -579,14 +634,12 @@ def test_save_assistant_message(self, basic_durable_agent): "trace_context": None, } - # Mock memory_store.add_message - with patch.object(type(basic_durable_agent.memory_store), "add_message"): + # Mock memory.add_message + with patch.object(type(basic_durable_agent.memory), "add_message"): basic_durable_agent._save_assistant_message(instance_id, assistant_message) # Verify message was added to instance - instance_data = basic_durable_agent.memory_store._current_state["instances"][ - instance_id - ] + instance_data = basic_durable_agent.state["instances"][instance_id] assert len(instance_data["messages"]) == 1 assert instance_data["messages"][0]["role"] == "assistant" assert instance_data["messages"][0]["content"] == "Hello back!" @@ -597,7 +650,7 @@ def test_get_last_message_from_state(self, basic_durable_agent): instance_id = "test-instance-123" # Set up instance with last_message - basic_durable_agent.memory_store._current_state["instances"][instance_id] = { + basic_durable_agent.state["instances"][instance_id] = { "input": "Test task", "source": "test_source", "triggering_workflow_instance_id": None, @@ -657,7 +710,7 @@ def test_append_tool_message_to_instance(self, basic_durable_agent): instance_id = "test-instance-123" # Set up instance - basic_durable_agent.memory_store._current_state["instances"][instance_id] = { + basic_durable_agent.state["instances"][instance_id] = { "input": "Test task", "source": "test_source", "triggering_workflow_instance_id": None, @@ -684,21 +737,46 @@ def test_append_tool_message_to_instance(self, basic_durable_agent): ) # Verify instance was updated - instance_data = basic_durable_agent.memory_store._current_state["instances"][ - instance_id - ] + instance_data = basic_durable_agent.state["instances"][instance_id] assert len(instance_data["messages"]) == 1 assert instance_data["messages"][0]["role"] == "assistant" assert len(instance_data["tool_history"]) == 1 assert instance_data["tool_history"][0]["tool_call_id"] == "call_123" + def test_update_agent_memory_and_history(self, basic_durable_agent): + """Test _update_agent_memory_and_history helper method.""" + + tool_msg = ToolMessage( + tool_call_id="call_123", name="test_tool", content="Tool result" + ) + tool_history_entry = ToolExecutionRecord( + tool_call_id="call_123", + tool_name="test_tool", + execution_result="tool_result", + ) + + # Mock the memory add_message method + with patch.object( + type(basic_durable_agent.memory), "add_message" + ) as mock_add_message: + basic_durable_agent._update_agent_memory_and_history( + tool_msg, tool_history_entry + ) + + # Verify memory was updated + mock_add_message.assert_called_once_with(tool_msg) + + # Verify agent-level tool_history was updated + assert len(basic_durable_agent.tool_history) == 1 + assert basic_durable_agent.tool_history[0].tool_call_id == "call_123" + def test_construct_messages_with_instance_history(self, basic_durable_agent): """Test _construct_messages_with_instance_history helper method.""" instance_id = "test-instance-123" input_data = "Test input" # Set up instance with messages - basic_durable_agent.memory_store._current_state["instances"][instance_id] = { + basic_durable_agent.state["instances"][instance_id] = { "input": "Test task", "source": "test_source", "triggering_workflow_instance_id": None, @@ -761,15 +839,10 @@ async def test_send_message_to_agent(self, basic_durable_agent): # This needs refactoring / better implementation on this test since the actual implementation would depend on the pubsub msg broker. await basic_durable_agent.send_message_to_agent("TargetAgent", task_response) - def test_register_agent(self, basic_durable_agent): - """Test registering agent.""" + def test_register_agentic_system(self, basic_durable_agent): + """Test registering agentic system.""" # TODO(@Sicoyle): fix this to add assertions. - basic_durable_agent.register_agent( - store_name="test_store", - store_key="agent_registry", - agent_name=basic_durable_agent.name, - agent_metadata="test metadata", - ) + basic_durable_agent.register_agentic_system() @pytest.mark.asyncio async def test_process_broadcast_message(self, basic_durable_agent): @@ -797,8 +870,8 @@ def test_durable_agent_workflow_name(self, basic_durable_agent): def test_durable_agent_state_initialization(self, basic_durable_agent): """Test that the agent state is properly initialized.""" validated_state = DurableAgentWorkflowState.model_validate( - basic_durable_agent.memory_store._current_state + basic_durable_agent.state ) assert isinstance(validated_state, DurableAgentWorkflowState) - assert "instances" in basic_durable_agent.memory_store._current_state - assert basic_durable_agent.memory_store._current_state["instances"] == {} + assert "instances" in basic_durable_agent.state + assert basic_durable_agent.state["instances"] == {} diff --git a/tests/agents/durableagent/test_mcp_streamable_http.py b/tests/agents/durableagent/test_mcp_streamable_http.py index 1d6bcb3a..09f4db8d 100644 --- a/tests/agents/durableagent/test_mcp_streamable_http.py +++ b/tests/agents/durableagent/test_mcp_streamable_http.py @@ -1,8 +1,8 @@ import pytest from unittest.mock import AsyncMock, Mock from dapr_agents.agents.durableagent.agent import DurableAgent -from dapr_agents.agents.memory_store import DurableAgentWorkflowEntry -from dapr_agents.agents.memory_store import DurableAgentWorkflowState +from dapr_agents.agents.durableagent.state import DurableAgentWorkflowEntry +from dapr_agents.agents.durableagent.state import DurableAgentWorkflowState from dapr_agents.tool.base import AgentTool @@ -23,7 +23,32 @@ def _setdefault(self, key, default): DurableAgentWorkflowState.__getitem__ = _getitem DurableAgentWorkflowState.setdefault = _setdefault + # Patch DaprStateStore to use a mock DaprClient that supports context manager + import dapr_agents.storage.daprstores.statestore as statestore + class MockDaprClient: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def save_state(self, *args, **kwargs): + pass + + def get_state(self, *args, **kwargs): + class R: + data = "{}" + etag = "etag" + + return R() + + def execute_state_transaction(self, *args, **kwargs): + pass + + statestore.DaprClient = MockDaprClient + # Patch DaprStateStore to use a mock DaprClient that supports context manager + from dapr_agents.workflow import agentic from dapr_agents.workflow import base # Mock the WorkflowApp initialization to prevent DaprClient creation @@ -40,13 +65,12 @@ def mock_workflow_app_post_init(self, __context): ) # Patch out agent registration logic (skip state store entirely) - def mock_register_agent(self, store_name, store_key, agent_name, agent_metadata): + def mock_register_agentic_system(self): pass - # register_agent is now in AgentBase - from dapr_agents.agents.base import AgentBase - - monkeypatch.setattr(AgentBase, "register_agent", mock_register_agent) + monkeypatch.setattr( + agentic.AgenticWorkflow, "register_agentic_system", mock_register_agentic_system + ) yield @@ -115,20 +139,19 @@ async def fake_call_tool(*args, **kwargs): @pytest.fixture def durable_agent_with_mcp_tool(mock_mcp_tool, mock_mcp_session): from dapr_agents.tool.executor import AgentToolExecutor - from dapr_agents.agents.memory_store import MemoryStore agent_tool = AgentTool.from_mcp(mock_mcp_tool, session=mock_mcp_session) tool_executor = AgentToolExecutor(tools=[agent_tool]) - memory_store = MemoryStore(name="teststatestore") agent = DurableAgent( name="TestDurableAgent", role="Math Assistant", goal="Help humans do math", instructions=["Test math instructions"], tools=[agent_tool], - memory_store=memory_store, state=DurableAgentWorkflowState().model_dump(), + state_store_name="teststatestore", message_bus_name="testpubsub", + agents_registry_store_name="testregistry", ) agent.__pydantic_private__["_tool_executor"] = tool_executor return agent @@ -150,9 +173,7 @@ async def test_execute_tool_activity_with_mcp_tool(durable_agent_with_mcp_tool): "end_time": None, "trace_context": None, } - durable_agent_with_mcp_tool.memory_store._current_state["instances"][ - instance_id - ] = workflow_entry + durable_agent_with_mcp_tool.state["instances"][instance_id] = workflow_entry # Print available tool names for debugging tool_names = [t.name for t in durable_agent_with_mcp_tool.tool_executor.tools] @@ -170,9 +191,7 @@ async def test_execute_tool_activity_with_mcp_tool(durable_agent_with_mcp_tool): await durable_agent_with_mcp_tool.run_tool( tool_call, instance_id, "2024-01-01T00:00:00Z" ) - instance_data = durable_agent_with_mcp_tool.memory_store._current_state[ - "instances" - ][instance_id] + instance_data = durable_agent_with_mcp_tool.state["instances"][instance_id] assert len(instance_data["tool_history"]) == 1 tool_entry = instance_data["tool_history"][0] assert tool_entry["tool_call_id"] == "call_123" @@ -206,53 +225,11 @@ def start_math_server_http(): async def get_agent_tools_from_http(): from dapr_agents.tool.mcp import MCPClient - try: - client = MCPClient() - await client.connect_streamable_http( - server_name="local", url="http://localhost:8000/mcp/" - ) - return client.get_all_tools() - except Exception: - # Fallback to a mocked tool list if server is unavailable - from dapr_agents.tool.base import AgentTool - from unittest.mock import AsyncMock, Mock - import json - - # Minimal mock MCP tool and session mirroring the add tool - mcp_tool = Mock() - mcp_tool.name = "add" - mcp_tool.description = "Add two numbers" - mcp_tool.inputSchema = { - "type": "object", - "properties": {"a": {"type": "integer"}, "b": {"type": "integer"}}, - "required": ["a", "b"], - } - - async def fake_call_tool(*args, **kwargs): - a = int(kwargs.get("a", 0)) - b = int(kwargs.get("b", 0)) - if not ("a" in kwargs and "b" in kwargs) and args: - try: - data = ( - json.loads(args[-1]) if isinstance(args[-1], str) else args[-1] - ) - a = ( - int(getattr(data, "get", lambda k, d=None: d)("a", a)) - if hasattr(data, "get") - else a - ) - b = ( - int(getattr(data, "get", lambda k, d=None: d)("b", b)) - if hasattr(data, "get") - else b - ) - except Exception: - pass - return str(a + b) - - session = Mock() - session.call_tool = AsyncMock(side_effect=fake_call_tool) - return [AgentTool.from_mcp(mcp_tool, session=session)] + client = MCPClient() + await client.connect_streamable_http( + server_name="local", url="http://localhost:8000/mcp/" + ) + return client.get_all_tools() @pytest.mark.asyncio @@ -276,19 +253,18 @@ async def test_add_tool_with_real_server_http(start_math_server_http): async def test_durable_agent_with_real_server_http(start_math_server_http): agent_tools = await get_agent_tools_from_http() from dapr_agents.tool.executor import AgentToolExecutor - from dapr_agents.agents.memory_store import MemoryStore tool_executor = AgentToolExecutor(tools=agent_tools) - memory_store = MemoryStore(name="teststatestore") agent = DurableAgent( name="TestDurableAgent", role="Math Assistant", goal="Help humans do math", instructions=["Test math instructions"], tools=agent_tools, - memory_store=memory_store, state=DurableAgentWorkflowState().model_dump(), + state_store_name="teststatestore", message_bus_name="testpubsub", + agents_registry_store_name="testregistry", ) agent.__pydantic_private__["_tool_executor"] = tool_executor instance_id = "test-instance-456" @@ -304,7 +280,7 @@ async def test_durable_agent_with_real_server_http(start_math_server_http): "end_time": None, "trace_context": None, } - agent.memory_store._current_state["instances"][instance_id] = workflow_entry + agent.state["instances"][instance_id] = workflow_entry # Print available tool names tool_names = [t.name for t in agent.tool_executor.tools] print("Available tool names (integration test):", tool_names) @@ -317,7 +293,7 @@ async def test_durable_agent_with_real_server_http(start_math_server_http): "function": {"name": tool_name, "arguments": '{"a": 2, "b": 2}'}, } await agent.run_tool(tool_call, instance_id, "2024-01-01T00:00:00Z") - instance_data = agent.memory_store._current_state["instances"][instance_id] + instance_data = agent.state["instances"][instance_id] assert len(instance_data["tool_history"]) == 1 tool_entry = instance_data["tool_history"][0] assert tool_entry["tool_call_id"] == "call_456" diff --git a/tests/agents/mocks/memory.py b/tests/agents/mocks/memory.py index 99e158b0..f54df2d9 100644 --- a/tests/agents/mocks/memory.py +++ b/tests/agents/mocks/memory.py @@ -27,3 +27,6 @@ def add_interaction( self, user_message: BaseMessage, assistant_message: BaseMessage ): pass + + def reset_memory(self): + pass diff --git a/tests/agents/test_base.py b/tests/agents/test_base.py index 75ae41af..d7aa1d83 100644 --- a/tests/agents/test_base.py +++ b/tests/agents/test_base.py @@ -20,10 +20,6 @@ def run(self, input_data): """Implementation of abstract method for testing.""" return f"Processed: {input_data}" - def get_chat_history(self, task=None): - """Implementation of abstract method for testing.""" - return self.memory_store.get_messages() - class TestAgentBaseClass: """Test cases for AgentBase class.""" @@ -81,6 +77,7 @@ def test_agent_creation_with_all_fields(self, basic_agent): assert basic_agent.instructions == ["Test instruction 1", "Test instruction 2"] assert basic_agent.max_iterations == 10 assert basic_agent.template_format == "jinja2" + assert isinstance(basic_agent.memory, ConversationListMemory) assert basic_agent.llm is not None def test_agent_creation_with_minimal_fields(self, minimal_agent): @@ -187,13 +184,14 @@ def test_get_last_message_empty_memory(self, basic_agent): assert basic_agent.get_last_message() is None def test_get_last_message_with_memory(self, basic_agent): - """Test getting last message from storage with content.""" + """Test getting last message from memory with content.""" # Use a dictionary as the mock message mock_message = {"foo": "bar"} - # Add a message to storage directly - basic_agent.memory_store._in_memory_messages = [mock_message] - result = basic_agent.get_last_message() - assert result == {"foo": "bar"} + with patch.object( + ConversationListMemory, "get_messages", return_value=[mock_message] + ): + result = basic_agent.get_last_message() + assert result == {"foo": "bar"} def test_get_last_user_message(self, basic_agent): """Test getting last user message from message list.""" @@ -218,6 +216,12 @@ def test_get_last_user_message_no_user_messages(self, basic_agent): result = basic_agent.get_last_user_message(messages) assert result is None + def test_reset_memory(self, basic_agent): + """Test memory reset.""" + with patch.object(type(basic_agent.memory), "reset_memory") as mock_reset: + basic_agent.reset_memory() + mock_reset.assert_called_once() + def test_pre_fill_prompt_template(self, basic_agent): """Test pre-filling prompt template with variables.""" # Store original template for comparison @@ -252,26 +256,33 @@ def test_pre_fill_prompt_template_without_template(self, mock_llm_client): agent.pre_fill_prompt_template(custom_var="test_value") def test_chat_history_with_vector_memory_and_task(self): - """Test chat history retrieval with storage.""" + """Test chat history retrieval with vector memory and task.""" + from tests.agents.mocks.vectorstore import MockVectorStore + from tests.agents.mocks.memory import DummyVectorMemory + + mock_vector_store = MockVectorStore() mock_llm = MockLLMClient() - agent = TestAgentBase(llm=mock_llm) + memory = DummyVectorMemory(mock_vector_store) + agent = TestAgentBase(memory=memory, llm=mock_llm) - # Add messages to storage directly - agent.memory_store._in_memory_messages = [Mock(), Mock()] + # Access chat_history as a property result = agent.chat_history assert isinstance(result, list) - if len(result) > 0: - assert isinstance(result[0], Mock) + assert isinstance(result[0], Mock) def test_chat_history_with_regular_memory(self, mock_llm_client): - """Test chat history retrieval with storage.""" - agent = TestAgentBase(llm=mock_llm_client) - - # Add a message to storage directly - agent.memory_store._in_memory_messages = [Mock(spec=MessageContent)] - result = agent.chat_history - assert isinstance(result, list) - assert isinstance(result[0], Mock) + """Test chat history retrieval with regular memory.""" + memory = ConversationListMemory() + agent = TestAgentBase(memory=memory, llm=mock_llm_client) + + with patch.object( + ConversationListMemory, + "get_messages", + return_value=[Mock(spec=MessageContent)], + ): + result = agent.chat_history + assert isinstance(result, list) + assert isinstance(result[0], Mock) def test_prefill_agent_attributes_missing_fields_warns( self, mock_llm_client, caplog @@ -313,13 +324,13 @@ def test_validate_llm_openai_without_api_key(self, monkeypatch): ): TestAgentBase(llm=OpenAIChatClient()) - def test_validate_storage_failure(self, mock_llm_client): - """Test validation fails when storage initialization fails.""" + def test_validate_memory_failure(self, mock_llm_client): + """Test validation fails when memory initialization fails.""" with patch( - "dapr_agents.agents.memory_store.MemoryStore.__init__", - side_effect=Exception("Storage error"), + "dapr_agents.memory.ConversationListMemory.__new__", + side_effect=Exception("Memory error"), ): - with pytest.raises(Exception, match="Storage error"): + with pytest.raises(Exception, match="Memory error"): TestAgentBase(llm=mock_llm_client) def test_signal_handler_setup(self, basic_agent): diff --git a/tests/workflow/orchestrators/test_random.py b/tests/workflow/orchestrators/test_random.py index bf115cb3..c438dd2e 100644 --- a/tests/workflow/orchestrators/test_random.py +++ b/tests/workflow/orchestrators/test_random.py @@ -2,7 +2,6 @@ import pytest from unittest.mock import MagicMock, patch from dapr_agents.workflow.orchestrators import RandomOrchestrator -from dapr_agents.agents.memory_store import MemoryStore @pytest.fixture @@ -11,8 +10,8 @@ def orchestrator_config(): return { "name": "test_orchestrator", "message_bus_name": "test-message-bus", - "memory_store": MemoryStore.model_construct(name="test-state-store"), - "registry_store": "test-registry-store", + "state_store_name": "test-state-store", + "agents_registry_store_name": "test-registry-store", } @@ -22,7 +21,12 @@ def test_random_orchestrator_initialization(orchestrator_config): "dapr_agents.workflow.orchestrators.random.OrchestratorWorkflowBase.model_post_init" ) as mock_init, patch( "dapr_agents.workflow.agentic.AgenticWorkflow.model_post_init" - ), patch("dapr_agents.workflow.agentic.AgenticWorkflow._dapr_client") as mockclient: + ), patch( + "dapr_agents.workflow.agentic.AgenticWorkflow._state_store_client" + ) as mock_state_store, patch( + "dapr_agents.workflow.agentic.AgenticWorkflow._dapr_client" + ) as mockclient: + mock_state_store.return_value = MagicMock() mockclient.return_value = MagicMock() orchestrator = RandomOrchestrator(**orchestrator_config) assert orchestrator.name == "test_orchestrator" @@ -36,8 +40,11 @@ async def test_process_input(orchestrator_config): with patch( "dapr_agents.workflow.orchestrators.random.OrchestratorWorkflowBase.model_post_init" ), patch("dapr_agents.workflow.agentic.AgenticWorkflow.model_post_init"), patch( + "dapr_agents.workflow.agentic.AgenticWorkflow._state_store_client" + ) as mock_state_store, patch( "dapr_agents.workflow.agentic.AgenticWorkflow._dapr_client" ) as mockclient: + mock_state_store.return_value = MagicMock() mockclient.return_value = MagicMock() orchestrator = RandomOrchestrator(**orchestrator_config) task = "test task" @@ -53,17 +60,18 @@ def test_select_random_speaker(orchestrator_config): with patch( "dapr_agents.workflow.orchestrators.random.OrchestratorWorkflowBase.model_post_init" ), patch("dapr_agents.workflow.agentic.AgenticWorkflow.model_post_init"), patch( + "dapr_agents.workflow.agentic.AgenticWorkflow._state_store_client" + ) as mock_state_store, patch( "dapr_agents.workflow.agentic.AgenticWorkflow._dapr_client" + ) as mockclient, patch.object( + RandomOrchestrator, + "get_agents_metadata", + return_value={"agent1": {"name": "agent1"}, "agent2": {"name": "agent2"}}, ): + mock_state_store.return_value = MagicMock() + mockclient.return_value = MagicMock() orchestrator = RandomOrchestrator(**orchestrator_config) - with patch.object( - RandomOrchestrator, - "get_agents_metadata", - return_value={ - "agent1": {"name": "agent1"}, - "agent2": {"name": "agent2"}, - }, - ): - speaker = orchestrator.select_random_speaker() + + speaker = orchestrator.select_random_speaker() assert speaker in ["agent1", "agent2"] assert orchestrator.current_speaker == speaker