diff --git a/dapr_agents/__init__.py b/dapr_agents/__init__.py index 5ed603a2..b6bc9a54 100644 --- a/dapr_agents/__init__.py +++ b/dapr_agents/__init__.py @@ -1,5 +1,5 @@ -from dapr_agents.agents.agent import Agent -from dapr_agents.agents.durableagent import DurableAgent +from dapr_agents.agents.standalone import Agent +from dapr_agents.agents.durable import DurableAgent from dapr_agents.executors import DockerCodeExecutor, LocalCodeExecutor from dapr_agents.llm.dapr import DaprChatClient from dapr_agents.llm.elevenlabs import ElevenLabsSpeechClient @@ -11,12 +11,10 @@ OpenAIEmbeddingClient, ) from dapr_agents.tool import AgentTool, tool -from dapr_agents.workflow import ( - AgenticWorkflow, +from dapr_agents.agents.orchestrators import ( LLMOrchestrator, RandomOrchestrator, RoundRobinOrchestrator, - WorkflowApp, ) __all__ = [ @@ -38,5 +36,4 @@ "LLMOrchestrator", "RandomOrchestrator", "RoundRobinOrchestrator", - "WorkflowApp", ] diff --git a/dapr_agents/agents/__init__.py b/dapr_agents/agents/__init__.py index 0c7b8b70..e76def36 100644 --- a/dapr_agents/agents/__init__.py +++ b/dapr_agents/agents/__init__.py @@ -1,5 +1,5 @@ -from .agent.agent import Agent from .base import AgentBase -from .durableagent.agent import DurableAgent +from .standalone import Agent +from .durable import DurableAgent __all__ = ["AgentBase", "Agent", "DurableAgent"] diff --git a/dapr_agents/agents/agent/__init__.py b/dapr_agents/agents/agent/__init__.py deleted file mode 100644 index a6df24ad..00000000 --- a/dapr_agents/agents/agent/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .agent import Agent - -__all__ = ["Agent"] diff --git a/dapr_agents/agents/agent/agent.py b/dapr_agents/agents/agent/agent.py deleted file mode 100644 index 8beb1791..00000000 --- a/dapr_agents/agents/agent/agent.py +++ /dev/null @@ -1,269 +0,0 @@ -import asyncio -import logging -from typing import Any, Dict, List, Optional, Union - -from dapr_agents.agents.base import AgentBase -from dapr_agents.types import ( - AgentError, - ToolCall, - ToolExecutionRecord, - ToolMessage, - UserMessage, - LLMChatResponse, -) - -logger = logging.getLogger(__name__) - - -class Agent(AgentBase): - """ - Agent that manages tool calls and conversations using a language model. - It integrates tools and processes them based on user inputs and task orchestration. - """ - - async def run(self, input_data: Optional[Union[str, Dict[str, Any]]] = None) -> Any: - """ - Runs the agent with the given input, supporting graceful shutdown. - Uses the _race helper to handle shutdown and cancellation cleanly. - - Args: - input_data (Optional[Union[str, Dict[str, Any]]]): Input for the agent, can be a string or dict. - Returns: - Any: The result of agent execution, or None if shutdown is requested. - """ - try: - return await self._race(self._run_agent(input_data)) - except asyncio.CancelledError: - logger.info("Agent execution was cancelled.") - return None - except Exception as e: - logger.error(f"Error during agent execution: {e}") - raise - - async def _race(self, coro) -> Optional[Any]: - """ - Runs the given coroutine and races it against the agent's shutdown event. - If shutdown is triggered, cancels the task and returns None. - - Args: - coro: The coroutine to run (e.g., _run_agent(input_data)). - Returns: - Optional[Any]: The result of the coroutine, or None if shutdown is triggered. - """ - task = asyncio.create_task(coro) - shutdown_task = asyncio.create_task(self._shutdown_event.wait()) - done, pending = await asyncio.wait( - [task, shutdown_task], - return_when=asyncio.FIRST_COMPLETED, - ) - for p in pending: - p.cancel() - if self._shutdown_event.is_set(): - logger.info("Shutdown requested during execution. Cancelling agent.") - task.cancel() - return None - return await task - - async def _run_agent( - self, input_data: Optional[Union[str, Dict[str, Any]]] = None - ) -> Any: - """ - Internal method for running the agent logic. - Formats messages, updates memory, and drives the conversation loop. - - Args: - input_data (Optional[Union[str, Dict[str, Any]]]): Input for the agent, can be a string or dict. - Returns: - Any: The result of the agent's conversation loop. - """ - logger.debug( - f"Agent run started with input: {input_data if input_data else 'Using memory context'}" - ) - - # Construct messages using only input_data; chat history handled internally - messages: List[Dict[str, Any]] = self.construct_messages(input_data or {}) - user_message = self.get_last_user_message(messages) - # Always work with a copy of the user message for safety - user_message_copy: Optional[Dict[str, Any]] = ( - dict(user_message) if user_message else None - ) - - if input_data and user_message_copy: - # Add the new user message to memory only if input_data is provided and user message exists - user_msg = UserMessage(content=user_message_copy.get("content", "")) - self.memory.add_message(user_msg) - - # Always print the last user message for context, even if no input_data is provided - if user_message_copy is not None: - # Ensure keys are str for mypy - self.text_formatter.print_message( - {str(k): v for k, v in user_message_copy.items()} - ) - - # Process conversation iterations and return the result - return await self.conversation(messages) - - async def execute_tools(self, tool_calls: List[ToolCall]) -> List[ToolMessage]: - """ - Executes a batch of tool calls in parallel, bounded by max_concurrent, using asyncio.gather. - Each tool call is executed asynchronously using run_tool, and results are appended to the persistent audit log (tool_history). - If any tool call fails, the error is propagated and other tasks continue unless you set return_exceptions=True. - - Args: - tool_calls (List[ToolCall]): List of tool calls returned by the LLM to execute in this batch. - max_concurrent (int, optional): Maximum number of concurrent tool executions (default: 5). - - Returns: - List[ToolMessage]: Results for this batch of tool calls, in the same order as input. - - Raises: - AgentError: If any tool execution fails. - """ - # Limiting concurrency to avoid overwhelming downstream systems - max_concurrent = 10 - semaphore = asyncio.Semaphore(max_concurrent) - - async def run_and_record(tool_call: ToolCall) -> ToolMessage: - """ - Executes a single tool call, respecting the concurrency limit. - Appends the result to the persistent audit log. - If the function name is missing, returns a ToolMessage with error status and raises AgentError. - """ - async with semaphore: - function_name = tool_call.function.name - tool_id = tool_call.id - function_args = tool_call.function.arguments_dict - - if not function_name: - error_msg = f"Tool call missing function name: {tool_call}" - logger.error(error_msg) - # Return a ToolExecutionRecord with error status and raise AgentError - tool_execution_record = ToolExecutionRecord( - tool_call_id="", - tool_name="", - tool_args={}, - execution_result=error_msg, - ) - self.tool_history.append(tool_execution_record) - raise AgentError(error_msg) - - try: - logger.debug( - f"Executing {function_name} with arguments {function_args}" - ) - result = await self.run_tool(function_name, **function_args) - result_str = str(result) if result is not None else "" - tool_message = ToolMessage( - tool_call_id=tool_id, - name=function_name, - content=result_str, - ) - # Print the tool message for visibility - self.text_formatter.print_message(tool_message) - # Add tool message to memory - self.memory.add_message(tool_message) - # Append tool message to the persistent audit log - tool_execution_record = ToolExecutionRecord( - tool_call_id=tool_id, - tool_name=function_name, - tool_args=function_args, - execution_result=result_str, - ) - self.tool_history.append(tool_execution_record) - return tool_message - except Exception as e: - logger.error(f"Error executing tool {function_name}: {e}") - raise AgentError( - f"Error executing tool '{function_name}': {e}" - ) from e - - # Run all tool calls concurrently, but bounded by max_concurrent - return await asyncio.gather(*(run_and_record(tc) for tc in tool_calls)) - - async def conversation(self, messages: List[Dict[str, Any]]) -> Any: - """ - Drives the agent conversation iteratively until a final answer or max iterations is reached. - Handles tool calls, updates memory, and returns the final assistant message. - Tool results are localized per iteration; persistent audit log is kept for all tool executions. - - Args: - messages (List[Dict[str, Any]]): Initial conversation messages. - Returns: - Any: The final assistant message or None if max iterations reached. - Raises: - AgentError: On chat failure or tool issues. - """ - final_reply = None - for turn in range(1, self.max_iterations + 1): - logger.info(f"Iteration {turn}/{self.max_iterations} started.") - try: - # Generate response using the LLM - response: LLMChatResponse = self.llm.generate( - messages=messages, - tools=self.get_llm_tools(), - **( - {"tool_choice": self.tool_choice} - if self.tool_choice is not None - else {} - ), - ) - # Get the first candidate from the response - response_message = response.get_message() - # Check if the response contains an assistant message - if response_message is None: - raise AgentError("LLM returned no assistant message") - else: - assistant = response_message - self.text_formatter.print_message(assistant) - self.memory.add_message(assistant) - - # Handle tool calls response - if assistant is not None and assistant.has_tool_calls(): - tool_calls = assistant.get_tool_calls() - if tool_calls: - messages.append(assistant.model_dump()) - tool_msgs = await self.execute_tools(tool_calls) - messages.extend([tm.model_dump() for tm in tool_msgs]) - if turn == self.max_iterations: - final_reply = assistant - logger.info("Reached max turns after tool calls; stopping.") - break - continue - - # No tool calls => done - final_reply = assistant - break - - except Exception as e: - logger.error(f"Error on turn {turn}: {e}") - raise AgentError(f"Failed during chat generation: {e}") from e - - # Post-loop - if final_reply is None: - logger.warning("No reply generated; hitting max iterations.") - return None - - logger.info(f"Agent conversation completed after {turn} turns.") - return final_reply - - async def run_tool(self, tool_name: str, *args, **kwargs) -> Any: - """ - Executes a single registered tool by name, handling both sync and async tools. - Used for atomic tool execution, either directly or as part of a batch in execute_tools. - - Args: - tool_name (str): Name of the tool to run. - *args: Positional arguments for the tool. - **kwargs: Keyword arguments for the tool. - - Returns: - Any: Result from the tool execution. - - Raises: - AgentError: If the tool is not found or execution fails. - """ - try: - return await self.tool_executor.run_tool(tool_name, *args, **kwargs) - except Exception as e: - logger.error(f"Agent failed to run tool '{tool_name}': {e}") - raise AgentError(f"Failed to run tool '{tool_name}': {e}") from e diff --git a/dapr_agents/agents/base.py b/dapr_agents/agents/base.py index e7bfe864..0e6c3997 100644 --- a/dapr_agents/agents/base.py +++ b/dapr_agents/agents/base.py @@ -1,557 +1,664 @@ -from dapr_agents.memory import ( - MemoryBase, - ConversationListMemory, - ConversationVectorMemory, -) -from dapr_agents.agents.utils.text_printer import ColorTextFormatter -from dapr_agents.types import MessagePlaceHolder, BaseMessage, ToolExecutionRecord -from dapr_agents.tool.executor import AgentToolExecutor -from dapr_agents.prompt.base import PromptTemplateBase -from dapr_agents.prompt import ChatPromptTemplate -from dapr_agents.tool.base import AgentTool -import re -from datetime import datetime -import logging +from __future__ import annotations + import asyncio -import signal -from abc import ABC, abstractmethod -from typing import ( - List, - Optional, - Dict, - Any, - Union, - Callable, - Literal, - ClassVar, +import logging +from datetime import datetime, timezone +from typing import Any, Dict, Iterable, List, Optional, Sequence, Union, Coroutine + +from dapr_agents.agents.components import AgentComponents +from dapr_agents.agents.configs import ( + AgentMemoryConfig, + AgentPubSubConfig, + AgentRegistryConfig, + AgentStateConfig, + AgentExecutionConfig, ) -from pydantic import BaseModel, Field, PrivateAttr, model_validator, ConfigDict +from dapr_agents.agents.prompting import AgentProfileConfig, PromptingAgentBase +from dapr_agents.agents.utils.text_printer import ColorTextFormatter from dapr_agents.llm.chat import ChatClientBase from dapr_agents.llm.utils.defaults import get_default_llm +from dapr_agents.memory import ConversationDaprStateMemory, ConversationListMemory +from dapr_agents.prompt.base import PromptTemplateBase +from dapr_agents.storage.daprstores.stateservice import StateStoreError +from dapr_agents.tool.base import AgentTool +from dapr_agents.tool.executor import AgentToolExecutor +from dapr_agents.types import AssistantMessage, ToolExecutionRecord, UserMessage logger = logging.getLogger(__name__) -class AgentBase(BaseModel, ABC): - """ - Base class for agents that interact with language models and manage tools for task execution. - - Args: - name: Agent name - role: Agent role - goal: Agent goal - instructions: List of instructions - tools: List of tools - llm: LLM client - memory: Memory instance +class AgentBase(AgentComponents): """ + Base class for agent behavior. - name: str = Field( - default="Dapr Agent", - description="The agent's name, defaulting to the role if not provided.", - ) - role: Optional[str] = Field( - default="Assistant", - description="The agent's role in the interaction (e.g., 'Weather Expert').", - ) - goal: Optional[str] = Field( - default="Help humans", - description="The agent's main objective (e.g., 'Provide Weather information').", - ) - # TODO: add a background/backstory field that would be useful for the agent to know about it's context/background for it's role. - instructions: Optional[List[str]] = Field( - default=None, description="Instructions guiding the agent's tasks." - ) - system_prompt: Optional[str] = Field( - default=None, - description="A custom system prompt, overriding name, role, goal, and instructions.", - ) - llm: Optional[ChatClientBase] = Field( - default=None, - description="Language model client for generating responses.", - ) - prompt_template: Optional[PromptTemplateBase] = Field( - default=None, description="The prompt template for the agent." - ) - # TODO: we need to add RBAC to tools to define what users and/or agents can use what tool(s). - tools: List[Union[AgentTool, Callable]] = Field( - default_factory=list, - description="Tools available for the agent to assist with tasks.", - ) - tool_choice: Optional[str] = Field( - default=None, - description="Strategy for selecting tools ('auto', 'required', 'none'). Defaults to 'auto' if tools are provided.", - ) - tool_history: List[ToolExecutionRecord] = Field( - default_factory=list, description="Executed tool calls during the conversation." - ) - # TODO: add a forceFinalAnswer field in case maxIterations is near/reached. Or do we have a conclusion baked in by default? Do we want this to derive a conclusion by default? - max_iterations: int = Field( - default=10, description="Max iterations for conversation cycles." - ) - # TODO(@Sicoyle): Rename this to make clearer - memory: MemoryBase = Field( - default_factory=ConversationListMemory, - description="Handles long-term conversation history (for all workflow instance-ids within the same session) and context storage.", - ) - # TODO: we should have a system_template, prompt_template, and response_template, or better separation here. - # If we have something like a customer service agent, we want diff templates for different types of interactions. - # In future, we could also have a way to dynamically change the template based on the context of the interaction. - template_format: Literal["f-string", "jinja2"] = Field( - default="jinja2", - description="The format used for rendering the prompt template.", - ) - - DEFAULT_SYSTEM_PROMPT: ClassVar[str] - """Default f-string template; placeholders will be swapped to Jinja if needed.""" - DEFAULT_SYSTEM_PROMPT = """ -# Today's date is: {date} - -## Name -Your name is {name}. - -## Role -Your role is {role}. - -## Goal -{goal}. - -## Instructions -{instructions}. -""".strip() - - _tool_executor: AgentToolExecutor = PrivateAttr() - _text_formatter: ColorTextFormatter = PrivateAttr( - default_factory=ColorTextFormatter - ) - - model_config = ConfigDict(arbitrary_types_allowed=True) - - @model_validator(mode="before") - def set_name_from_role(cls, values: dict): - # Set name to role if name is not provided - if not values.get("name") and values.get("role"): - values["name"] = values["role"] - return values - - @model_validator(mode="after") - def validate_llm(self): - """Validate that LLM is properly configured.""" - if hasattr(self, "llm"): - if self.llm is None: - logger.warning("LLM client is None, some functionality may be limited.") - else: - try: - # Validate LLM is properly configured by accessing it as this is required to be set. - _ = self.llm - except Exception as e: - logger.error(f"Failed to initialize LLM: {e}") - self.llm = None + Responsibilities: + - Profile/prompt wiring (system prompt, instructions, style, template). + - LLM client wiring. + - Tool exposure and execution adapter. + - Conversation memory management (configurable; defaults provided). - return self + Infrastructure (pub/sub, durable state, registry) is provided by `AgentComponents`. + """ - def model_post_init(self, __context: Any) -> None: - """ - Post-initialization hook for AgentBase. - Sets up the prompt template using a centralized helper, ensuring agent and LLM client reference the same template. - Also validates and pre-fills the template, and sets up graceful shutdown. + def __init__( + self, + *, + # Profile / prompt + profile_config: Optional[AgentProfileConfig] = None, + name: Optional[str] = None, + role: Optional[str] = None, + goal: Optional[str] = None, + instructions: Optional[Iterable[str]] = None, + style_guidelines: Optional[Iterable[str]] = None, + system_prompt: Optional[str] = None, + prompt_template: Optional[PromptTemplateBase] = None, + # Components (infrastructure) + pubsub_config: Optional[AgentPubSubConfig] = None, + state_config: Optional[AgentStateConfig] = None, + registry_config: Optional[AgentRegistryConfig] = None, + base_metadata: Optional[Dict[str, Any]] = None, + max_etag_attempts: int = 10, + # Memory / runtime + memory_config: Optional[AgentMemoryConfig] = None, + llm: Optional[ChatClientBase] = None, + tools: Optional[Iterable[Any]] = None, + # Metadata + agent_metadata: Optional[Dict[str, Any]] = None, + # Execution + execution_config: Optional[AgentExecutionConfig] = None, + ) -> None: + """ + Initialize an agent with behavior + infrastructure. Args: - __context (Any): Context passed from Pydantic's model initialization. - """ - self._tool_executor = AgentToolExecutor(tools=self.tools) - - # Set tool_choice to 'auto' if tools are provided, otherwise None - if self.tool_choice is None: - self.tool_choice = "auto" if self.tools else None - - # Initialize LLM if not provided - if self.llm is None: - self.llm = get_default_llm() + profile_config: Base profile config (name/role/goal/prompts). Optional if + individual fields are provided below. + name: Agent name (required if `profile_config` is omitted). + role: Agent role (e.g., "Assistant"). + goal: High-level agent objective. + instructions: Additional instruction strings for the prompt. + style_guidelines: Style directives for the prompt. + system_prompt: System prompt override. + prompt_template: Optional explicit prompt template instance. + + pubsub_config: Pub/Sub config used by `AgentComponents`. + state_config: Durable state config used by `AgentComponents`. + registry_config: Team registry config used by `AgentComponents`. + execution_config: Execution dials for the agent run. + base_metadata: Default Dapr state metadata used by `AgentComponents`. + max_etag_attempts: Concurrency retry count for registry mutations. + + memory_config: Memory backend configuration. If omitted and a state store + is configured, a Dapr-backed conversation memory is created by default. + llm: Chat client. Defaults to `get_default_llm()`. + tools: Optional tool callables or `AgentTool` instances. + + agent_metadata: Extra metadata to store in the registry. + """ + # Resolve and validate profile (ensures non-empty name). + resolved_profile = self._build_profile( + base_profile=profile_config, + name=name, + role=role, + goal=goal, + instructions=instructions, + style_guidelines=style_guidelines, + system_prompt=system_prompt, + ) + self.profile_config = resolved_profile + self.name = resolved_profile.name # type: ignore[assignment] + + # Wire infrastructure via AgentComponents. + super().__init__( + name=self.name, + pubsub_config=pubsub_config, + state_config=state_config, + registry_config=registry_config, + base_metadata=base_metadata, + max_etag_attempts=max_etag_attempts, + ) - # Centralize prompt template selection logic - self.prompt_template = self._initialize_prompt_template() - # Ensure LLM client and agent both reference the same template - if self.llm is not None: + # ----------------------------- + # Memory wiring + # ----------------------------- + self._memory_config = memory_config or AgentMemoryConfig() + if self._memory_config.store is None and state_config is not None: + # Auto-provision a Dapr-backed memory if we have a state store. + self._memory_config.store = ConversationDaprStateMemory( # type: ignore[union-attr] + store_name=state_config.store.store_name, + session_id=f"{self.name}-session", + ) + self.memory = self._memory_config.store or ConversationListMemory() + + # ----------------------------- + # Prompting helper + # ----------------------------- + self.prompting_helper = PromptingAgentBase( + name=self.name, + role=resolved_profile.role or "Assistant", + goal=resolved_profile.goal or "Help users accomplish their tasks.", + instructions=list(resolved_profile.instructions), + style_guidelines=list(resolved_profile.style_guidelines), + system_prompt=resolved_profile.system_prompt, + template_format=resolved_profile.template_format, + include_chat_history=True, + prompt_template=prompt_template, + profile_config=resolved_profile, + ) + # Keep profile config synchronized with helper defaults. + if self.profile_config.name is None: + self.profile_config.name = self.prompting_helper.name + if self.profile_config.role is None: + self.profile_config.role = self.prompting_helper.role + if self.profile_config.goal is None: + self.profile_config.goal = self.prompting_helper.goal + + self.prompt_template = self.prompting_helper.prompt_template + self._text_formatter = self.prompting_helper.text_formatter + + # ----------------------------- + # LLM wiring + # ----------------------------- + self.llm: ChatClientBase = llm or get_default_llm() + if self.llm: self.llm.prompt_template = self.prompt_template - self._validate_prompt_template() - self.prefill_agent_attributes() + # ----------------------------- + # Tools + # ----------------------------- + self.tools: List[Any] = list(tools or []) + self.tool_executor = AgentToolExecutor(tools=list(self.tools)) + self.tool_history: List[ToolExecutionRecord] = [] + + # ----------------------------- + # Execution config + # ----------------------------- + self.execution_config = execution_config or AgentExecutionConfig() + try: + self.execution_config.max_iterations = max( + 1, int(self.execution_config.max_iterations) + ) + except Exception: + self.execution_config.max_iterations = 10 + self.execution_config.tool_choice = self.execution_config.tool_choice or "auto" - # Set up graceful shutdown - self._shutdown_event = asyncio.Event() - self._setup_signal_handlers() + # ----------------------------- + # Load durable state (from AgentComponents) + # ----------------------------- + try: + self.load_state() + except Exception: # noqa: BLE001 + logger.warning("Agent failed to load persisted state; starting fresh.") + + # ----------------------------- + # Agent metadata & registry registration (from AgentComponents) + # ----------------------------- + base_meta: Dict[str, Any] = { + "name": self.name, + "orchestrator": False, + "role": self.prompting_helper.role, + "goal": self.prompting_helper.goal, + "instructions": list(self.prompting_helper.instructions), + } + if self.pubsub_config is not None: + base_meta["topic_name"] = self.agent_topic_name + base_meta["pubsub_name"] = self.message_bus_name + + merged_meta = {**base_meta, **(agent_metadata or {})} + self.agent_metadata = merged_meta + if self.registry_state is not None: + try: + self.register_agentic_system(metadata=merged_meta) + except StateStoreError: + logger.warning( + "Could not register agent metadata; registry unavailable." + ) + else: + logger.debug( + "Registry configuration not provided; skipping agent registration." + ) - super().model_post_init(__context) + # ------------------------------------------------------------------ + # Presentation helpers + # ------------------------------------------------------------------ + @property + def text_formatter(self) -> ColorTextFormatter: + """Formatter used for human-friendly console output.""" + return self._text_formatter + + @text_formatter.setter + def text_formatter(self, formatter: ColorTextFormatter) -> None: + """Override the default text formatter and keep the helper in sync.""" + self._text_formatter = formatter + if hasattr(self, "prompting_helper"): + self.prompting_helper._text_formatter = formatter - def _initialize_prompt_template(self) -> PromptTemplateBase: + def print_interaction( + self, source_agent_name: str, target_agent_name: str, message: str + ) -> None: """ - Determines which prompt template to use for the agent: - 1. If the user supplied one, use it. - 2. Else if the LLM client already has one, adopt that. - 3. Else generate a system_prompt and ChatPromptTemplate from agent attributes. + Print a formatted interaction between two agents. - Returns: - PromptTemplateBase: The selected or constructed prompt template. + Args: + source_agent_name: Sender name. + target_agent_name: Recipient name. + message: Message content. + """ + separator = "-" * 80 + parts = [ + (source_agent_name, "dapr_agents_mustard"), + (" -> ", "dapr_agents_teal"), + (f"{target_agent_name}\n\n", "dapr_agents_mustard"), + (message + "\n\n", None), + (separator + "\n", "dapr_agents_teal"), + ] + self._text_formatter.print_colored_text(parts) + + # ------------------------------------------------------------------ + # Prompting & memory utilities + # ------------------------------------------------------------------ + def build_initial_messages( + self, + user_input: Optional[Union[str, Dict[str, Any]]] = None, + **extra_variables: Any, + ) -> List[Dict[str, Any]]: """ - # 1) User provided one? - if self.prompt_template: - logger.debug("🛠️ Using provided agent.prompt_template") - return self.prompt_template - - # 2) LLM client has one? - if ( - self.llm - and hasattr(self.llm, "prompt_template") - and self.llm.prompt_template - ): - logger.debug("🔄 Syncing from llm.prompt_template") - return self.llm.prompt_template + Build the initial message list for an LLM call. - # 3) Build from system_prompt or attributes - if not self.system_prompt: - logger.debug("⚙️ Constructing system_prompt from attributes") - self.system_prompt = self.construct_system_prompt() + Args: + user_input: Optional user message or structured payload. + **extra_variables: Extra template variables for the prompt template. - logger.debug("⚙️ Building ChatPromptTemplate from system_prompt") - return self.construct_prompt_template() + Returns: + List of message dictionaries ready for an LLM chat API. + """ + return self.prompting_helper.build_initial_messages( + user_input, + chat_history=self.get_chat_history() + if self.prompting_helper.include_chat_history + else None, + **extra_variables, + ) - def _collect_template_attrs(self) -> tuple[Dict[str, str], List[str]]: + def get_chat_history(self) -> List[Dict[str, Any]]: """ - Collect agent attributes for prompt template pre-filling and warn about unused ones. - - valid: attributes set on self and declared in prompt_template.input_variables. - - unused: attributes set on self but not present in the template. + Retrieve the conversation history from the configured memory backend. + Returns: - (valid, unused): Tuple of dict of valid attrs and list of unused attr names. + A list of message-like dictionaries in normalized form. """ - attrs = ["name", "role", "goal", "instructions"] - valid: Dict[str, str] = {} - unused: List[str] = [] - if not self.prompt_template or not hasattr( - self.prompt_template, "input_variables" - ): - return valid, attrs # No template, all attrs are unused - original = set(self.prompt_template.input_variables) - - for attr in attrs: - val = getattr(self, attr, None) - if val is None: - continue - if attr in original: - # Only join instructions if it's a list and the template expects it - if attr == "instructions" and isinstance(val, list): - valid[attr] = "\n".join(val) - else: - valid[attr] = str(val) - else: - unused.append(attr) - return valid, unused - - def _setup_signal_handlers(self): - """Set up signal handlers for graceful shutdown""" try: - signal.signal(signal.SIGINT, self._signal_handler) - signal.signal(signal.SIGTERM, self._signal_handler) - except (OSError, ValueError): - # TODO: test this bc signal handlers may not work in all environments (e.g., Windows) - pass - - def _signal_handler(self, signum, frame): - """Handle interrupt signals gracefully""" - print(f"\nReceived signal {signum}. Shutting down gracefully...") - self._shutdown_event.set() - - def _validate_prompt_template(self) -> None: - """ - Ensures chat_history is always available, injects any declared attributes, - and warns if the user set attributes that aren't in the template. - """ - if not self.prompt_template: - return - - # Always make chat_history available - vars_set = set(self.prompt_template.input_variables) | {"chat_history"} - - # Inject any attributes the template declares - valid_attrs, unused_attrs = self._collect_template_attrs() - vars_set |= set(valid_attrs.keys()) - self.prompt_template.input_variables = list(vars_set) - - if unused_attrs: - logger.warning( - "Agent attributes set but not referenced in prompt_template: " - f"{', '.join(unused_attrs)}. Consider adding them to input_variables." - ) + history = self.memory.get_messages() + except Exception as exc: # noqa: BLE001 + logger.warning("Memory get_messages failed: %s", exc) + return [] + + normalized: List[Dict[str, Any]] = [] + for entry in history: + if hasattr(entry, "model_dump"): + normalized.append(entry.model_dump()) + elif isinstance(entry, dict): + normalized.append(dict(entry)) + return normalized + + def reset_memory(self) -> None: + """Clear all stored conversation messages.""" + if self.memory: + self.memory.reset_memory() - @property - def tool_executor(self) -> AgentToolExecutor: - """Returns the client to execute and manage tools, ensuring it's accessible but read-only.""" - return self._tool_executor - - @property - def text_formatter(self) -> ColorTextFormatter: - """Returns the text formatter for the agent.""" - return self._text_formatter + def get_last_message(self) -> Optional[Dict[str, Any]]: + """Return the last message stored in memory, if any.""" + history = self.get_chat_history() + return dict(history[-1]) if history else None - def get_chat_history(self, task: Optional[str] = None) -> List[Dict[str, Any]]: + def get_last_user_message( + self, messages: Sequence[Dict[str, Any]] + ) -> Optional[Dict[str, Any]]: """ - Retrieves the chat history from memory as a list of dictionaries. + Return the most recent message authored by the user from a sequence. Args: - task (Optional[str]): The task or query provided by the user (used for vector search). + messages: Message sequence from which to extract the last user message. Returns: - List[Dict[str, Any]]: The chat history as dictionaries. + The last user message as a dict, or None if not present. """ - if isinstance(self.memory, ConversationVectorMemory) and task: - if ( - hasattr(self.memory.vector_store, "embedding_function") - and self.memory.vector_store.embedding_function - and hasattr( - self.memory.vector_store.embedding_function, "embed_documents" - ) - ): - query_embeddings = self.memory.vector_store.embedding_function.embed( - task - ) - messages = self.memory.get_messages(query_embeddings=query_embeddings) - else: - messages = self.memory.get_messages() - else: - messages = self.memory.get_messages() - return messages + match = self._get_last_user_message(messages) + if not match: + return None + result = dict(match) + content = result.get("content") + if isinstance(content, str): + result["content"] = content.strip() + return result - @property - def chat_history(self) -> List[Dict[str, Any]]: + def get_last_message_if_user( + self, messages: Sequence[Dict[str, Any]] + ) -> Optional[Dict[str, Any]]: """ - Returns the full chat history as a list of dictionaries. + Return the last message only if it is authored by the user. + + Args: + messages: Message sequence. Returns: - List[Dict[str, Any]]: The chat history. + The last message as a dict if its role is 'user'; otherwise None. """ - return self.get_chat_history() + if messages and messages[-1].get("role") == "user": + msg = dict(messages[-1]) + content = msg.get("content") + if isinstance(content, str): + msg["content"] = content.strip() + return msg + return None - @abstractmethod - def run(self, input_data: Union[str, Dict[str, Any]]) -> Any: + def get_llm_tools(self) -> List[Union[AgentTool, Dict[str, Any]]]: """ - Executes the agent's main logic based on provided inputs. + Convert configured tools into LLM-friendly tool specs. - Args: - inputs (Dict[str, Any]): A dictionary with dynamic input values for task execution. + Returns: + List of `AgentTool` or tool-spec dicts. """ - pass + llm_tools: List[Union[AgentTool, Dict[str, Any]]] = [] + for tool in self.tools: + if isinstance(tool, AgentTool): + llm_tools.append(tool) + elif callable(tool): + try: + llm_tools.append(AgentTool.from_func(tool)) + except Exception as exc: # noqa: BLE001 + logger.warning("Failed to convert callable to AgentTool: %s", exc) + return llm_tools - def prefill_agent_attributes(self) -> None: + def _build_profile( + self, + *, + base_profile: Optional[AgentProfileConfig], + name: Optional[str], + role: Optional[str], + goal: Optional[str], + instructions: Optional[Iterable[str]], + style_guidelines: Optional[Iterable[str]], + system_prompt: Optional[str], + ) -> AgentProfileConfig: """ - Pre-fill prompt_template with agent attributes if specified in `input_variables`. - Uses _collect_template_attrs to avoid duplicate logic and ensure consistency. - """ - if not self.prompt_template: - return - - # Re-use our helper to split valid vs. unused - valid_attrs, unused_attrs = self._collect_template_attrs() + Construct a concrete AgentProfileConfig from a base profile and field overrides. - if unused_attrs: - logger.warning( - "Agent attributes set but not used in prompt_template: " - f"{', '.join(unused_attrs)}. Consider adding them to input_variables." - ) + Args: + base_profile: Optional starting profile to clone (avoids mutating the caller’s). + name: Name override. + role: Role override. + goal: Goal/mission override. + instructions: Additional instruction strings. + style_guidelines: Prompt style directives. + system_prompt: System prompt override. - if valid_attrs: - self.prompt_template = self.prompt_template.pre_fill_variables( - **valid_attrs + Returns: + A fully-populated AgentProfileConfig with a non-empty name. + + Raises: + ValueError: If the resulting profile has an empty name. + """ + # Clone the base profile to avoid external side effects. + if base_profile is not None: + profile = AgentProfileConfig( + name=base_profile.name, + role=base_profile.role, + goal=base_profile.goal, + instructions=list(base_profile.instructions), + style_guidelines=list(base_profile.style_guidelines), + system_prompt=base_profile.system_prompt, + template_format=base_profile.template_format, + modules=tuple(base_profile.modules), + module_overrides=dict(base_profile.module_overrides), ) - logger.debug(f"Pre-filled template with: {list(valid_attrs.keys())}") else: - logger.debug("No prompt_template variables needed pre-filling.") + profile = AgentProfileConfig() + + # Apply field-level overrides when provided. + if name is not None: + profile.name = name + if role is not None: + profile.role = role + if goal is not None: + profile.goal = goal + if instructions is not None: + profile.instructions = list(instructions) + if style_guidelines is not None: + profile.style_guidelines = list(style_guidelines) + if system_prompt is not None: + profile.system_prompt = system_prompt + + # Durable agents require a concrete name for state/memory/registry keys. + if not profile.name or not profile.name.strip(): + raise ValueError( + "Durable agents require a non-empty name " + "(provide name= or profile_config.name)." + ) - def construct_system_prompt(self) -> str: - """ - Build the system prompt for the agent using a single template string. - - Fills in the current date. - - Leaves placeholders for name, role, goal, and instructions as variables (instructions only if set). - - Converts placeholders to Jinja2 syntax if requested. + return profile - Returns: - str: The formatted system prompt string. + # ------------------------------------------------------------------ + # Internal utilities + # ------------------------------------------------------------------ + @staticmethod + def _run_asyncio_task(coro: Coroutine[Any, Any, Any]) -> None: """ - # Only fill in the date; leave all other placeholders as variables - instructions_placeholder = "{instructions}" if self.instructions else "" - filled = self.DEFAULT_SYSTEM_PROMPT.format( - date=datetime.now().strftime("%B %d, %Y"), - name="{name}", - role="{role}", - goal="{goal}", - instructions=instructions_placeholder, - ) + Execute an async coroutine from a synchronous context, creating a fresh loop if needed. - # If using Jinja2, swap braces for all placeholders - if self.template_format == "jinja2": - # Replace every {foo} with {{foo}} - return re.sub(r"\{(\w+)\}", r"{{\1}}", filled) + Args: + coro: The coroutine to execute. + """ + try: + asyncio.get_running_loop() + except RuntimeError: + asyncio.run(coro) else: - return filled + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(coro) + finally: + loop.close() - def construct_prompt_template(self) -> ChatPromptTemplate: + @staticmethod + def _serialize_message(message: Any) -> Dict[str, Any]: """ - Constructs a ChatPromptTemplate that includes the system prompt and a placeholder for chat history. - Ensures that the template is flexible and adaptable to dynamically handle pre-filled variables. + Convert a message-like object into a plain dict for history persistence. + + Args: + message: Pydantic model, dict, or object exposing `model_dump`. Returns: - ChatPromptTemplate: A formatted prompt template for the agent. - """ - # Construct the system prompt if not provided - system_prompt = self.system_prompt or self.construct_system_prompt() - - # Create the template with placeholders for system message and chat history - return ChatPromptTemplate.from_messages( - messages=[ - ("system", system_prompt), - MessagePlaceHolder(variable_name="chat_history"), - ], - template_format=self.template_format, + Normalized dictionary representation. + + Raises: + TypeError: When the input type is unsupported. + """ + if hasattr(message, "model_dump"): + return message.model_dump() + if isinstance(message, dict): + return dict(message) + if hasattr(message, "__dict__"): + return dict(message.__dict__) + raise TypeError( + f"Unsupported message type for serialization: {type(message)!r}" ) - def construct_messages( - self, input_data: Union[str, Dict[str, Any]] + def _get_last_user_message( + self, messages: Sequence[Dict[str, Any]] + ) -> Optional[Dict[str, Any]]: + """Find the last user-role message from the given sequence.""" + for message in reversed(messages): + if message.get("role") == "user": + return message + return None + + # ------------------------------------------------------------------ + # State-aware message helpers (use AgentComponents' state model) + # ------------------------------------------------------------------ + def _construct_messages_with_instance_history( + self, instance_id: str ) -> List[Dict[str, Any]]: """ - Constructs and formats initial messages based on input type, passing chat_history as a list, without mutating self.prompt_template. + Build a conversation history combining persistent memory and per-instance messages. Args: - input_data (Union[str, Dict[str, Any]]): User input, either as a string or dictionary. + instance_id: Workflow instance identifier. Returns: - List[Dict[str, Any]]: List of formatted messages, including the user message if input_data is a string. + Combined message history excluding system messages from instance timeline. """ - if not self.prompt_template: - raise ValueError( - "Prompt template must be initialized before constructing messages." - ) + container = self._get_entry_container() + entry = container.get(instance_id) if container else None - chat_history = self.get_chat_history() # List[Dict[str, Any]] + instance_messages: List[Dict[str, Any]] = [] + if entry and hasattr(entry, "messages"): + for msg in getattr(entry, "messages"): + serialized = self._serialize_message(msg) + if serialized.get("role") != "system": + instance_messages.append(serialized) - if isinstance(input_data, str): - formatted_messages = self.prompt_template.format_prompt( - chat_history=chat_history - ) - if isinstance(formatted_messages, list): - user_message = {"role": "user", "content": input_data} - return formatted_messages + [user_message] - else: - return [ - {"role": "system", "content": formatted_messages}, - {"role": "user", "content": input_data}, - ] - - elif isinstance(input_data, dict): - input_vars = dict(input_data) - if "chat_history" not in input_vars: - input_vars["chat_history"] = chat_history - formatted_messages = self.prompt_template.format_prompt(**input_vars) - if isinstance(formatted_messages, list): - return formatted_messages - else: - return [{"role": "system", "content": formatted_messages}] - - else: - raise ValueError("Input data must be either a string or dictionary.") - - def reset_memory(self): - """Clears all messages stored in the agent's memory.""" - self.memory.reset_memory() + persistent_memory: List[Dict[str, Any]] = [] + try: + for msg in self.memory.get_messages(): + try: + persistent_memory.append(self._serialize_message(msg)) + except TypeError: + logger.debug( + "Unsupported memory message type %s; skipping.", type(msg) + ) + except Exception: # noqa: BLE001 + logger.debug("Unable to load persistent memory.", exc_info=True) + + history: List[Dict[str, Any]] = [] + history.extend(persistent_memory) + history.extend(instance_messages) + return history + + def _sync_system_messages_with_state( + self, + instance_id: str, + all_messages: Sequence[Dict[str, Any]], + ) -> None: + """ + Persist the latest set of system messages into the instance state. - def get_last_message(self) -> Optional[Dict[str, Any]]: + Args: + instance_id: Workflow instance id. + all_messages: Complete message list to scan for system-role messages. """ - Retrieves the last message from the chat history. + # Delegate to AgentComponents logic. + self.sync_system_messages(instance_id=instance_id, all_messages=all_messages) - Returns: - Optional[Dict[str, Any]]: The last message in the history as a dictionary, or None if none exist. + def _process_user_message( + self, + instance_id: str, + task: Optional[str], + user_message_copy: Optional[Dict[str, Any]], + ) -> None: """ - chat_history = self.get_chat_history() - if chat_history: - last_msg = chat_history[-1] - if isinstance(last_msg, BaseMessage): - return last_msg.model_dump() - return last_msg - return None + Append a user message into the instance timeline and memory, and persist state. - def get_last_user_message( - self, messages: List[Dict[str, Any]] - ) -> Optional[Dict[str, Any]]: + Args: + instance_id: Workflow instance id. + task: Optional task string; if missing, no-op. + user_message_copy: Message dict to append. """ - Retrieves the last user message in a list of messages, returning a copy with trimmed content. + if not task or not user_message_copy: + return - Args: - messages (List[Dict[str, Any]]): List of formatted messages to search. + container = self._get_entry_container() + entry = container.get(instance_id) if container else None + if entry is None or not hasattr(entry, "messages"): + return - Returns: - Optional[Dict[str, Any]]: The last user message (copy) with trimmed content, or None if no user message exists. - """ - # Iterate in reverse to find the most recent 'user' role message - for message in reversed(messages): - if message.get("role") == "user": - # Return a copy with trimmed content - msg_copy = dict(message) - msg_copy["content"] = msg_copy["content"].strip() - return msg_copy - return None + # Use configured coercer / message model + message_model = ( + self._message_coercer(user_message_copy) # type: ignore[attr-defined] + if getattr(self, "_message_coercer", None) + else self._message_dict_to_message_model(user_message_copy) + ) + entry.messages.append(message_model) # type: ignore[attr-defined] + if hasattr(entry, "last_message"): + entry.last_message = message_model # type: ignore[attr-defined] - def get_last_message_if_user( - self, messages: List[Dict[str, Any]] - ) -> Optional[Dict[str, Any]]: - """ - Returns the last message only if it is a user message; otherwise, returns None. + session_id = getattr(getattr(self, "memory", None), "session_id", None) + if session_id is not None and hasattr(entry, "session_id"): + entry.session_id = str(session_id) # type: ignore[attr-defined] - Args: - messages (List[Dict[str, Any]]): List of formatted messages to check. + self.memory.add_message( + UserMessage(content=user_message_copy.get("content", "")) + ) + self.save_state() - Returns: - Optional[Dict[str, Any]]: The last message (copy) with trimmed content if it is a user message, else None. + def _save_assistant_message( + self, instance_id: str, assistant_message: Dict[str, Any] + ) -> None: """ - if messages and messages[-1].get("role") == "user": - msg_copy = dict(messages[-1]) - msg_copy["content"] = msg_copy["content"].strip() - return msg_copy - return None + Append an assistant message into the instance timeline and memory, and persist state. - def get_llm_tools(self) -> List[Union[AgentTool, Dict[str, Any]]]: + Args: + instance_id: Workflow instance id. + assistant_message: Assistant message dict (will be tagged with agent name). """ - Converts tools to the format expected by LLM clients. + assistant_message["name"] = self.name - Returns: - List[Union[AgentTool, Dict[str, Any]]]: Tools in LLM-compatible format. - """ - llm_tools: List[Union[AgentTool, Dict[str, Any]]] = [] - for tool in self.tools: - if isinstance(tool, AgentTool): - llm_tools.append(tool) - elif callable(tool): - try: - agent_tool = AgentTool.from_func(tool) - llm_tools.append(agent_tool) - except Exception as e: - logger.warning(f"Failed to convert callable to AgentTool: {e}") - continue - return llm_tools + container = self._get_entry_container() + entry = container.get(instance_id) if container else None + if entry is None or not hasattr(entry, "messages"): + return - def pre_fill_prompt_template(self, **kwargs: Union[str, Callable[[], str]]) -> None: + message_id = assistant_message.get("id") + if message_id and any( + getattr(msg, "id", None) == message_id for msg in getattr(entry, "messages") + ): + return + + message_model = ( + self._message_coercer(assistant_message) # type: ignore[attr-defined] + if getattr(self, "_message_coercer", None) + else self._message_dict_to_message_model(assistant_message) + ) + entry.messages.append(message_model) # type: ignore[attr-defined] + if hasattr(entry, "last_message"): + entry.last_message = message_model # type: ignore[attr-defined] + + self.memory.add_message(AssistantMessage(**assistant_message)) + self.save_state() + + # ------------------------------------------------------------------ + # Small convenience wrappers + # ------------------------------------------------------------------ + def list_team_agents( + self, *, team: Optional[str] = None, include_self: bool = True + ) -> Dict[str, Any]: """ - Pre-fills the prompt template with specified variables, updating input variables if applicable. + Convenience wrapper over `get_agents_metadata`. Args: - **kwargs: Variables to pre-fill in the prompt template. These can be strings or callables - that return strings. + team: Team override. + include_self: If True, include this agent in the results. - Notes: - - Existing pre-filled variables will be overwritten by matching keys in `kwargs`. - - This method does not affect the `chat_history` which is dynamically updated. + Returns: + Mapping of agent name to metadata. """ - if not self.prompt_template: - raise ValueError( - "Prompt template must be initialized before pre-filling variables." - ) + return self.get_agents_metadata( + exclude_self=not include_self, + exclude_orchestrator=False, + team=team, + ) - self.prompt_template = self.prompt_template.pre_fill_variables(**kwargs) - logger.debug(f"Pre-filled prompt template with variables: {kwargs.keys()}") + # ------------------------------------------------------------------ + # Misc helpers + # ------------------------------------------------------------------ + @staticmethod + def _coerce_datetime(value: Optional[Any]) -> datetime: + """Coerce strings/None to a timezone-aware UTC datetime.""" + if isinstance(value, datetime): + return value + if isinstance(value, str): + try: + return datetime.fromisoformat(value) + except ValueError: + pass + return datetime.now(timezone.utc) diff --git a/dapr_agents/agents/components.py b/dapr_agents/agents/components.py new file mode 100644 index 00000000..1d785576 --- /dev/null +++ b/dapr_agents/agents/components.py @@ -0,0 +1,732 @@ +from __future__ import annotations + +import logging +import random +import time +from datetime import datetime, timezone +from typing import Any, Callable, Dict, Optional, Sequence, Type + +from dapr.clients.grpc._state import Concurrency, Consistency, StateOptions +from pydantic import BaseModel, ValidationError + +from dapr_agents.agents.configs import ( + AgentPubSubConfig, + AgentRegistryConfig, + AgentStateConfig, +) +from dapr_agents.agents.schemas import ( + AgentWorkflowEntry, + AgentWorkflowMessage, + AgentWorkflowState, +) +from dapr_agents.storage.daprstores.stateservice import StateStoreError +from dapr_agents.types.workflow import DaprWorkflowStatus + +logger = logging.getLogger(__name__) + + +class AgentComponents: + """ + Thin infrastructure layer for agents/orchestrators. + + Handles: + - Pub/Sub plumbing (topic names, bus name). + - Durable workflow state (load/save, instance bootstrapping). + - Team registry operations (list/register/mutate with optimistic concurrency). + + Higher-level concerns (prompting, memory, tools) should remain outside this class. + """ + + def __init__( + self, + *, + name: str, + pubsub_config: Optional[AgentPubSubConfig] = None, + state_config: Optional[AgentStateConfig] = None, + registry_config: Optional[AgentRegistryConfig] = None, + base_metadata: Optional[Dict[str, Any]] = None, + max_etag_attempts: int = 10, + ) -> None: + """ + Initialize component wiring. + + Args: + name: Logical agent name; used for keys/topics when not overridden. + pubsub_config: Dapr pub/sub configuration for this agent. + state_config: Durable state (Dapr state store, key overrides, defaults, model customization). + registry_config: Agent registry backing store and team settings. + base_metadata: Base metadata for Dapr state operations. + max_etag_attempts: Max optimistic-concurrency retries on registry mutations. + """ + self.name = name + + # ----------------------------- + # Pub/Sub configuration (copy) + # ----------------------------- + self._pubsub_config: Optional[AgentPubSubConfig] = None + if pubsub_config is not None: + # Copy only what we need to avoid accidental external mutation. + self._pubsub_config = AgentPubSubConfig( + pubsub_name=pubsub_config.pubsub_name, + agent_topic=pubsub_config.agent_topic or name, + broadcast_topic=pubsub_config.broadcast_topic, + ) + + # ----------------------------- + # State configuration and model (flexible) + # ----------------------------- + self._state_config = state_config + self.state_store = state_config.store if state_config else None + override_state_key = state_config.state_key if state_config else None + self.state_key = override_state_key or f"{self.name}:workflow_state" + + # Customization points (classes + hooks) + self._state_model_cls: Type[BaseModel] = ( + state_config.state_model_cls if state_config else AgentWorkflowState + ) + self._message_model_cls: Type[BaseModel] = ( + state_config.message_model_cls if state_config else AgentWorkflowMessage + ) + self._entry_factory: Optional[Callable[..., Any]] = ( + state_config.entry_factory if state_config else None + ) + self._message_coercer: Optional[Callable[[Dict[str, Any]], Any]] = ( + state_config.message_coercer if state_config else None + ) + self._entry_container_getter: Optional[ + Callable[[BaseModel], Optional[dict]] + ] = ( + getattr(state_config, "entry_container_getter", None) + if state_config + else None + ) + + # Seed the default model from config or empty instance + if state_config and state_config.default_state is not None: + default_state_model = self._state_model_cls.model_validate( + state_config.default_state + ) + else: + default_state_model = self._state_model_cls() + self._state_default_model: BaseModel = default_state_model + self._state_model: BaseModel = self._state_default_model.model_copy(deep=True) + + # ----------------------------- + # Registry configuration + # ----------------------------- + self._registry_config = registry_config + self.registry_state = registry_config.store if registry_config else None + self._registry_prefix = "agents:" + self._registry_team_override = ( + registry_config.team_name + if registry_config and registry_config.team_name + else "default" + ) + + # ----------------------------- + # Dapr save options & metadata + # ----------------------------- + self._save_options = StateOptions( + concurrency=Concurrency.first_write, + consistency=Consistency.strong, + ) + self._base_metadata = dict(base_metadata or {"contentType": "application/json"}) + self._max_etag_attempts = max_etag_attempts + + # ------------------------------------------------------------------ + # Pub/Sub helpers + # ------------------------------------------------------------------ + @property + def pubsub_config(self) -> Optional[AgentPubSubConfig]: + """Return the configured pub/sub settings, if any.""" + return self._pubsub_config + + @property + def message_bus_name(self) -> str: + """Return the Dapr pub/sub component name (bus).""" + if not self._pubsub_config: + raise RuntimeError("No pubsub configuration available for this agent.") + return self._pubsub_config.pubsub_name + + @property + def agent_topic_name(self) -> str: + """Return the per-agent topic name.""" + if not self._pubsub_config: + raise RuntimeError("No pubsub configuration available for this agent.") + return self._pubsub_config.agent_topic or self.name + + @property + def broadcast_topic_name(self) -> Optional[str]: + """Return the broadcast topic name, if one was configured.""" + if not self._pubsub_config: + return None + return self._pubsub_config.broadcast_topic + + # ------------------------------------------------------------------ + # State helpers + # ------------------------------------------------------------------ + @property + def workflow_state(self) -> BaseModel: + """Return the in-memory workflow state model (customizable model).""" + return self._state_model + + @property + def state(self) -> Dict[str, Any]: + """Return the workflow state as a JSON-serializable dict.""" + return self._state_model.model_dump(mode="json") + + def load_state(self) -> None: + """ + Load the durable workflow state snapshot into memory. + + If no state store is configured, resets the in-memory model to defaults. + """ + if not self.state_store: + logger.debug("No state store configured; using in-memory state only.") + self._state_model = self._initial_state_model() + return + + snapshot = self.state_store.load( + key=self.state_key, + default=self._initial_state(), + ) + try: + if isinstance(snapshot, dict): + self._state_model = self._state_model_cls.model_validate(snapshot) + else: + raise TypeError(f"Unexpected state snapshot type {type(snapshot)}") + except (ValidationError, TypeError) as exc: + logger.warning( + "Invalid workflow state encountered (%s); resetting to defaults.", exc + ) + self._state_model = self._initial_state_model() + + def save_state(self) -> None: + """ + Persist the current workflow state with optimistic concurrency. + + No-op when no state store is configured. Uses load_with_etag + save(etag=...) + with a short retry loop to avoid lost updates under contention. + """ + if not self.state_store: + logger.debug("No state store configured; skipping state persistence.") + return + + key = self.state_key + meta = self._state_metadata_for_key(key) + attempts = max(1, min(self._max_etag_attempts, 10)) + + # Ensure the state document exists so we can get a concrete ETag. + try: + current, etag = self.state_store.load_with_etag( + key=key, + default=self._initial_state(), + state_metadata=meta, + ) + if etag is None: + # Initialize to get an etag + self.state_store.save( + key=key, + value=current if isinstance(current, dict) else self.state, + etag=None, + state_metadata=meta, + state_options=self._save_options, + ) + except Exception: + logger.exception("Failed to initialize state document for key '%s'.", key) + # Best-effort attempt to proceed; if this fails below, we'll log again. + + for attempt in range(1, attempts + 1): + try: + _, etag = self.state_store.load_with_etag( + key=key, + default=self._initial_state(), + state_metadata=meta, + ) + self.state_store.save( + key=key, + value=self.state, + etag=etag, + state_metadata=meta, + state_options=self._save_options, + ) + return + except Exception as exc: # noqa: BLE001 + logger.warning( + "Conflict during workflow state save (attempt %d/%d) for '%s': %s", + attempt, + attempts, + key, + exc, + ) + if attempt == attempts: + logger.exception( + "Failed to persist agent state after %d attempts.", attempts + ) + return + time.sleep(min(0.25 * attempt, 1.0) * (1 + random.uniform(0, 0.25))) + + def _initial_state(self) -> Dict[str, Any]: + """Return a deep-copied default state as a plain dict.""" + return self._state_default_model.model_copy(deep=True).model_dump(mode="json") + + def _initial_state_model(self) -> BaseModel: + """Return a deep-copied default state model.""" + return self._state_default_model.model_copy(deep=True) + + def ensure_instance_exists( + self, + *, + instance_id: str, + input_value: Any, + triggering_workflow_instance_id: Optional[str], + time: Optional[datetime] = None, + ) -> None: + """ + Ensure a workflow instance entry exists in the state model. + + Uses a pluggable `entry_factory` when provided. If absent, falls back to a + best-effort default that assumes an `instances` dict on the root model. + + Args: + instance_id: Unique workflow instance identifier. + input_value: Input payload used to start the workflow. + triggering_workflow_instance_id: Parent workflow instance id, if any. + time: Optional start time (defaults to now, UTC). + + Raises: + RuntimeError: If a custom entry factory raises and is not handled. + """ + container = self._get_entry_container() + if container is None: + # No instances concept; nothing to do. + return + if instance_id in container: + return + + start_time = self._coerce_datetime(time) + + if self._entry_factory is not None: + entry = self._entry_factory( + instance_id=instance_id, + input_value=input_value, + triggering_workflow_instance_id=triggering_workflow_instance_id, + start_time=start_time, + ) + else: + # Default (legacy) AgentWorkflowEntry-compatible record + entry = AgentWorkflowEntry( + input_value=str(input_value), + workflow_instance_id=instance_id, + triggering_workflow_instance_id=triggering_workflow_instance_id, + workflow_name=None, + session_id=None, + start_time=start_time, + status=DaprWorkflowStatus.RUNNING.value, + ) + container[instance_id] = entry + + def sync_system_messages( + self, + instance_id: str, + all_messages: Sequence[Dict[str, Any]], + ) -> None: + """ + Synchronize system messages into the workflow state for a given instance. + + Uses `message_coercer` or `message_model_cls` to construct message entries. + + Args: + instance_id: Workflow instance identifier. + all_messages: Full (system/user/assistant) list; only 'system' are synced. + """ + container = self._get_entry_container() + if container is None: + return + entry = container.get(instance_id) + if entry is None: + return + + system_messages = [m for m in all_messages if m.get("role") == "system"] + if not system_messages: + return + + existing = list(getattr(entry, "system_messages", []) or []) + existing_sig = [ + (getattr(m, "content", None), getattr(m, "name", None)) for m in existing + ] + new_sig = [(m.get("content"), m.get("name")) for m in system_messages] + if existing_sig == new_sig: + return + + # Build new models + if self._message_coercer: + new_models = [self._message_coercer(m) for m in system_messages] + else: + new_models = [ + self._message_dict_to_message_model(m) for m in system_messages + ] + + # Assign back if the field exists; otherwise, skip + if hasattr(entry, "system_messages"): + entry.system_messages = new_models # type: ignore[attr-defined] + + # De-duplicate in entry.messages if that field exists + if hasattr(entry, "messages"): + filtered = [ + m + for m in getattr(entry, "messages") + if getattr(m, "role", None) != "system" + ] + entry.messages = filtered # type: ignore[attr-defined] + # Fix last_message if applicable + if ( + getattr(entry, "last_message", None) is not None + and getattr(entry.last_message, "role", None) == "system" + ): + non_system = [ + m + for m in getattr(entry, "messages") + if getattr(m, "role", None) != "system" + ] + entry.last_message = non_system[-1] if non_system else None # type: ignore[attr-defined] + + def _message_dict_to_message_model(self, message: Dict[str, Any]) -> Any: + """ + Convert a dict into the configured message model. + + Falls back to returning the raw dict if instantiation fails (to avoid hard + failures with custom models). Logs a warning the first time a shape mismatch + is observed to help with debugging template drift. + """ + allowed = { + "role", + "content", + "name", + "tool_calls", + "function_call", + "tool_call_id", + "id", + } + payload = {k: message[k] for k in allowed if k in message} + payload.setdefault("role", "system") + payload.setdefault("content", "") + + try: + return self._message_model_cls(**payload) + except Exception as exc: # noqa: BLE001 + # Keep noisy logs under control by summarizing the mismatch. + try: + role = payload.get("role") + name = payload.get("name") + msg_id = payload.get("id") or payload.get("tool_call_id") + logger.warning( + "Message coercion failed for role=%r name=%r id=%r with %s; keeping raw payload.", + role, + name, + msg_id, + type(exc).__name__, + ) + except Exception: + # Don't let logging fail the fallback + pass + return payload + + # ------------------------------------------------------------------ + # Registry helpers + # ------------------------------------------------------------------ + def register_agentic_system( + self, + *, + metadata: Optional[Dict[str, Any]] = None, + team: Optional[str] = None, + ) -> None: + """ + Upsert this agent's metadata in the team registry. + + Args: + metadata: Additional metadata to store for this agent. + team: Team override; falls back to configured default team. + """ + if not self.registry_state: + logger.debug( + "No registry configured; skipping registration for %s", self.name + ) + return + + payload = dict(metadata or {}) + payload.setdefault("name", self.name) + payload.setdefault("team", self._effective_team(team)) + + if self._pubsub_config is not None: + payload.setdefault("topic_name", self.agent_topic_name) + payload.setdefault("pubsub_name", self.message_bus_name) + if self.broadcast_topic_name: + payload.setdefault("broadcast_topic", self.broadcast_topic_name) + + self._upsert_agent_entry( + team=self._effective_team(team), + agent_name=self.name, + agent_metadata=payload, + ) + + def deregister_agentic_system(self, *, team: Optional[str] = None) -> None: + """ + Remove this agent from the team registry. + + Args: + team: Team override; falls back to configured default team. + """ + if not self.registry_state: + return + self._remove_agent_entry(team=self._effective_team(team), agent_name=self.name) + + def get_agents_metadata( + self, + *, + exclude_self: bool = True, + exclude_orchestrator: bool = False, + team: Optional[str] = None, + ) -> Dict[str, Any]: + """ + Load and optionally filter all agents registered for a team. + + Args: + exclude_self: If True, omit this agent from results. + exclude_orchestrator: If True, omit agents with orchestrator=True. + team: Team override; falls back to configured default team. + + Returns: + Mapping of agent name to metadata. + + Raises: + RuntimeError: When no registry is configured or load fails. + """ + if not self.registry_state: + raise RuntimeError("registry_state must be provided to use agent registry") + + key = self._team_registry_key(team) + try: + agents_metadata = self.registry_state.load( + key=key, + default={}, + state_metadata=self._state_metadata_for_key(key), + ) + if not agents_metadata: + logger.info("No agents found in registry key '%s'.", key) + return {} + + filtered = { + name: meta + for name, meta in agents_metadata.items() + if not (exclude_self and name == self.name) + and not (exclude_orchestrator and meta.get("orchestrator", False)) + } + return filtered + except Exception as exc: # noqa: BLE001 + logger.error("Failed to retrieve agents metadata: %s", exc, exc_info=True) + raise RuntimeError(f"Error retrieving agents metadata: {str(exc)}") from exc + + def _mutate_registry_entry( + self, + *, + team: Optional[str], + mutator: Callable[[Dict[str, Any]], Optional[Dict[str, Any]]], + max_attempts: Optional[int] = None, + ) -> None: + """ + Apply a mutation to the team registry with optimistic concurrency. + + Args: + team: Team identifier. + mutator: Function that returns the updated registry dict (or None for no-op). + max_attempts: Override for concurrency retries; defaults to init value. + + Raises: + StateStoreError: If the mutation fails after retries due to contention. + """ + if not self.registry_state: + raise RuntimeError( + "registry_state must be provided to mutate the agent registry" + ) + + key = self._team_registry_key(team) + meta = self._state_metadata_for_key(key) + attempts = max_attempts or self._max_etag_attempts + + self._ensure_registry_initialized(key=key, meta=meta) + + for attempt in range(1, attempts + 1): + try: + current, etag = self.registry_state.load_with_etag( + key=key, + default={}, + state_metadata=meta, + ) + if not isinstance(current, dict): + current = {} + + updated = mutator(dict(current)) + if updated is None: + return + + self.registry_state.save( + key=key, + value=updated, + etag=etag, + state_metadata=meta, + state_options=self._save_options, + ) + return + except Exception as exc: # noqa: BLE001 + logger.warning( + "Conflict during registry mutation (attempt %d/%d) for '%s': %s", + attempt, + attempts, + key, + exc, + ) + if attempt == attempts: + raise StateStoreError( + f"Failed to mutate agent registry key '{key}' after {attempts} attempts." + ) from exc + # Jittered backoff to reduce thundering herd during contention. + time.sleep(min(1.0 * attempt, 3.0) * (1 + random.uniform(0, 0.25))) + + def _upsert_agent_entry( + self, + *, + team: Optional[str], + agent_name: str, + agent_metadata: Dict[str, Any], + max_attempts: Optional[int] = None, + ) -> None: + """ + Insert/update a single agent record in the team registry. + + Args: + team: Team identifier. + agent_name: Agent name (key). + agent_metadata: Metadata value to write. + max_attempts: Override retry attempts. + """ + + def mutator(current: Dict[str, Any]) -> Optional[Dict[str, Any]]: + if current.get(agent_name) == agent_metadata: + return None + current[agent_name] = agent_metadata + return current + + self._mutate_registry_entry( + team=team, + mutator=mutator, + max_attempts=max_attempts, + ) + + def _remove_agent_entry( + self, + *, + team: Optional[str], + agent_name: str, + max_attempts: Optional[int] = None, + ) -> None: + """ + Delete a single agent record from the team registry. + + Args: + team: Team identifier. + agent_name: Agent name (key). + max_attempts: Override retry attempts. + """ + + def mutator(current: Dict[str, Any]) -> Optional[Dict[str, Any]]: + if agent_name not in current: + return None + del current[agent_name] + return current + + self._mutate_registry_entry( + team=team, + mutator=mutator, + max_attempts=max_attempts, + ) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + def _effective_team(self, team: Optional[str] = None) -> str: + """Compute the effective team name from override or defaults.""" + return team or self._registry_team_override or "default" + + def effective_team(self, team: Optional[str] = None) -> str: + """Public alias for _effective_team, useful in templates/callers.""" + return self._effective_team(team) + + def _team_registry_key(self, team: Optional[str] = None) -> str: + """Return the registry document key for a team.""" + return f"{self._registry_prefix}{self._effective_team(team)}" + + def _state_metadata_for_key(self, key: str) -> Dict[str, str]: + """Return Dapr state metadata including partition key.""" + meta = dict(self._base_metadata) + meta["partitionKey"] = key + return meta + + def _ensure_registry_initialized(self, *, key: str, meta: Dict[str, str]) -> None: + """ + Ensure a registry document exists to create an ETag for concurrency control. + + Args: + key: Registry document key. + meta: Dapr state metadata to use for the operation. + """ + current, etag = self.registry_state.load_with_etag( # type: ignore[union-attr] + key=key, + default={}, + state_metadata=meta, + ) + if etag is None: + self.registry_state.save( # type: ignore[union-attr] + key=key, + value={}, + etag=None, + state_metadata=meta, + state_options=self._save_options, + ) + + def _get_entry_container(self) -> Optional[dict]: + """ + Return the container mapping for workflow entries, if any. + + Returns: + A mutable mapping (e.g., dict) of instance_id -> entry, or None if + the underlying state model does not expose such a container. + + Notes: + Prefer a caller-provided hook via `AgentStateConfig.entry_container_getter`. + Falls back to `model.instances` for legacy/default shapes. + """ + if self._entry_container_getter: + return self._entry_container_getter(self._state_model) + return getattr(self._state_model, "instances", None) + + @staticmethod + def _coerce_datetime(value: Optional[Any]) -> datetime: + """ + Coerce strings/None to a timezone-aware UTC datetime. + + Args: + value: Source value (datetime | str | None). + + Returns: + A timezone-aware UTC datetime. If a naive datetime is provided, UTC is assumed. + """ + if isinstance(value, datetime): + return value if value.tzinfo else value.replace(tzinfo=timezone.utc) + if isinstance(value, str): + try: + dt = datetime.fromisoformat(value) + return dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc) + except ValueError: + pass + return datetime.now(timezone.utc) diff --git a/dapr_agents/agents/configs.py b/dapr_agents/agents/configs.py new file mode 100644 index 00000000..788ee3f0 --- /dev/null +++ b/dapr_agents/agents/configs.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, MutableMapping, Optional, Sequence, Type + +from pydantic import BaseModel + +from dapr_agents.agents.schemas import AgentWorkflowMessage, AgentWorkflowState +from dapr_agents.memory import ConversationListMemory, MemoryBase +from dapr_agents.storage.daprstores.stateservice import StateStoreService + +_JINJA_PLACEHOLDER_PATTERN = re.compile(r"(? str: + return _JINJA_PLACEHOLDER_PATTERN.sub(r"{{\1}}", text) + + +# Type hooks for state customization +EntryFactory = Callable[..., Any] +MessageCoercer = Callable[[Dict[str, Any]], Any] +EntryContainerGetter = Callable[[BaseModel], Optional[MutableMapping[str, Any]]] + + +@dataclass +class AgentStateConfig: + """Configuration for agent state persistence and model customization. + + Attributes: + store: Backing state store service. + default_state: Default state payload or model. If a dict, validated by `state_model_cls`. + state_key: Optional key override for the durable state entry. + state_model_cls: Root Pydantic model class for the state (default: AgentWorkflowState). + message_model_cls: Pydantic model class for workflow/system messages (default: AgentWorkflowMessage). + entry_factory: Factory used by AgentComponents.ensure_instance_exists(...) to create an entry. + message_coercer: Function to convert a raw dict into a message model instance. + entry_container_getter: Function to extract the instance container (e.g., `model.instances`) from the root model. + """ + + store: "StateStoreService" + default_state: Optional[Dict[str, Any] | BaseModel] = None + state_key: Optional[str] = None + + state_model_cls: Type[BaseModel] = AgentWorkflowState + message_model_cls: Type[BaseModel] = AgentWorkflowMessage + entry_factory: Optional[EntryFactory] = None + message_coercer: Optional[MessageCoercer] = None + entry_container_getter: Optional[EntryContainerGetter] = None + + def __post_init__(self) -> None: + # Defensive checks (optional but helpful during misconfigurations) + if not issubclass(self.state_model_cls, BaseModel): + raise TypeError("state_model_cls must be a subclass of pydantic.BaseModel") + if not issubclass(self.message_model_cls, BaseModel): + raise TypeError( + "message_model_cls must be a subclass of pydantic.BaseModel" + ) + + # Normalize default_state against the selected state_model_cls + Model = self.state_model_cls + if self.default_state is None: + self.default_state = Model().model_dump(mode="json") + else: + if isinstance(self.default_state, BaseModel): + self.default_state = self.default_state.model_dump(mode="json") + else: + self.default_state = Model.model_validate( + self.default_state + ).model_dump(mode="json") + + +@dataclass +class AgentRegistryConfig: + """Configuration for agent registry storage.""" + + store: StateStoreService + team_name: Optional[str] = None + + +@dataclass +class AgentMemoryConfig: + """Configuration wrapper for agent memory selection.""" + + store: MemoryBase = field(default_factory=ConversationListMemory) + + +@dataclass +class AgentPubSubConfig: + """Declarative pub/sub configuration for durable agents. + + Attributes: + pubsub_name: Name of the Dapr pub/sub component to use for all agent traffic. + agent_topic: Primary topic for direct messages to the agent. Defaults to ``name``. + broadcast_topic: Optional topic shared by a team for broadcast messages. + """ + + pubsub_name: str + agent_topic: Optional[str] = None + broadcast_topic: Optional[str] = None + + +@dataclass +class PromptSection: + """Reusable block for composing a structured system prompt.""" + + title: str + lines: List[str] = field(default_factory=list) + + def render(self, template_format: str) -> str: + if not self.lines: + return "" + header = self.title.strip() + body = "\n".join(f"- {line.strip()}" for line in self.lines if line.strip()) + section = f"{header}:\n{body}".strip() + return ( + _ensure_jinja_placeholders(section) + if template_format == "jinja2" + else section + ) + + +@dataclass +class AgentProfileConfig: + """ + High-level persona description for an agent. + + Mirrors common fields in OpenAI Agents SDK while remaining lightweight. + """ + + name: Optional[str] = None + role: Optional[str] = None + goal: Optional[str] = None + instructions: List[str] = field(default_factory=list) + style_guidelines: List[str] = field(default_factory=list) + system_prompt: Optional[str] = None + template_format: str = "jinja2" + modules: Sequence[str] = field(default_factory=tuple) + module_overrides: Dict[str, PromptSection] = field(default_factory=dict) + + +@dataclass +class AgentExecutionConfig: + """ + Dials to configure the agent execution. + """ + + # TODO: add a forceFinalAnswer field in case maxIterations is near/reached. Or do we have a conclusion baked in by default? Do we want this to derive a conclusion by default? + # TODO: add stop_at_tokens + max_iterations: int = 10 + tool_choice: Optional[str] = "auto" diff --git a/dapr_agents/agents/durable.py b/dapr_agents/agents/durable.py new file mode 100644 index 00000000..31da1e1d --- /dev/null +++ b/dapr_agents/agents/durable.py @@ -0,0 +1,729 @@ +from __future__ import annotations + +import json +import logging +from typing import Any, Dict, Iterable, Optional + +import dapr.ext.workflow as wf + +from dapr_agents.agents.base import AgentBase +from dapr_agents.agents.configs import ( + AgentMemoryConfig, + AgentPubSubConfig, + AgentRegistryConfig, + AgentStateConfig, + AgentExecutionConfig, +) +from dapr_agents.agents.prompting import AgentProfileConfig +from dapr_agents.agents.schemas import ( + AgentTaskResponse, + BroadcastMessage, + TriggerAction, +) +from dapr_agents.llm.chat import ChatClientBase +from dapr_agents.prompt.base import PromptTemplateBase +from dapr_agents.types import ( + AgentError, + LLMChatResponse, + ToolExecutionRecord, + ToolMessage, + UserMessage, +) +from dapr_agents.types.workflow import DaprWorkflowStatus +from dapr_agents.workflow.decorators.routers import message_router +from dapr_agents.workflow.utils.pubsub import broadcast_message, send_message_to_agent + +logger = logging.getLogger(__name__) + + +class DurableAgent(AgentBase): + """ + Workflow-native durable agent runtime on top of AgentBase. + + Overview: + Wires your AgentBase behavior into Dapr Workflows for durable, pub/sub-driven runs. + Leverages flexible state models, message coercers, and registry/metadata from Components. + + """ + + def __init__( + self, + *, + # Profile / prompt + profile_config: Optional[AgentProfileConfig] = None, + name: Optional[str] = None, + role: Optional[str] = None, + goal: Optional[str] = None, + instructions: Optional[Iterable[str]] = None, + style_guidelines: Optional[Iterable[str]] = None, + system_prompt: Optional[str] = None, + prompt_template: Optional[PromptTemplateBase] = None, + # Infrastructure + pubsub_config: AgentPubSubConfig, + state_config: Optional[AgentStateConfig] = None, + registry_config: Optional[AgentRegistryConfig] = None, + # Memory / runtime + memory_config: Optional[AgentMemoryConfig] = None, + llm: Optional[ChatClientBase] = None, + tools: Optional[Iterable[Any]] = None, + # Behavior / execution + execution_config: Optional[AgentExecutionConfig] = None, + # Misc + agent_metadata: Optional[Dict[str, Any]] = None, + runtime: Optional[wf.WorkflowRuntime] = None, + ) -> None: + """ + Initialize behavior, infrastructure, and workflow runtime. + + Args: + profile_config: High-level profile (can be overridden by explicit fields). + name: Agent name (required if not in `profile_config`). + role: Agent role/persona label. + goal: High-level objective for prompting context. + instructions: Extra instruction lines for the system prompt. + style_guidelines: Style directives for the system prompt. + system_prompt: System prompt override. + prompt_template: Optional explicit prompt template instance. + + pubsub_config: Dapr Pub/Sub configuration for triggers/broadcasts. + state_config: Durable state configuration and model customization. + registry_config: Team registry configuration. + execution_config: Execution dials for the agent run. + + memory_config: Conversation memory config; defaults to in-memory, or Dapr state-backed if available. + llm: Chat client; defaults to `get_default_llm()`. + tools: Optional tool callables or `AgentTool` instances. + + agent_metadata: Extra metadata to publish to the registry. + runtime: Optional pre-existing workflow runtime to attach to. + """ + super().__init__( + pubsub_config=pubsub_config, + profile_config=profile_config, + name=name, + role=role, + goal=goal, + instructions=instructions, + style_guidelines=style_guidelines, + system_prompt=system_prompt, + state_config=state_config, + memory_config=memory_config, + registry_config=registry_config, + execution_config=execution_config, + agent_metadata=agent_metadata, + llm=llm, + tools=tools, + prompt_template=prompt_template, + ) + + self._runtime: wf.WorkflowRuntime = runtime or wf.WorkflowRuntime() + self._runtime_owned = runtime is None + self._registered = False + self._started = False + + # ------------------------------------------------------------------ + # Runtime accessors + # ------------------------------------------------------------------ + @property + def runtime(self) -> wf.WorkflowRuntime: + """Return the underlying workflow runtime.""" + return self._runtime + + @property + def is_started(self) -> bool: + """Return True when the workflow runtime has been started.""" + return self._started + + # ------------------------------------------------------------------ + # Workflows / Activities + # ------------------------------------------------------------------ + @message_router(message_model=TriggerAction) + def agent_workflow(self, ctx: wf.DaprWorkflowContext, message: dict): + """ + Primary workflow loop reacting to `TriggerAction` pub/sub messages. + + Args: + ctx: Dapr workflow context injected by the runtime. + message: Trigger payload; may include task string and metadata. + + Returns: + Final assistant message as a dict. + + Raises: + AgentError: If the loop finishes without producing a final response. + """ + task = message.get("task") + metadata = message.get("_message_metadata", {}) or {} + + # Propagate OTel/parent workflow relations if present. + otel_span_context = message.get("_otel_span_context") + if "workflow_instance_id" in message: + metadata["triggering_workflow_instance_id"] = message[ + "workflow_instance_id" + ] + + trigger_instance_id = metadata.get("triggering_workflow_instance_id") + source = metadata.get("source") or "direct" + + # Ensure we have the latest durable state for this turn. + self.load_state() + + # Bootstrap instance entry (flexible to non-`instances` models). + self.ensure_instance_exists( + instance_id=ctx.instance_id, + input_value=task or "Triggered without input.", + triggering_workflow_instance_id=trigger_instance_id, + time=ctx.current_utc_datetime, + ) + + if not ctx.is_replaying: + logger.info("Initial message from %s -> %s", source, self.name) + + # Record initial entry via activity to keep deterministic/replay-friendly I/O. + yield ctx.call_activity( + self.record_initial_entry, + input={ + "instance_id": ctx.instance_id, + "input_value": task or "Triggered without input.", + "source": source, + "triggering_workflow_instance_id": trigger_instance_id, + "start_time": ctx.current_utc_datetime.isoformat(), + "trace_context": otel_span_context, + }, + ) + + final_message: Dict[str, Any] = {} + turn = 0 + + try: + for turn in range(1, self.execution_config.max_iterations + 1): + if not ctx.is_replaying: + logger.debug( + "Agent %s turn %d/%d (instance=%s)", + self.name, + turn, + self.execution_config.max_iterations, + ctx.instance_id, + ) + + assistant_response: Dict[str, Any] = yield ctx.call_activity( + self.call_llm, + input={ + "task": task, + "instance_id": ctx.instance_id, + "time": ctx.current_utc_datetime.isoformat(), + }, + ) + + tool_calls = assistant_response.get("tool_calls") or [] + if tool_calls: + if not ctx.is_replaying: + logger.debug( + "Agent %s executing %d tool call(s) on turn %d", + self.name, + len(tool_calls), + turn, + ) + parallel = [ + ctx.call_activity( + self.run_tool, + input={ + "tool_call": tc, + "instance_id": ctx.instance_id, + "time": ctx.current_utc_datetime.isoformat(), + "order": idx, + }, + ) + for idx, tc in enumerate(tool_calls) + ] + yield ctx.when_all(parallel) + task = None # prepare for next turn + continue + + final_message = assistant_response + if not ctx.is_replaying: + logger.debug( + "Agent %s produced final response on turn %d (instance=%s)", + self.name, + turn, + ctx.instance_id, + ) + break + else: + raise AgentError("Workflow ended without generating a final response.") + + except Exception as exc: # noqa: BLE001 + logger.exception("Agent %s workflow failed: %s", self.name, exc) + final_message = {"role": "assistant", "content": f"Error: {str(exc)}"} + + # Optionally broadcast the final message to the team. + if self.broadcast_topic_name: + yield ctx.call_activity( + self.broadcast_message_to_agents, + input={"message": final_message}, + ) + + # Optionally send a direct response back to the trigger origin. + if source and trigger_instance_id: + yield ctx.call_activity( + self.send_response_back, + input={ + "response": final_message, + "target_agent": source, + "target_instance_id": trigger_instance_id, + }, + ) + + # Finalize the workflow entry in durable state. + yield ctx.call_activity( + self.finalize_workflow, + input={ + "instance_id": ctx.instance_id, + "final_output": final_message.get("content", ""), + "end_time": ctx.current_utc_datetime.isoformat(), + "triggering_workflow_instance_id": trigger_instance_id, + }, + ) + + if not ctx.is_replaying: + verdict = ( + "max_iterations_reached" + if turn == self.execution_config.max_iterations + else "completed" + ) + logger.info( + "Workflow %s finalized for agent %s with verdict=%s", + ctx.instance_id, + self.name, + verdict, + ) + + return final_message + + @message_router(message_model=BroadcastMessage, broadcast=True) + def broadcast_listener(self, ctx: wf.DaprWorkflowContext, message: dict) -> None: + """ + Handle broadcast messages sent by other agents and store them in memory. + + Args: + ctx: Dapr workflow context (unused). + message: Broadcast payload containing content and metadata. + """ + metadata = message.get("_message_metadata", {}) or {} + source = metadata.get("source") or "unknown" + message_content = message.get("content", "") + if source == self.name: + logger.debug("Agent %s ignoring self-originated broadcast.", self.name) + return + + logger.info("Agent %s received broadcast from %s", self.name, source) + logger.debug("Full broadcast message: %s", message) + # Store as a user message from the broadcasting agent (kept in persistent memory). + self.memory.add_message( + UserMessage(name=source, content=message_content, role="user") + ) + + # ------------------------------------------------------------------ + # Activities + # ------------------------------------------------------------------ + def record_initial_entry( + self, ctx: wf.WorkflowActivityContext, payload: Dict[str, Any] + ) -> None: + """ + Record the initial entry for a workflow instance. + + Args: + payload: Keys: + - instance_id: Workflow instance id. + - input_value: Initial input value. + - source: Trigger source string. + - triggering_workflow_instance_id: Optional parent workflow id. + - start_time: ISO8601 datetime string. + - trace_context: Optional tracing context. + """ + instance_id = payload.get("instance_id") + trace_context = payload.get("trace_context") + input_value = payload.get("input_value", "Triggered without input.") + source = payload.get("source", "direct") + triggering_instance = payload.get("triggering_workflow_instance_id") + start_time = self._coerce_datetime(payload.get("start_time")) + + # Ensure instance exists in durable state + self.ensure_instance_exists( + instance_id=instance_id, + input_value=input_value, + triggering_workflow_instance_id=triggering_instance, + time=start_time, + ) + + # Use flexible container accessor (supports custom state layouts) + container = self._get_entry_container() + entry = container.get(instance_id) if container else None + if entry is None: + return + + entry.input_value = input_value + entry.source = source + entry.triggering_workflow_instance_id = triggering_instance + entry.start_time = start_time + entry.trace_context = trace_context + + session_id = getattr(self.memory, "session_id", None) + if session_id is not None and hasattr(entry, "session_id"): + entry.session_id = str(session_id) + + entry.status = DaprWorkflowStatus.RUNNING.value + self.save_state() + + def call_llm( + self, ctx: wf.WorkflowActivityContext, payload: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Ask the LLM to generate the next assistant message. + + Args: + payload: Must contain 'instance_id'; may include 'task' and 'time'. + + Returns: + Assistant message as a dict. + + Raises: + AgentError: If the LLM call fails or yields no message. + """ + instance_id = payload.get("instance_id") + task = payload.get("task") + + chat_history = self._construct_messages_with_instance_history(instance_id) + messages = self.prompting_helper.build_initial_messages( + user_input=task, + chat_history=chat_history, + ) + + # Sync current system messages into per-instance state + self._sync_system_messages_with_state(instance_id, messages) + + # Persist the user's turn (if any) into the instance timeline + memory + user_message = self._get_last_user_message(messages) + user_copy = dict(user_message) if user_message else None + self._process_user_message(instance_id, task, user_copy) + + if user_copy is not None: + self.text_formatter.print_message({str(k): v for k, v in user_copy.items()}) + + try: + response: LLMChatResponse = self.llm.generate( + messages=messages, + tools=self.get_llm_tools(), + tool_choice=self.execution_config.tool_choice, + ) + except Exception as exc: # noqa: BLE001 + logger.exception("LLM generate failed: %s", exc) + raise AgentError(str(exc)) from exc + + assistant_message = response.get_message() + if assistant_message is None: + raise AgentError("LLM returned no assistant message.") + + as_dict = assistant_message.model_dump() + self._save_assistant_message(instance_id, as_dict) + self.text_formatter.print_message(as_dict) + self.save_state() + return as_dict + + async def run_tool( + self, ctx: wf.WorkflowActivityContext, payload: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Execute a single tool call and persist results to state/memory. + + Args: + payload: Keys 'tool_call', 'instance_id', 'time', 'order'. + + Returns: + Tool execution record as a dict. + + Raises: + AgentError: If tool arguments contain invalid JSON. + """ + tool_call = payload.get("tool_call", {}) + instance_id = payload.get("instance_id") + fn_name = tool_call["function"]["name"] + raw_args = tool_call["function"].get("arguments", "") + + try: + args = json.loads(raw_args) if raw_args else {} + except json.JSONDecodeError as exc: + raise AgentError(f"Invalid JSON in tool args: {exc}") from exc + + result = await self.tool_executor.run_tool(fn_name, **args) + + # Safe serialization: JSON if possible, else string fallback. + if isinstance(result, str): + serialized_result = result + else: + try: + serialized_result = json.dumps(result) + except Exception: # noqa: BLE001 + serialized_result = str(result) + + tool_result = { + "tool_call_id": tool_call["id"], + "tool_name": fn_name, + "tool_args": args, + "execution_result": serialized_result, + } + history_entry = ToolExecutionRecord(**tool_result) + + # Build the tool message for both memory and (optionally) per-instance timeline + tool_message = ToolMessage( + tool_call_id=tool_result["tool_call_id"], + name=tool_result["tool_name"], + content=tool_result["execution_result"], + role="tool", + ) + agent_message = { + "id": tool_message.tool_call_id, + "role": "tool", + "name": tool_message.name, + "content": tool_message.content, + } + + # Append to durable state if the entry/timeline exists (with de-dupe) + container = self._get_entry_container() + entry = container.get(instance_id) if container else None + if entry is not None and hasattr(entry, "messages"): + # Skip if this tool_call_id already recorded + try: + existing_ids = { + getattr(m, "id", None) or getattr(m, "tool_call_id", None) + for m in getattr(entry, "messages") + } + except Exception: + existing_ids = set() + if agent_message["id"] not in existing_ids: + tool_message_model = ( + self._message_coercer(agent_message) + if getattr(self, "_message_coercer", None) + else self._message_dict_to_message_model(agent_message) + ) + entry.messages.append(tool_message_model) + if hasattr(entry, "tool_history"): + entry.tool_history.append(history_entry) + if hasattr(entry, "last_message"): + entry.last_message = tool_message_model + + # Always persist to memory + in-process tool history + self.memory.add_message(tool_message) + self.tool_history.append(history_entry) + self.save_state() + return tool_result + + def broadcast_message_to_agents( + self, ctx: wf.WorkflowActivityContext, payload: Dict[str, Any] + ) -> None: + """ + Broadcast a message to all agents via pub/sub (if a broadcast topic is set). + + Args: + payload: Dict containing 'message' (assistant/user-like dict). + """ + message = payload.get("message", {}) + if not isinstance(message, dict) or not self.broadcast_topic_name: + logger.debug( + "Skipping broadcast because payload is invalid or topic is unset." + ) + return + + try: + agents_metadata = self.get_agents_metadata( + exclude_self=False, exclude_orchestrator=False + ) + except Exception: # noqa: BLE001 + logger.exception("Unable to load agents metadata; broadcast aborted.") + return + + message["role"] = "user" + message["name"] = self.name + response_message = BroadcastMessage(**message) + + async def _broadcast() -> None: + await broadcast_message( + message=response_message, + broadcast_topic=self.broadcast_topic_name, + message_bus=self.message_bus_name, + source=self.name, + agents_metadata=agents_metadata, + ) + + try: + self._run_asyncio_task(_broadcast()) + except Exception: # noqa: BLE001 + logger.exception("Failed to publish broadcast message.") + + def send_response_back( + self, ctx: wf.WorkflowActivityContext, payload: Dict[str, Any] + ) -> None: + """ + Send the final response back to the triggering agent. + + Args: + payload: Dict containing 'response', 'target_agent', 'target_instance_id'. + """ + response = payload.get("response", {}) + target_agent = payload.get("target_agent", "") + target_instance_id = payload.get("target_instance_id", "") + if not target_agent or not target_instance_id: + logger.debug( + "Target agent or instance missing; skipping response publication." + ) + return + + response["role"] = "user" + response["name"] = self.name + response["workflow_instance_id"] = target_instance_id + agent_response = AgentTaskResponse(**response) + + agents_metadata = self.get_agents_metadata() + + try: + self._run_asyncio_task( + send_message_to_agent( + source=self.name, + target_agent=target_agent, + message=agent_response, + agents_metadata=agents_metadata, + ) + ) + except Exception: # noqa: BLE001 + logger.exception("Failed to publish response to %s", target_agent) + + def finalize_workflow( + self, ctx: wf.WorkflowActivityContext, payload: Dict[str, Any] + ) -> None: + """ + Finalize a workflow instance by setting status, output, and end time. + + Args: + payload: Dict with 'instance_id', 'final_output', 'end_time', + and optional 'triggering_workflow_instance_id'. + """ + instance_id = payload.get("instance_id") + final_output = payload.get("final_output", "") + end_time = payload.get("end_time", "") + triggering_workflow_instance_id = payload.get("triggering_workflow_instance_id") + + container = self._get_entry_container() + entry = container.get(instance_id) if container else None + if not entry: + return + + entry.status = ( + DaprWorkflowStatus.COMPLETED.value + if final_output + else DaprWorkflowStatus.FAILED.value + ) + entry.end_time = self._coerce_datetime(end_time) + if hasattr(entry, "output"): + entry.output = final_output or "" + entry.triggering_workflow_instance_id = triggering_workflow_instance_id + self.save_state() + + # ------------------------------------------------------------------ + # Runtime control + # ------------------------------------------------------------------ + def start( + self, + runtime: Optional[wf.WorkflowRuntime] = None, + *, + auto_register: bool = True, + ) -> None: + """ + Start the workflow runtime and register this agent's components. + + Behavior: + • If a runtime is provided, attach to it (we still consider it not owned). + • Register workflows once (if not already). + • Always attempt to start the runtime; treat start() as idempotent: + - If it's already running, swallow/log the exception and continue. + • We only call shutdown() later if we own the runtime. + """ + if self._started: + raise RuntimeError("Agent has already been started.") + + if runtime is not None: + self._runtime = runtime + self._runtime_owned = False + self._registered = False + logger.info( + "Attached injected WorkflowRuntime (owned=%s).", self._runtime_owned + ) + + if auto_register and not self._registered: + self.register_workflows(self._runtime) + self._registered = True + logger.info( + "Registered workflows/activities on WorkflowRuntime for agent '%s'.", + self.name, + ) + + # Always try to start; treat as idempotent. + try: + self._runtime.start() + logger.info( + "WorkflowRuntime started for agent '%s' (owned=%s).", + self.name, + self._runtime_owned, + ) + except Exception as exc: # noqa: BLE001 + # Most common benign case: runtime already running + logger.warning( + "WorkflowRuntime.start() raised for agent '%s' (likely already running): %s", + self.name, + exc, + exc_info=True, + ) + + self._started = True + + def stop(self) -> None: + """Stop the workflow runtime if it is owned by this instance.""" + if not self._started: + return + + if self._runtime_owned: + try: + self._runtime.shutdown() + except Exception: # noqa: BLE001 + logger.debug( + "Error while shutting down workflow runtime", exc_info=True + ) + + self._started = False + + def register(self, runtime: wf.WorkflowRuntime) -> None: + """ + Register workflows and activities on a provided runtime. + + Args: + runtime: An externally-managed workflow runtime to register with. + """ + self._runtime = runtime + self._runtime_owned = False + self.register_workflows(runtime) + self._registered = True + + def register_workflows(self, runtime: wf.WorkflowRuntime) -> None: + """ + Register workflows/activities for this agent. + + Args: + runtime: The Dapr workflow runtime to register with. + """ + runtime.register_workflow(self.agent_workflow) + runtime.register_workflow(self.broadcast_listener) + runtime.register_activity(self.record_initial_entry) + runtime.register_activity(self.call_llm) + runtime.register_activity(self.run_tool) + runtime.register_activity(self.broadcast_message_to_agents) + runtime.register_activity(self.send_response_back) + runtime.register_activity(self.finalize_workflow) diff --git a/dapr_agents/agents/durableagent/__init__.py b/dapr_agents/agents/durableagent/__init__.py deleted file mode 100644 index 2d80e605..00000000 --- a/dapr_agents/agents/durableagent/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .agent import DurableAgent - -__all__ = ["DurableAgent"] diff --git a/dapr_agents/agents/durableagent/agent.py b/dapr_agents/agents/durableagent/agent.py deleted file mode 100644 index ab5ce41f..00000000 --- a/dapr_agents/agents/durableagent/agent.py +++ /dev/null @@ -1,1015 +0,0 @@ -import asyncio -import json -import logging -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional, Union - -from dapr.ext.workflow import DaprWorkflowContext # type: ignore -from pydantic import Field, model_validator - -from dapr_agents.agents.base import AgentBase -from dapr_agents.agents.durableagent.state import DurableAgentWorkflowState -from dapr_agents.types import ( - AgentError, - AssistantMessage, - LLMChatResponse, - ToolExecutionRecord, - ToolMessage, - UserMessage, -) -from dapr_agents.types.workflow import DaprWorkflowStatus -from dapr_agents.workflow.agentic import AgenticWorkflow -from dapr_agents.workflow.decorators import message_router, task, workflow -from dapr_agents.memory import ConversationDaprStateMemory - -from .schemas import ( - AgentTaskResponse, - BroadcastMessage, - InternalTriggerAction, - TriggerAction, -) -from .state import ( - DurableAgentMessage, - DurableAgentWorkflowEntry, -) - -logger = logging.getLogger(__name__) - - -# TODO(@Sicoyle): Clear up the lines between DurableAgent and AgentWorkflow -class DurableAgent(AgenticWorkflow, AgentBase): - """ - A conversational AI agent that responds to user messages, engages in discussions, - and dynamically utilizes external tools when needed. - - The DurableAgent follows an agentic workflow, iterating on responses based on - contextual understanding, reasoning, and tool-assisted execution. It ensures - meaningful interactions by selecting the right tools, generating relevant responses, - and refining outputs through iterative feedback loops. - """ - - agent_topic_name: Optional[str] = Field( - default=None, - description="The topic name dedicated to this specific agent, derived from the agent's name if not provided.", - ) - agent_metadata: Optional[Dict[str, Any]] = Field( - default=None, - description="Metadata about the agent, including name, role, goal, instructions, and topic name.", - ) - workflow_instance_id: Optional[str] = Field( - default=None, - description="The current workflow instance ID for this agent.", - ) - memory: Optional[ConversationDaprStateMemory] = Field( - default=None, - description="Persistent memory with session-based state hydration.", - ) - - @model_validator(mode="before") - def set_agent_and_topic_name(cls, values: dict): - # Set name to role if name is not provided - if not values.get("name") and values.get("role"): - values["name"] = values["role"] - - # Derive agent_topic_name from agent name - if not values.get("agent_topic_name") and values.get("name"): - values["agent_topic_name"] = values["name"] - - return values - - def model_post_init(self, __context: Any) -> None: - """Initializes the workflow with agentic execution capabilities.""" - # Call AgenticWorkflow's model_post_init first to initialize state store and other dependencies - # This will properly load state from storage if it exists - super().model_post_init(__context) - - # Name of main Workflow - # TODO: can this be configurable or dynamic? Would that make sense? - self._workflow_name = "AgenticWorkflow" - - # Initialize state structure if it doesn't exist - if not self.state: - self.state = {"instances": {}} - - if self.memory is not None: - self.memory = ConversationDaprStateMemory( - store_name=self.memory.store_name, - session_id=f"{self.name or 'agent'}_session", - ) - logger.info(f"Initialized memory with store name: {self.memory.store_name}") - - # Load the current workflow instance ID from state using session_id - logger.debug(f"State after loading: {self.state}") - if self.state and self.state.get("instances"): - logger.debug(f"Found {len(self.state['instances'])} instances in state") - current_session_id = ( - self.memory.session_id - if self.memory - else f"{self.name}_default_session" - ) - for instance_id, instance_data in self.state["instances"].items(): - stored_workflow_name = instance_data.get("workflow_name") - stored_session_id = instance_data.get("session_id") - logger.debug( - f"Instance {instance_id}: workflow_name={stored_workflow_name}, session_id={stored_session_id}, current_workflow_name={self._workflow_name}, current_session_id={current_session_id}" - ) - if ( - stored_workflow_name == self._workflow_name - and stored_session_id == current_session_id - ): - self.workflow_instance_id = instance_id - logger.debug( - f"Loaded current workflow instance ID from state using session_id: {instance_id}" - ) - break - else: - logger.debug("No instances found in state or state is empty") - - # Sync workflow state with Dapr runtime after loading - # This ensures our database reflects the actual state of resumed workflows - self._sync_workflow_state_after_startup() - - # Register the agentic system - self._agent_metadata = { - "name": self.name, - "role": self.role, - "goal": self.goal, - "instructions": self.instructions, - "topic_name": self.agent_topic_name, - "pubsub_name": self.message_bus_name, - "orchestrator": False, - } - - self.register_agentic_system() - - async def run(self, input_data: Union[str, Dict[str, Any]]) -> Any: - """ - Fire up the workflow, wait for it to complete, then return the final serialized_output. - Dapr automatically handles resuming any incomplete workflows when the runtime starts. - - Args: - input_data (Union[str, Dict[str, Any]]): The input for the workflow. Can be a string (task) or a dict. - Returns: - Any: The final output from the workflow execution. - """ - logger.debug(f"DurableAgent.run() called with input: {input_data}") - - # Set up signal handlers for graceful shutdown when using run() method - self.setup_signal_handlers() - - # Prepare input payload for workflow - if isinstance(input_data, dict): - input_payload = input_data - else: - input_payload = {"task": input_data} - - try: - result = await self.run_and_monitor_workflow_async( - workflow=self._workflow_name, - input=input_payload, - ) - return result - except asyncio.CancelledError: - logger.warning("Workflow execution was cancelled") - raise - - @message_router - @workflow(name="AgenticWorkflow") - def tool_calling_workflow(self, ctx: DaprWorkflowContext, message: TriggerAction): - """ - Executes a tool-calling workflow, determining the task source (either an agent or an external user). - This uses Dapr Workflows to run the agent in a ReAct-style loop until it generates a final answer or reaches max iterations, - calling tools as needed. - - Args: - ctx (DaprWorkflowContext): The workflow context for the current execution, providing state and control methods. - message (TriggerAction): The trigger message containing the task, iteration, and metadata for workflow execution. - - Returns: - Dict[str, Any]: The final response message when the workflow completes, or None if continuing to the next iteration. - """ - # Step 1: pull out task + metadata + span context from workflow input through .start, .run(), pubsub invocation - if isinstance(message, dict): - task = message.get("task", None) - metadata = message.get("_message_metadata", {}) or {} - # Extract OpenTelemetry span context if present - otel_span_context = message.get("_otel_span_context", None) - # Extract workflow_instance_id from TriggerAction if present from orchestrator - if "workflow_instance_id" in message: - metadata["triggering_workflow_instance_id"] = message[ - "workflow_instance_id" - ] - else: # This is for if triggered by an orchestrator - task = getattr(message, "task", None) - metadata = getattr(message, "_message_metadata", {}) or {} - # Extract OpenTelemetry span context if present - otel_span_context = getattr(message, "_otel_span_context", None) - # Extract workflow_instance_id from TriggerAction if present from orchestrator - if hasattr(message, "workflow_instance_id"): - metadata["triggering_workflow_instance_id"] = getattr( - message, "workflow_instance_id" - ) - # Extract source from TriggerAction if present from orchestrator - if hasattr(message, "source"): - metadata["source"] = getattr(message, "source") - - triggering_workflow_instance_id = metadata.get( - "triggering_workflow_instance_id" - ) - source = self.get_source_or_default(metadata.get("source")) - - # Store workflow instance ID for observability layer to use - # The observability layer will handle AGENT span creation for resumed workflows - if otel_span_context: - # New workflow - store the provided span context (observability layer handles this) - from dapr_agents.observability.context_storage import store_workflow_context - - instance_context_key = f"__workflow_context_{ctx.instance_id}__" - store_workflow_context(instance_context_key, otel_span_context) - - # Load the latest state from database to ensure we have up-to-date instance data - self.load_state() - - if not ctx.is_replaying: - logger.debug(f"Initial message from {source} -> {self.name}") - - yield ctx.call_activity( - self.record_initial_entry, - input={ - "instance_id": ctx.instance_id, - "input": task or "Triggered without input.", - "source": source, - "triggering_workflow_instance_id": triggering_workflow_instance_id, - "start_time": ctx.current_utc_datetime.isoformat(), - "trace_context": otel_span_context, - }, - ) - - try: - for turn in range(1, self.max_iterations + 1): - if not ctx.is_replaying: - logger.debug( - f"Workflow turn {turn}/{self.max_iterations} (Instance ID: {ctx.instance_id})" - ) - - # Generate Response with LLM and atomically save the assistant's response message - response_message: dict = yield ctx.call_activity( - self.call_llm, - input={ - "task": task, - "instance_id": ctx.instance_id, - "time": ctx.current_utc_datetime.isoformat(), - }, - ) - - # Handle tool calls response - tool_calls = response_message.get("tool_calls") or [] - if tool_calls: - if not ctx.is_replaying: - logger.debug( - f"Turn {turn}: executing {len(tool_calls)} tool call(s)" - ) - # fan‑out parallel tool executions - parallel = [ - ctx.call_activity( - self.run_tool, - input={ - "tool_call": tc, - "instance_id": ctx.instance_id, - "time": ctx.current_utc_datetime.isoformat(), - "execution_order": i, # Add ordering information - }, - ) - for i, tc in enumerate(tool_calls) - ] - yield self.when_all(parallel) - - # Prepare for next turn: clear task so that call_llm() uses memory/history - task = None - continue # bump to next turn - - # No tool calls → this is your final answer - break # exit loop - else: - raise AgentError("Workflow ended without producing a final response") - - except Exception as e: - logger.exception("Workflow error", exc_info=e) - err_msg = { - "role": "assistant", - "content": f"⚠️ Unexpected error: {e}", - } - self._save_assistant_message(ctx.instance_id, err_msg) - - # Get the last message from state (this will be the final response) - final_msg = self._get_last_message_from_state(ctx.instance_id) - if not final_msg: - final_msg = {"role": "assistant", "content": "No response generated"} - - # Broadcast the final response if a broadcast topic is set - if self.broadcast_topic_name: - yield ctx.call_activity( - self.broadcast_message_to_agents, - input={"message": final_msg}, - ) - - # Respond to source agent if available - if source and triggering_workflow_instance_id: - yield ctx.call_activity( - self.send_response_back, - input={ - "response": final_msg, - "target_agent": source, - "target_instance_id": triggering_workflow_instance_id, - }, - ) - - yield ctx.call_activity( - self.finalize_workflow, - input={ - "instance_id": ctx.instance_id, - "final_output": final_msg.get("content", ""), - "time": ctx.current_utc_datetime.isoformat(), - "triggering_workflow_instance_id": triggering_workflow_instance_id, - }, - ) - - # Set verdict for the workflow instance - if not ctx.is_replaying: - verdict = ( - "max_iterations_reached" if turn == self.max_iterations else "completed" - ) - logger.info(f"Workflow {ctx.instance_id} finalized: {verdict}") - - # Return the final response message - return final_msg - - @message_router - @workflow(name="AgenticWorkflow") - def internal_trigger_workflow( - self, ctx: DaprWorkflowContext, message: InternalTriggerAction - ): - """ - Handles InternalTriggerAction messages by treating them the same as TriggerAction. - This prevents self-triggering loops while allowing orchestrators to trigger agents. - - Args: - ctx (DaprWorkflowContext): The workflow context for the current execution. - message (InternalTriggerAction): The internal trigger message from an orchestrator. - - Returns: - Dict[str, Any]: The final response message when the workflow completes. - """ - # Convert InternalTriggerAction to TriggerAction format and delegate to the main workflow - trigger_message = TriggerAction( - task=message.task, - workflow_instance_id=message.workflow_instance_id, - source="orchestrator", # Default source for internal triggers - ) - return self.tool_calling_workflow(ctx, trigger_message) - - def get_source_or_default(self, source: str): - # Set default source if not provided (for direct run() calls) - if not source: - source = "direct" - return source - - @task - def record_initial_entry( - self, - instance_id: str, - input: str, - source: Optional[str], - triggering_workflow_instance_id: Optional[str], - start_time: str, # required to be passed in using the workflow context for deterministic timestamp - output: str = "", - trace_context: Optional[Dict[str, Any]] = None, - ): - """ - Records the initial workflow entry for a new workflow instance. - Args: - instance_id (str): The workflow instance ID. - input (str): The input task for the workflow. - source (Optional[str]): The source of the workflow trigger. - triggering_workflow_instance_id (Optional[str]): The workflow instance ID of the triggering workflow. - output (str): The output for the workflow entry (default: ""). - start_time (Optional[str]): The start time in ISO format (default: None, will use current time). - trace_context (Optional[Dict[str, Any]]): OpenTelemetry trace context for workflow resumption. - """ - # Convert datetime to string for JSON serialization - if start_time: - if isinstance(start_time, str): - start_time_str = start_time - else: - start_time_str = start_time.isoformat() - else: - start_time_str = datetime.now(timezone.utc).isoformat() - - entry = { - "input": input, - "source": source, - "workflow_instance_id": instance_id, - "triggering_workflow_instance_id": triggering_workflow_instance_id, - "workflow_name": self._workflow_name, - "session_id": self.memory.session_id - if self.memory - else f"{self.name}_default_session", - "start_time": start_time_str, - "trace_context": trace_context, - "status": DaprWorkflowStatus.RUNNING.value, - "messages": [], - "tool_history": [], - "end_time": None, - } - if "instances" not in self.state: - self.state["instances"] = {} - self.state["instances"][instance_id] = entry - - # Note: This is only really needed bc of the in-memory storage solutions. - # With persistent storage, this is not needed as we rehydrate the conversation state from the database upon app restart. - def _ensure_instance_exists( - self, - instance_id: str, - input: str, - triggering_workflow_instance_id: Optional[str] = None, - time: Optional[datetime] = None, - ) -> None: - """Ensure the instance entry exists in the state.""" - if instance_id not in self.state.get("instances", {}): - if "instances" not in self.state: - self.state["instances"] = {} - - # Handle time parameter - it might be a datetime object or a string - if time: - if isinstance(time, str): - start_time = time - else: - start_time = time.isoformat() - else: - start_time = datetime.now(timezone.utc).isoformat() - - self.state["instances"][instance_id] = { - "input": input, - "start_time": start_time, - "source": "user_input", - "workflow_instance_id": instance_id, - "triggering_workflow_instance_id": triggering_workflow_instance_id, - "workflow_name": self._workflow_name, - "session_id": self.memory.session_id - if self.memory - else f"{self.name}_default_session", - "messages": [], - "tool_history": [], - "status": DaprWorkflowStatus.RUNNING.value, - "end_time": None, - "trace_context": None, - } - - def _process_user_message( - self, - instance_id: str, - task: Optional[Union[str, Dict[str, Any]]], - user_message_copy: Optional[Dict[str, Any]], - ) -> None: - """Process and save user message to memory and state.""" - if not (task and user_message_copy): - return - - user_msg = UserMessage(content=user_message_copy.get("content", "")) - if self.memory: - self.memory.add_message(user_msg) - - msg_object = DurableAgentMessage(**user_message_copy) - inst = self.state["instances"][instance_id] - inst["messages"].append(msg_object.model_dump(mode="json")) - inst["last_message"] = msg_object.model_dump(mode="json") - self.save_state() - - def _call_llm(self, messages: List[Dict[str, Any]]) -> Dict[str, Any]: - """Generate LLM response and return the assistant message.""" - response: LLMChatResponse = self.llm.generate( - messages=messages, - tools=self.get_llm_tools(), - **( - {"tool_choice": self.tool_choice} - if self.tool_choice is not None - else {} - ), - ) - response_message = response.get_message() - if response_message is None: - raise AgentError("LLM returned no assistant message") - - return response_message.model_dump() - - def _save_assistant_message( - self, instance_id: str, assistant_message: Dict[str, Any] - ) -> None: - """Save assistant message to state with idempotency check.""" - assistant_message["name"] = self.name - agent_msg = DurableAgentMessage(**assistant_message) - - inst = self.state["instances"][instance_id] - messages_list = inst["messages"] - - # Check for duplicate by message ID (idempotent for workflow replay) - message_exists = any(msg.get("id") == agent_msg.id for msg in messages_list) - if not message_exists: - messages_list.append(agent_msg.model_dump(mode="json")) - inst["last_message"] = agent_msg.model_dump(mode="json") - if self.memory: - self.memory.add_message(AssistantMessage(**assistant_message)) - self.save_state() - - def _print_llm_interaction_messages( - self, - user_message_copy: Optional[Dict[str, Any]], - assistant_message: Dict[str, Any], - ) -> None: - """Print user and assistant messages for context.""" - # Print user message - if user_message_copy is not None: - self.text_formatter.print_message( - {str(k): v for k, v in user_message_copy.items()} - ) - - # Print assistant message - self.text_formatter.print_message(assistant_message) - - @task - async def call_llm( - self, - instance_id: str, - time: datetime, - task: Optional[Union[str, Dict[str, Any]]] = None, - ) -> Dict[str, Any]: - """ - Ask the LLM for the assistant's next message. - - Args: - instance_id (str): The workflow instance ID. - time (str): The time of the message. - task: The user's query for this turn (either a string or a dict), - or None if this is a follow-up iteration. - - Returns: - A plain dict of the LLM's response (choices, finish_reason, etc). - Pydantic models are `.model_dump()`-ed; any other object is coerced via `dict()`. - """ - # Construct messages using instance-specific chat history instead of global memory - # This ensures proper message sequence for tool calls and ensures formatting/structure - messages: List[Dict[str, Any]] = self._construct_messages_with_instance_history( - instance_id, task or {} - ) - user_message = self.get_last_message_if_user(messages) - - # Always work with a copy of the user message for safety - user_message_copy: Optional[Dict[str, Any]] = ( - dict(user_message) if user_message else None - ) - - self._ensure_instance_exists( - instance_id, task or "No input provided", time=time - ) - self._process_user_message(instance_id, task, user_message_copy) - - # Generate LLM response and atomically save assistant message - try: - assistant_message = self._call_llm(messages) - self._save_assistant_message(instance_id, assistant_message) - self._print_llm_interaction_messages(user_message_copy, assistant_message) - - return assistant_message - except Exception as e: - error_type = type(e).__name__ - error_msg = str(e) - - logger.exception( - f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}" - ) - logger.exception(f"Task: {task}") - logger.exception(f"Messages count: {len(messages)}") - logger.exception(f"Tools available: {len(self.get_llm_tools())}") - logger.exception("Full error details:", exc_info=True) - - raise AgentError( - f"LLM generation failed in workflow {instance_id}: {error_type} - {error_msg}" - ) from e - - @task - def _create_tool_message_objects(self, tool_result: Dict[str, Any]) -> tuple: - """ - Create ToolMessage and DurableAgentMessage objects from tool result. - - Args: - tool_result: Dictionary containing tool execution details - - Returns: - Tuple of (tool_msg, agent_msg, tool_history_entry) - """ - tool_msg = ToolMessage( - tool_call_id=tool_result["tool_call_id"], - name=tool_result["tool_name"], - content=tool_result["execution_result"], - role="tool", - ) - agent_msg = DurableAgentMessage(**tool_msg.model_dump()) - tool_history_entry = ToolExecutionRecord(**tool_result) - - return tool_msg, agent_msg, tool_history_entry - - def _append_tool_message_to_instance( - self, - instance_id: str, - agent_msg: DurableAgentMessage, - tool_history_entry: ToolExecutionRecord, - ) -> None: - """ - Append tool message and history to the instance state. - - Args: - instance_id: The workflow instance ID - agent_msg: The DurableAgentMessage object - tool_history_entry: The ToolExecutionRecord object - """ - wf_instance = self.state["instances"][instance_id] - - # Check if message already exists (idempotent operation for workflow replay) - wf_messages = wf_instance["messages"] - - # Check for duplicate by message ID (idempotent for workflow replay) - message_exists = any(msg.get("id") == agent_msg.id for msg in wf_messages) - if not message_exists: - wf_messages.append(agent_msg.model_dump(mode="json")) - - # Check for duplicate tool history entry by tool_call_id - tool_history = wf_instance["tool_history"] - - tool_exists = any( - th.get("tool_call_id") == tool_history_entry.tool_call_id - for th in tool_history - ) - if not tool_exists: - tool_history.append(tool_history_entry.model_dump(mode="json")) - - def _update_agent_memory_and_history( - self, tool_message: ToolMessage, tool_history_entry: ToolExecutionRecord - ) -> None: - """ - Update agent's memory and tool history. - - Args: - tool_message: The ToolMessage object - tool_history_entry: The ToolExecutionRecord object - """ - # Update tool history and memory of agent (only if new) - # Note: Memory updates are handled at workflow level to avoid replay issues - self.tool_history.append(tool_history_entry) - # Add the tool message to the agent's memory if enabled - if self.memory: - self.memory.add_message(tool_message) - - def _get_last_message_from_state( - self, instance_id: str - ) -> Optional[Dict[str, Any]]: - """ - Get the last message from the instance state. - - Args: - instance_id: The workflow instance ID - - Returns: - The last message dict or None if not found - """ - instance_data = self.state.get("instances", {}).get(instance_id) - if instance_data is not None: - return instance_data.get("last_message") - return None - - @task - async def run_tool( - self, - tool_call: Dict[str, Any], - instance_id: str, - time: datetime, - execution_order: int = 0, - ) -> Dict[str, Any]: - """ - Executes a tool call atomically by invoking the specified function with the provided arguments - and immediately persisting the result to the agent's state and memory. - - Args: - tool_call (Dict[str, Any]): A dictionary containing tool execution details, including the function name and arguments. - instance_id (str): The workflow instance ID for state persistence. - time (str): The current time for state persistence. - - Returns: - Dict[str, Any]: A dictionary containing the tool call ID, function name, function arguments - - Raises: - AgentError: If the tool call is malformed or execution fails. - """ - # Extract function name and raw args - fn_name = tool_call["function"]["name"] - raw_args = tool_call["function"].get("arguments", "") - - # Parse JSON arguments (or empty dict) - try: - args = json.loads(raw_args) if raw_args else {} - except json.JSONDecodeError as e: - raise AgentError(f"Invalid JSON in tool args: {e}") - - # Run the tool - logger.debug(f"Executing tool '{fn_name}' with args: {args}") - try: - result = await self.tool_executor.run_tool(fn_name, **args) - except Exception as e: - logger.exception(f"Error executing tool '{fn_name}': {e}", exc_info=True) - raise AgentError(f"Error executing tool '{fn_name}': {e}") from e - - # Create the tool result payload - tool_result = { - "tool_call_id": tool_call["id"], - "tool_name": fn_name, - "tool_args": args, - "execution_result": str(result) if result is not None else "", - } - - # Atomically persist the tool execution result - # Get existing input or use placeholder - existing_input = ( - self.state["instances"][instance_id]["input"] - if instance_id in self.state.get("instances", {}) - else "Tool execution" - ) - self._ensure_instance_exists(instance_id, existing_input, time=time) - tool_msg, agent_msg, tool_history_entry = self._create_tool_message_objects( - tool_result - ) - self._append_tool_message_to_instance( - instance_id, agent_msg, tool_history_entry - ) - self._update_agent_memory_and_history(tool_msg, tool_history_entry) - self.save_state() - self.text_formatter.print_message(tool_msg) - - return tool_result - - @task - async def broadcast_message_to_agents(self, message: Dict[str, Any]): - """ - Broadcasts it to all registered agents. - - Args: - message (Dict[str, Any]): A message to append to the workflow state and broadcast to all agents. - """ - # Format message for broadcasting - message["role"] = "user" - message["name"] = self.name - response_message = BroadcastMessage(**message) - - # Broadcast message to all agents - await self.broadcast_message(message=response_message) - - @task - async def send_response_back( - self, response: Dict[str, Any], target_agent: str, target_instance_id: str - ): - """ - Sends a task response back to a target agent within a workflow. - - Args: - response (Dict[str, Any]): The response payload to be sent. - target_agent (str): The name of the agent that should receive the response. - target_instance_id (str): The workflow instance ID associated with the response. - - Raises: - ValidationError: If the response does not match the expected structure for `AgentTaskResponse`. - """ - response["role"] = "user" - response["name"] = self.name - response["workflow_instance_id"] = target_instance_id - agent_response = AgentTaskResponse(**response) - - # Send the message to the target agent - await self.send_message_to_agent(name=target_agent, message=agent_response) - - # TODO: add metrics on workflow run in future here? - @task - def finalize_workflow( - self, - instance_id: str, - final_output: str, - time: str, - triggering_workflow_instance_id: Optional[str] = None, - ) -> None: - """ - Record the final output and end_time in the workflow state. - """ - # Ensure the instance entry exists - existing_input = ( - self.state["instances"][instance_id]["input"] - if instance_id in self.state.get("instances", {}) - else "Workflow completion" - ) - self._ensure_instance_exists( - instance_id, existing_input, triggering_workflow_instance_id, time - ) - instance = self.state["instances"][instance_id] - instance["output"] = final_output - # Convert time to string for JSON serialization - if time: - if isinstance(time, str): - instance["end_time"] = time - else: - instance["end_time"] = time.isoformat() - else: - instance["end_time"] = datetime.now(timezone.utc).isoformat() - instance["status"] = DaprWorkflowStatus.COMPLETED.value # Mark as completed - logger.info(f"Workflow {instance_id} completed successfully") - self.save_state() - - @message_router(broadcast=True) - async def process_broadcast_message(self, message: BroadcastMessage): - """ - Processes a broadcast message by filtering out messages from the same agent, - storing valid messages in memory, and triggering the agent's workflow if needed. - - Args: - message (BroadcastMessage): The received broadcast message. - - Returns: - None: The function updates the agent's memory and triggers a workflow. - """ - try: - # Extract metadata safely from message["_message_metadata"] - metadata = getattr(message, "_message_metadata", {}) - - if not isinstance(metadata, dict) or not metadata: - logger.warning( - f"{self.name} received a broadcast message with missing or invalid metadata. Ignoring." - ) - return - - source = metadata.get("source", "unknown_source") - message_type = metadata.get("type", "unknown_type") - message_content = getattr(message, "content", "No Data") - logger.info( - f"{self.name} received broadcast message of type '{message_type}' from '{source}'." - ) - # Ignore messages sent by this agent - if source == self.name: - logger.debug( - f"{self.name} ignored its own broadcast message of type '{message_type}'." - ) - return - # Log and process the valid broadcast message - logger.debug( - f"{self.name} processing broadcast message from '{source}'. Content: {message_content}" - ) - # Store the message in local memory if enabled - if self.memory: - self.memory.add_message(message) - - # Define DurableAgentMessage object for state persistence - agent_msg = DurableAgentMessage(**message.model_dump()) - - # Persist to global chat history - if "chat_history" not in self.state: - self.state["chat_history"] = [] - self.state["chat_history"].append(agent_msg.model_dump(mode="json")) - # Save the state after processing the broadcast message - self.save_state() - - # Trigger agent workflow to respond to the broadcast message - workflow_instance_id = metadata.get("workflow_instance_id") - if workflow_instance_id: - # Create a TriggerAction to start the agent's workflow - trigger_message = TriggerAction( - task=message.content, workflow_instance_id=workflow_instance_id - ) - trigger_message._message_metadata = { - "source": metadata.get("source", "unknown"), - "type": "BroadcastMessage", - "workflow_instance_id": workflow_instance_id, - } - - # Start the agent's workflow - await self.run_and_monitor_workflow_async( - workflow="AgenticWorkflow", input=trigger_message - ) - - except Exception as e: - logger.error(f"Error processing broadcast message: {e}", exc_info=True) - - # TODO: we need to better design context history management. Context engineering is important, - # and too much context can derail the agent. - def _construct_messages_with_instance_history( - self, instance_id: str, input_data: Union[str, Dict[str, Any]] - ) -> List[Dict[str, Any]]: - """ - Construct messages using instance-specific chat history and persistent memory. - This ensures proper message sequence for tool calls and maintains conversation - history across workflow executions using the session_id. - - Args: - instance_id: The workflow instance ID - input_data: User input, either as a string or dictionary - - Returns: - List of formatted messages with proper sequence - """ - additional_context_messages: List[Dict[str, Any]] = [] - if not self.prompt_template: - raise ValueError( - "Prompt template must be initialized before constructing messages." - ) - - # Get instance-specific chat history - if self.state is None: - logger.warning( - f"Agent state is None for instance {instance_id}, initializing empty state" - ) - self.state = {} - - instance_data = self.state.get("instances", {}).get(instance_id) - if instance_data is not None: - instance_messages = instance_data.get("messages", []) - else: - instance_messages = [] - - # Get messages from persistent memory (session-based, cross-workflow) if memory is enabled - persistent_memory_messages = [] - if self.memory: - try: - persistent_memory_messages = self.memory.get_messages() - logger.info( - f"Retrieved {len(persistent_memory_messages)} messages for session {self.memory.session_id}" - ) - except Exception as e: - logger.warning(f"Failed to retrieve persistent memory: {e}") - else: - logger.debug("Memory is disabled, skipping persistent memory retrieval") - - # Get long-term memory from workflow state (for broadcast messages and persistent context) - long_term_memory_data = self.state.get("chat_history", []) - long_term_memory_messages = [] - for msg in long_term_memory_data: - if isinstance(msg, dict): - long_term_memory_messages.append(msg) - elif hasattr(msg, "model_dump"): - long_term_memory_messages.append(msg.model_dump()) - - # Build chat history with proper context and order - chat_history = [] - - # First add persistent memory and long-term memory as user messages for context - # This ensures we have cross-workflow context but doesn't interfere with tool state order - for msg in persistent_memory_messages + long_term_memory_messages: - msg_dict = msg.model_dump() if hasattr(msg, "model_dump") else dict(msg) - if msg_dict in chat_history: - continue - # TODO: We need to properly design session-based memory. - # Convert tool-related messages to user messages to avoid conversation order issues - if msg_dict.get("role") in ["tool", "assistant"] and ( - msg_dict.get("tool_calls") or msg_dict.get("tool_call_id") - ): - msg_dict = { - "role": "user", - "content": f"[Previous {msg_dict['role']} message: {msg_dict.get('content', '')}]", - } - chat_history.append(msg_dict) - - # Then add instance messages in their original form to maintain tool state - for msg in instance_messages: - msg_dict = msg.model_dump() if hasattr(msg, "model_dump") else dict(msg) - if msg_dict in chat_history: - continue - chat_history.append(msg_dict) - - # Add additional context memory last (for broadcast-triggered workflows) - chat_history.extend(additional_context_messages) - - if isinstance(input_data, str): - formatted_messages = self.prompt_template.format_prompt( - chat_history=chat_history - ) - if isinstance(formatted_messages, list): - user_message = {"role": "user", "content": input_data} - return formatted_messages + [user_message] - else: - return [ - {"role": "system", "content": formatted_messages}, - {"role": "user", "content": input_data}, - ] - elif isinstance(input_data, dict): - input_vars = dict(input_data) - if "chat_history" not in input_vars: - input_vars["chat_history"] = chat_history - formatted_messages = self.prompt_template.format_prompt(**input_vars) - if isinstance(formatted_messages, list): - return formatted_messages - else: - return [{"role": "system", "content": formatted_messages}] - else: - raise ValueError("Input data must be either a string or dictionary.") diff --git a/dapr_agents/agents/durableagent/schemas.py b/dapr_agents/agents/durableagent/schemas.py deleted file mode 100644 index 2ff99a89..00000000 --- a/dapr_agents/agents/durableagent/schemas.py +++ /dev/null @@ -1,48 +0,0 @@ -from dapr_agents.types.message import BaseMessage -from pydantic import BaseModel, Field -from typing import Optional - - -class BroadcastMessage(BaseMessage): - """ - Represents a broadcast message from an agent. - """ - - -class AgentTaskResponse(BaseMessage): - """ - Represents a response message from an agent after completing a task. - """ - - workflow_instance_id: Optional[str] = Field( - default=None, description="Dapr workflow instance id from source if available" - ) - - -class TriggerAction(BaseModel): - """ - Represents a message used to trigger an agent's activity within the workflow. - """ - - task: Optional[str] = Field( - None, - description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.", - ) - workflow_instance_id: Optional[str] = Field( - default=None, description="Dapr workflow instance id from source if available" - ) - - -class InternalTriggerAction(BaseModel): - """ - Represents an internal message used by orchestrators to trigger agents. - This prevents self-triggering loops. - """ - - task: Optional[str] = Field( - None, - description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.", - ) - workflow_instance_id: Optional[str] = Field( - default=None, description="Dapr workflow instance id from source if available" - ) diff --git a/dapr_agents/workflow/orchestrators/__init__.py b/dapr_agents/agents/orchestrators/__init__.py similarity index 73% rename from dapr_agents/workflow/orchestrators/__init__.py rename to dapr_agents/agents/orchestrators/__init__.py index efc9e5e9..88b64da2 100644 --- a/dapr_agents/workflow/orchestrators/__init__.py +++ b/dapr_agents/agents/orchestrators/__init__.py @@ -1,10 +1,10 @@ -from .base import OrchestratorWorkflowBase -from .llm import LLMOrchestrator +from .base import OrchestratorBase from .random import RandomOrchestrator from .roundrobin import RoundRobinOrchestrator +from .llm import LLMOrchestrator __all__ = [ - "OrchestratorWorkflowBase", + "OrchestratorBase", "LLMOrchestrator", "RandomOrchestrator", "RoundRobinOrchestrator", diff --git a/dapr_agents/agents/orchestrators/base.py b/dapr_agents/agents/orchestrators/base.py new file mode 100644 index 00000000..e77b07bb --- /dev/null +++ b/dapr_agents/agents/orchestrators/base.py @@ -0,0 +1,324 @@ +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Coroutine +from typing import Any, Dict, Optional + +import dapr.ext.workflow as wf + +from dapr_agents.agents.components import AgentComponents +from dapr_agents.agents.configs import ( + AgentPubSubConfig, + AgentRegistryConfig, + AgentStateConfig, +) +from dapr_agents.agents.utils.text_printer import ColorTextFormatter + +logger = logging.getLogger(__name__) + + +class OrchestratorBase(AgentComponents): + """ + Workflow-native orchestrator base built on AgentComponents. + + Overview: + Manages workflow runtime lifecycle (register/start/stop), optional + self-registration in the agent registry (marked as orchestrator), + console helpers for readable interactions, and small utilities like + raising workflow events. + """ + + def __init__( + self, + *, + name: str, + pubsub_config: Optional[AgentPubSubConfig] = None, + state_config: Optional[AgentStateConfig] = None, + registry_config: Optional[AgentRegistryConfig] = None, + agent_metadata: Optional[Dict[str, Any]] = None, + runtime: Optional[wf.WorkflowRuntime] = None, + workflow_client: Optional[wf.DaprWorkflowClient] = None, + ) -> None: + """ + Initialize the orchestrator base. + + Args: + name: Orchestrator name. + pubsub_config: Pub/Sub settings used to address agents via topics. + state_config: Durable state settings (if the orchestrator persists anything). + registry_config: Agent registry configuration for discovery. + agent_metadata: Extra metadata to store in the registry; ``orchestrator=True`` + is enforced automatically. + runtime: Optional pre-existing workflow runtime to attach to. + workflow_client: Optional DaprWorkflowClient for dependency injection/testing. + """ + super().__init__( + name=name, + pubsub_config=pubsub_config, + state_config=state_config, + registry_config=registry_config, + ) + + # Ensure registry entry marks this as an orchestrator + meta = dict(agent_metadata or {}) + meta.setdefault("orchestrator", True) + if self.registry_state is not None: + try: + self.register_agentic_system(metadata=meta) + except Exception: # noqa: BLE001 + logger.warning( + "Could not register orchestrator in registry.", exc_info=True + ) + + # Runtime wiring + self._runtime: wf.WorkflowRuntime = runtime or wf.WorkflowRuntime() + self._runtime_owned = runtime is None + self._registered = False + self._started = False + self._workflow_client = workflow_client or wf.DaprWorkflowClient() + + # Presentation helper (console) + self._text_formatter = ColorTextFormatter() + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + @property + def runtime(self) -> wf.WorkflowRuntime: + """Return the underlying workflow runtime.""" + return self._runtime + + @property + def is_started(self) -> bool: + """Return True if the runtime has been started by this orchestrator.""" + return self._started + + def register(self, runtime: wf.WorkflowRuntime) -> None: + """ + Register workflows/activities onto an external runtime. + + Subclasses must implement `register_workflows(runtime)` to perform registrations. + """ + self._runtime = runtime + self._runtime_owned = False + self.register_workflows(runtime) + self._registered = True + + def start( + self, + runtime: Optional[wf.WorkflowRuntime] = None, + *, + auto_register: bool = True, + ) -> None: + """ + Start the workflow runtime and register workflows/activities if needed. + + Behavior: + • If a runtime is provided, attach to it (we still consider it not owned). + • Register workflows once (if not already). + • Always attempt to start the runtime; treat start() as idempotent: + - If it's already running, swallow/log the exception and continue. + • We only call shutdown() later if we own the runtime. + """ + if self._started: + raise RuntimeError("Orchestrator has already been started.") + + if runtime is not None: + self._runtime = runtime + self._runtime_owned = False + self._registered = False + logger.info( + "Attached injected WorkflowRuntime (owned=%s).", self._runtime_owned + ) + + if auto_register and not self._registered: + self.register_workflows(self._runtime) + self._registered = True + logger.info("Registered workflows/activities on WorkflowRuntime.") + + try: + self._runtime.start() + logger.info("WorkflowRuntime started (owned=%s).", self._runtime_owned) + except Exception as exc: # noqa: BLE001 + logger.warning( + "WorkflowRuntime.start() raised (likely already running): %s", + exc, + exc_info=True, + ) + + self._started = True + + def stop(self) -> None: + """Stop the workflow runtime if owned by this instance.""" + if not self._started: + return + + if self._runtime_owned: + try: + self._runtime.shutdown() + except Exception: # noqa: BLE001 + logger.debug( + "Error while shutting down orchestrator runtime", exc_info=True + ) + + self._started = False + + # ------------------------------------------------------------------ + # Registration hook + # ------------------------------------------------------------------ + def register_workflows( + self, runtime: wf.WorkflowRuntime + ) -> None: # pragma: no cover + """ + Hook for subclasses to register workflows and activities. + + Example: + runtime.register_workflow(self.my_workflow) + runtime.register_activity(self.my_activity) + """ + raise NotImplementedError + + # ------------------------------------------------------------------ + # Presentation helpers (console) + # ------------------------------------------------------------------ + @property + def text_formatter(self) -> ColorTextFormatter: + """Formatter used for human-friendly console output.""" + return self._text_formatter + + @text_formatter.setter + def text_formatter(self, formatter: ColorTextFormatter) -> None: + """Override the default text formatter.""" + self._text_formatter = formatter + + def print_interaction( + self, source_agent_name: str, target_agent_name: str, message: str + ) -> None: + """ + Print a formatted interaction between two agents. + + Args: + source_agent_name: Sender name. + target_agent_name: Recipient name. + message: Message content. + """ + separator = "-" * 80 + parts = [ + (source_agent_name, "dapr_agents_pink"), + (" -> ", "dapr_agents_teal"), + (f"{target_agent_name}\n\n", "dapr_agents_pink"), + (message + "\n\n", "dapr_agents_pink"), + (separator + "\n", "dapr_agents_teal"), + ] + self._text_formatter.print_colored_text(parts) + + # ------------------------------------------------------------------ + # Team/registry convenience + # ------------------------------------------------------------------ + def list_team_agents( + self, *, team: Optional[str] = None, include_self: bool = True + ) -> Dict[str, Any]: + """ + Convenience wrapper over `get_agents_metadata`. + + Args: + team: Team override. + include_self: If True, include this orchestrator/agent in the results. + + Returns: + Mapping of agent name to metadata. + """ + return self.get_agents_metadata( + exclude_self=not include_self, + exclude_orchestrator=False, + team=team, + ) + + # ------------------------------------------------------------------ + # Event helpers + # ------------------------------------------------------------------ + def raise_workflow_event( + self, instance_id: str, event_name: str, *, data: Any | None = None + ) -> None: + """ + Raise an external event for a running workflow instance. + + Args: + instance_id: Target workflow instance id. + event_name: Name of the event to raise. + data: Optional payload. If it is a Pydantic-like object with ``model_dump``, + it will be serialized to a dict. + + Raises: + RuntimeError: If raising the event fails. + """ + try: + payload = self._serialize_event_data(data) + logger.info( + "Raising workflow event '%s' for instance '%s'", event_name, instance_id + ) + self._workflow_client.raise_workflow_event( + instance_id=instance_id, + event_name=event_name, + data=payload, + ) + logger.info( + "Raised workflow event '%s' for instance '%s'", event_name, instance_id + ) + except Exception as exc: # noqa: BLE001 + logger.error( + "Failed to raise workflow event '%s' for instance '%s'. Data=%s Error=%s", + event_name, + instance_id, + data, + exc, + ) + raise RuntimeError( + f"Failed to raise workflow event '{event_name}' for instance '{instance_id}': {exc}" + ) from exc + + @staticmethod + def _serialize_event_data(data: Any | None) -> Any: + """ + Best-effort serialization for event data. + + Args: + data: Arbitrary event payload. + + Returns: + A JSON-serializable payload. Pydantic-like objects are converted via ``model_dump``. + """ + if data is None: + return None + if hasattr(data, "model_dump"): + try: + return data.model_dump() + except Exception: # noqa: BLE001 + return data # fallback; Dapr client will attempt serialization + return data + + # ------------------------------------------------------------------ + # Small async helper (shared by orchestrators) + # ------------------------------------------------------------------ + @staticmethod + def _run_asyncio_task(coro: Coroutine[Any, Any, Any]) -> Any: + """ + Execute an async coroutine from a sync context, creating a fresh loop if needed. + + Args: + coro: The coroutine to execute. + + Returns: + The result of the coroutine execution. + """ + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + else: + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() diff --git a/dapr_agents/workflow/orchestrators/llm/__init__.py b/dapr_agents/agents/orchestrators/llm/__init__.py similarity index 100% rename from dapr_agents/workflow/orchestrators/llm/__init__.py rename to dapr_agents/agents/orchestrators/llm/__init__.py diff --git a/dapr_agents/agents/orchestrators/llm/base.py b/dapr_agents/agents/orchestrators/llm/base.py new file mode 100644 index 00000000..13848a02 --- /dev/null +++ b/dapr_agents/agents/orchestrators/llm/base.py @@ -0,0 +1,818 @@ +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +import dapr.ext.workflow as wf + +from dapr_agents.agents.configs import ( + AgentMemoryConfig, + AgentPubSubConfig, + AgentRegistryConfig, + AgentExecutionConfig, +) +from dapr_agents.agents.orchestrators.base import OrchestratorBase +from dapr_agents.agents.orchestrators.llm.configs import LLMOrchestratorStateConfig +from dapr_agents.agents.utils.text_printer import ColorTextFormatter +from dapr_agents.llm.chat import ChatClientBase +from dapr_agents.llm.utils.defaults import get_default_llm +from dapr_agents.memory import ConversationDaprStateMemory, ConversationListMemory +from dapr_agents.types import AssistantMessage, UserMessage + +logger = logging.getLogger(__name__) + + +class LLMOrchestratorBase(OrchestratorBase): + """ + A base class for LLM-driven orchestrators. This class manages the memory, LLM client, + and workflow state, and provides mechanisms for handling state persistence, + compensation, and broadcasting messages across agents. + + Responsibilities: + • Memory management (Dapr state-backed or in-memory). + • Dependency injection for LLM clients. + • Managing durable workflow state (messages, plan, finalization). + • Broadcast messages and trigger actions via pub/sub and agent registry. + • Compensation utilities for maintaining state consistency on failures. + """ + + def __init__( + self, + *, + name: str = "LLMOrchestrator", + pubsub_config: Optional[AgentPubSubConfig] = None, + state_config: Optional[LLMOrchestratorStateConfig] = None, + registry_config: Optional[AgentRegistryConfig] = None, + execution_config: Optional[AgentExecutionConfig] = None, + agent_metadata: Optional[Dict[str, Any]] = None, + memory_config: Optional[AgentMemoryConfig] = None, + llm: Optional[ChatClientBase] = None, + runtime: Optional[wf.WorkflowRuntime] = None, + workflow_client: Optional[wf.DaprWorkflowClient] = None, + ) -> None: + """ + Initializes the LLMOrchestrator with the provided configurations. + + Args: + name (str): Logical orchestrator name. + pubsub_config (Optional[AgentPubSubConfig]): Dapr Pub/Sub configuration. + state_config (Optional[LLMOrchestratorStateConfig]): State configuration for the orchestrator. + registry_config (Optional[AgentRegistryConfig]): Configuration for agent/team registry. + agent_metadata (Optional[Dict[str, Any]]): Metadata to store alongside the registry entry. + memory_config (Optional[AgentMemoryConfig]): Memory configuration for the orchestrator. + llm (Optional[ChatClientBase]): LLM client instance. + runtime (Optional[wf.WorkflowRuntime]): Workflow runtime configuration. + workflow_client (Optional[wf.DaprWorkflowClient]): Dapr workflow client. + """ + super().__init__( + name=name, + pubsub_config=pubsub_config, + state_config=state_config, + registry_config=registry_config, + execution_config=execution_config, + agent_metadata=agent_metadata, + runtime=runtime, + workflow_client=workflow_client, + ) + + # Memory wiring setup + self._memory_config = memory_config or AgentMemoryConfig() + if self._memory_config.store is None and state_config is not None: + self._memory_config.store = ConversationDaprStateMemory( + store_name=state_config.store.store_name, + session_id=f"{self.name}-session", + ) + self.memory = self._memory_config.store or ConversationListMemory() + + # Console formatting + self._text_formatter = ColorTextFormatter() + + # LLM client initialization + self.llm = llm or get_default_llm() + + # Initialize state if not present + if not getattr(self, "state", None): + self.state = {"instances": {}} + else: + self.state.setdefault("instances", {}) + + @property + def text_formatter(self) -> ColorTextFormatter: + """Returns the text formatter used for console output.""" + return self._text_formatter + + @text_formatter.setter + def text_formatter(self, formatter: ColorTextFormatter) -> None: + """Sets a custom text formatter for console output.""" + self._text_formatter = formatter + + @staticmethod + def _utcnow() -> datetime: + """Returns the current UTC time as a timezone-aware datetime.""" + return datetime.now(timezone.utc) + + @staticmethod + def _serialize_message(message: Any) -> Dict[str, Any]: + """ + Serializes a message-like object into a dictionary for storage. + + Args: + message (Any): The message object to serialize. + + Returns: + Dict[str, Any]: The serialized message as a dictionary. + + Raises: + TypeError: If the message type cannot be serialized. + """ + if hasattr(message, "model_dump"): + return message.model_dump() # type: ignore[no-any-return] + if isinstance(message, dict): + return dict(message) + if hasattr(message, "__dict__"): + return dict(message.__dict__) + raise TypeError( + f"Unsupported message type for serialization: {type(message)!r}" + ) + + @staticmethod + def _convert_plan_objects_to_dicts(plan_objects: List[Any]) -> List[Dict[str, Any]]: + """ + Converts plan objects (Pydantic models or dictionaries) into dictionaries. + + Args: + plan_objects (List[Any]): A list of plan objects to convert. + + Returns: + List[Dict[str, Any]]: The converted plan objects as dictionaries. + """ + if not plan_objects: + return [] + return [ + obj.model_dump() if hasattr(obj, "model_dump") else dict(obj) + for obj in plan_objects + ] + + def _ensure_instance_row( + self, instance_id: str, *, input_text: Optional[str] = None + ) -> None: + """ + Ensures an entry exists for the workflow instance in the state. + + This delegates to ensure_instance_exists() which uses the entry_factory + configured in LLMOrchestratorStateConfig to create proper model instances. + + Args: + instance_id (str): The workflow instance ID. + input_text (Optional[str]): The initial input text (if any) for the workflow. + """ + container = self._get_entry_container() + logger.debug( + "_ensure_instance_row: container type=%s, instance_id=%s, exists=%s", + type(container).__name__ if container else None, + instance_id, + instance_id in container if container else False, + ) + if container and instance_id not in container: + logger.debug( + "_ensure_instance_row: Creating new instance via ensure_instance_exists" + ) + # Use the parent class method which properly handles entry_factory + self.ensure_instance_exists( + instance_id=instance_id, + input_value=input_text or "", + triggering_workflow_instance_id=None, + time=self._utcnow(), + ) + # Check what was created + entry = container.get(instance_id) + logger.debug( + "_ensure_instance_row: Created entry type=%s", + type(entry).__name__ if entry else None, + ) + + def update_workflow_state( + self, + *, + instance_id: str, + message: Optional[Dict[str, Any]] = None, + final_output: Optional[str] = None, + plan: Optional[List[Dict[str, Any]]] = None, + wf_time: Optional[str] = None, + ) -> None: + """ + Updates the state of the workflow for a given instance. + + Args: + instance_id (str): The workflow instance ID. + message (Optional[Dict[str, Any]]): A message to append to the history. + final_output (Optional[str]): Final output of the workflow. + plan (Optional[List[Dict[str, Any]]]): The current plan snapshot. + wf_time (Optional[str]): Workflow time (ISO 8601 string). + """ + self._ensure_instance_row(instance_id) + + container = self._get_entry_container() + if not container or instance_id not in container: + logger.error( + "Cannot update state - instance %s not found in container", instance_id + ) + return + + entry = container[instance_id] + logger.info( + "update_workflow_state: entry type=%s, hasattr(plan)=%s", + type(entry).__name__, + hasattr(entry, "plan"), + ) + + if plan is not None: + logger.info( + "Updating plan: entry type=%s, plan length=%d", + type(entry).__name__, + len(plan), + ) + if hasattr(entry, "plan"): + logger.debug("Entry is a Pydantic model, setting plan attribute") + from dapr_agents.agents.orchestrators.llm.state import PlanStep + + entry.plan = [ + PlanStep(**step_dict) if isinstance(step_dict, dict) else step_dict + for step_dict in plan + ] # type: ignore[attr-defined] + else: + # Fallback for dict-based state + logger.info("Entry is a dict, setting plan key") + entry["plan"] = plan + + if message is not None: + msg = self._serialize_message(message) + logger.info( + "Processing message: entry type=%s, hasattr(messages)=%s", + type(entry).__name__, + hasattr(entry, "messages"), + ) + + if hasattr(entry, "messages"): + # Entry is a Pydantic model - need to convert dict to message model + if self._message_coercer: + msg_model = self._message_coercer(msg) + else: + msg_model = self._message_dict_to_message_model(msg) + logger.debug("Message model type: %s", type(msg_model).__name__) + entry.messages.append(msg_model) # type: ignore[attr-defined] + entry.last_message = msg_model # type: ignore[attr-defined] + else: + # Fallback for dict-based state + logger.debug("Entry is dict, appending message dict directly") + if "messages" not in entry: + entry["messages"] = [] + entry["messages"].append(msg) + entry["last_message"] = msg + + try: + role = msg.get("role", "user") + content = msg.get("content", "") + if role == "assistant": + self.memory.add_message( + AssistantMessage(content=content, name=msg.get("name")) + ) + elif role == "user": + self.memory.add_message( + UserMessage(content=content, name=msg.get("name")) + ) + except Exception: + logger.info("Failed to mirror message into memory.", exc_info=True) + + if final_output is not None: + end_time_value = self._coerce_datetime(wf_time) + + if hasattr(entry, "output"): + entry.output = final_output # type: ignore[attr-defined] + entry.end_time = end_time_value # type: ignore[attr-defined] + else: + # Dict-based state fallback - store as ISO string + entry["output"] = final_output + entry["end_time"] = end_time_value.isoformat() + + self.save_state() + + async def rollback_workflow_initialization(self, instance_id: str) -> None: + """Clears a partially-created plan for an instance.""" + try: + container = self._get_entry_container() + entry = container.get(instance_id) if container else None + if entry: + if hasattr(entry, "plan"): + entry.plan = [] # type: ignore[attr-defined] + else: + entry["plan"] = [] + self.save_state() + logger.debug("Rolled back workflow initialization for %s", instance_id) + except Exception: + logger.exception("Failed to rollback workflow initialization.") + + async def rollback_agent_trigger( + self, instance_id: str, step_id: int, substep_id: Optional[float] + ) -> None: + """Reverts a step from `in_progress` back to `not_started`.""" + from dapr_agents.workflow.orchestrators.llm.utils import find_step_in_plan + + try: + container = self._get_entry_container() + entry = container.get(instance_id) if container else None + if entry: + plan = ( + getattr(entry, "plan", None) + if hasattr(entry, "plan") + else entry.get("plan", []) + ) + step_entry = find_step_in_plan(plan, step_id, substep_id) + if step_entry and step_entry.get("status") == "in_progress": + step_entry["status"] = "not_started" + self.update_workflow_state(instance_id=instance_id, plan=plan) + logger.debug( + "Rolled back agent trigger for %s (%s/%s)", + instance_id, + step_id, + substep_id, + ) + except Exception: + logger.exception("Failed to rollback agent trigger.") + + async def rollback_agent_response_processing( + self, instance_id: str, agent: str, step_id: int, substep_id: Optional[float] + ) -> None: + """Undo the last task history entry and revert `completed` -> `in_progress` if needed.""" + from dapr_agents.workflow.orchestrators.llm.utils import find_step_in_plan + + try: + container = self._get_entry_container() + entry = container.get(instance_id) if container else None + if entry: + hist = ( + getattr(entry, "task_history", None) + if hasattr(entry, "task_history") + else entry.get("task_history", []) + ) + for i in range(len(hist) - 1, -1, -1): + t = hist[i] + if ( + t.get("agent") == agent + and t.get("step") == step_id + and t.get("substep") == substep_id + ): + hist.pop(i) + break + plan = ( + getattr(entry, "plan", None) + if hasattr(entry, "plan") + else entry.get("plan", []) + ) + step_entry = find_step_in_plan(plan, step_id, substep_id) + if step_entry and step_entry.get("status") == "completed": + step_entry["status"] = "in_progress" + self.update_workflow_state(instance_id=instance_id, plan=plan) + logger.debug( + "Rolled back response processing for agent=%s step=%s substep=%s", + agent, + step_id, + substep_id, + ) + except Exception: + logger.exception("Failed to rollback agent response processing.") + + async def rollback_workflow_finalization(self, instance_id: str) -> None: + """Clear output and end time if finalization failed.""" + try: + container = self._get_entry_container() + entry = container.get(instance_id) if container else None + if entry: + if hasattr(entry, "output"): + entry.output = None # type: ignore[attr-defined] + entry.end_time = None # type: ignore[attr-defined] + else: + entry["output"] = None + entry["end_time"] = None + self.save_state() + logger.info("Rolled back workflow finalization for %s", instance_id) + except Exception: + logger.exception("Failed to rollback workflow finalization.") + + async def ensure_workflow_state_consistency(self, instance_id: str) -> None: + """Ensure that the instance row exists and contains the required keys.""" + try: + self._ensure_instance_row(instance_id) + container = self._get_entry_container() + entry = container.get(instance_id) if container else None + if entry: + if hasattr(entry, "plan"): + # Pydantic model - fields should already exist + pass + else: + # Dict-based state + entry.setdefault("plan", []) + entry.setdefault("messages", []) + entry.setdefault("task_history", []) + self.save_state() + except Exception: + logger.exception("Failed to ensure workflow state consistency.") + + async def compensate_failed_activity( + self, + *, + instance_id: str, + failed_activity: str, + activity_context: Dict[str, Any], + ) -> Dict[str, Any]: + """Generic compensator for multi-step activities.""" + actions: List[str] = [] + try: + if failed_activity == "initialize_workflow_with_plan": + await self.rollback_workflow_initialization(instance_id) + actions.append("cleared_partial_plan") + elif failed_activity == "execute_agent_task_with_progress_tracking": + await self.rollback_agent_trigger( + instance_id, + activity_context.get("step_id"), + activity_context.get("substep_id"), + ) + actions.append("reverted_step_status") + elif failed_activity == "process_agent_response_with_progress": + await self.rollback_agent_response_processing( + instance_id, + activity_context.get("agent"), + activity_context.get("step_id"), + activity_context.get("substep_id"), + ) + actions.append("reverted_response_processing") + elif failed_activity == "finalize_workflow_with_summary": + await self.rollback_workflow_finalization(instance_id) + actions.append("reverted_finalization") + + await self.ensure_workflow_state_consistency(instance_id) + return { + "status": "compensated", + "failed_activity": failed_activity, + "compensation_actions": actions, + } + except Exception as exc: + logger.exception("Compensation failed for %s", failed_activity) + return { + "status": "compensation_failed", + "failed_activity": failed_activity, + "error": str(exc), + } + + async def execute_with_compensation( + self, activity_coro, *, activity_name: str, instance_id: str, **kwargs: Any + ) -> Any: + """Execute an async activity and auto-compensate on failure.""" + try: + return await activity_coro + except Exception as exc: + logger.error("Activity %s failed: %s", activity_name, exc) + ctx = {"instance_id": instance_id, "error": str(exc), **kwargs} + result = await self.compensate_failed_activity( + instance_id=instance_id, + failed_activity=activity_name, + activity_context=ctx, + ) + if result.get("status") != "compensated": + logger.error("Compensation failed: %s", result) + raise + + async def update_task_history_internal( + self, + *, + instance_id: str, + agent: str, + step: int, + substep: Optional[float], + results: Dict[str, Any], + plan: List[Dict[str, Any]], + ) -> None: + """ + Append a task result to workflow messages and task history, then persist plan pointer. + + Args: + instance_id: Workflow instance id. + agent: Agent producing the results. + step: Plan step id. + substep: Plan substep id (if any). + results: Message-like result dict from the agent. + plan: Current plan snapshot (dicts). + + Raises: + ValueError: If the instance row does not exist (unexpected). + """ + from dapr_agents.agents.orchestrators.llm.state import TaskResult + + logger.debug( + "Updating task history for %s at step %s, substep %s (instance=%s)", + agent, + step, + substep, + instance_id, + ) + + # Store the agent's response in the message history + self.update_workflow_state(instance_id=instance_id, message=results) + + # Retrieve workflow entry from container + container = self._get_entry_container() + entry = container.get(instance_id) if container else None + if not entry: + msg = f"No workflow entry for instance {instance_id}" + raise ValueError(msg) + + # Create a TaskResult object + task_result = TaskResult( + agent=agent, + step=step, + substep=substep, + result=results.get("content", ""), + ) + + # Append the result to task history + if hasattr(entry, "task_history"): + if not hasattr(entry.task_history, "append"): + entry.task_history = [] # type: ignore[attr-defined] + # Store TaskResult model instance directly instead of dict to avoid serialization warnings + entry.task_history.append(task_result) # type: ignore[attr-defined] + else: + entry.setdefault("task_history", []).append( + task_result.model_dump(mode="json") + ) + + # Get current plan from entry + current_plan = ( + getattr(entry, "plan", None) + if hasattr(entry, "plan") + else entry.get("plan", plan) + ) + + # Persist state with updated plan + self.update_workflow_state(instance_id=instance_id, plan=current_plan) + + async def trigger_agent_internal( + self, + *, + instance_id: str, + name: str, + step: int, + substep: Optional[float], + instruction: str, + plan: List[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + """ + Mark the referenced (step/substep) as in_progress, persist, and send a directed trigger. + + Args: + instance_id: Workflow instance id. + name: Target agent name to trigger. + step: Plan step id. + substep: Plan substep id (if any). + instruction: Instruction to send the target agent. + plan: Current plan snapshot (dicts). + + Returns: + The updated plan (with status changes applied). + + Raises: + ValueError: If the (step, substep) reference does not exist in the plan. + RuntimeError: If sending the trigger fails due to pub/sub issues. + """ + from dapr_agents.agents.schemas import TriggerAction + from dapr_agents.workflow.orchestrators.llm.utils import ( + find_step_in_plan, + update_step_statuses, + ) + from dapr_agents.workflow.utils.pubsub import send_message_to_agent + + logger.info( + "Triggering agent %s for step %s/%s (instance=%s)", + name, + step, + substep, + instance_id, + ) + + # Ensure the step or substep exists + step_entry = find_step_in_plan(plan, step, substep) + if not step_entry: + if substep is not None: + msg = f"Substep {substep} in Step {step} not found in the current plan." + else: + msg = f"Step {step} not found in the current plan." + raise ValueError(msg) + + # Mark step or substep as "in_progress" + step_entry["status"] = "in_progress" + logger.debug("Marked step %s, substep %s as 'in_progress'", step, substep) + + # Apply global status updates to maintain consistency + updated_plan = update_step_statuses(plan) + + # Save updated plan state + self.update_workflow_state(instance_id=instance_id, plan=updated_plan) + + # Get agents metadata for pub/sub + agents_metadata = self.list_team_agents( + include_self=False, team=self.effective_team() + ) + + # Send message to agent with specific task instruction + trigger = TriggerAction(task=instruction, workflow_instance_id=instance_id) + + async def _send() -> None: + await send_message_to_agent( + source=self.name, + target_agent=name, + message=trigger, + agents_metadata=agents_metadata, + ) + + try: + import asyncio + + loop = asyncio.get_running_loop() + await _send() if loop.is_running() else asyncio.run(_send()) + except Exception: + logger.exception("Failed to send trigger to agent %s", name) + raise + + return updated_plan + + async def update_plan_internal( + self, + *, + instance_id: str, + plan: List[Dict[str, Any]], + status_updates: Optional[List[Dict[str, Any]]] = None, + plan_updates: Optional[List[Dict[str, Any]]] = None, + ) -> List[Dict[str, Any]]: + """ + Apply status/structure updates to the plan and persist them. + + Args: + instance_id: Workflow instance id. + plan: Current plan snapshot. + status_updates: Optional status updates, each with `step`, optional `substep`, + and `status` fields. + plan_updates: Optional structural updates (see `restructure_plan` utility). + + Returns: + The updated plan after applying changes. + + Raises: + ValueError: If a referenced step/substep is not found in the plan. + """ + from dapr_agents.workflow.orchestrators.llm.utils import ( + find_step_in_plan, + restructure_plan, + update_step_statuses, + ) + + logger.debug("Updating plan for instance %s", instance_id) + + # Validate and apply status updates. + if status_updates: + logger.info("Applying %d status update(s) to plan", len(status_updates)) + for u in status_updates: + step_id = u["step"] + sub_id = u.get("substep") + new_status = u["status"] + + logger.debug( + "Updating step %s/%s to status '%s'", + step_id, + sub_id, + new_status, + ) + target = find_step_in_plan(plan, step_id, sub_id) + if not target: + msg = f"Step {step_id}/{sub_id} not present in plan." + logger.error(msg) + raise ValueError(msg) + + # Apply status update + target["status"] = new_status + logger.debug( + "Successfully updated status of step %s/%s to '%s'", + step_id, + sub_id, + new_status, + ) + + # Apply structural updates while preserving substeps unless explicitly overridden. + if plan_updates: + logger.debug("Applying %d plan restructuring update(s)", len(plan_updates)) + plan = restructure_plan(plan, plan_updates) + + # Apply global consistency checks for statuses + plan = update_step_statuses(plan) + + # Persist the updated plan + self.update_workflow_state(instance_id=instance_id, plan=plan) + + logger.debug("Plan successfully updated for instance %s", instance_id) + return plan + + async def finish_workflow_internal( + self, + *, + instance_id: str, + plan: List[Dict[str, Any]], + step: int, + substep: Optional[float], + verdict: str, + summary: str, + wf_time: Optional[str], + ) -> None: + """ + Finalize workflow by updating statuses (if completed) and storing the summary. + + Args: + instance_id: Workflow instance id. + plan: Current plan snapshot. + step: Completed step id. + substep: Completed substep id (if any). + verdict: Outcome category (e.g., "completed", "failed", "max_iterations_reached"). + summary: Final summary content to persist. + wf_time: Workflow timestamp (ISO 8601 string) to set as end time if provided. + + Returns: + None + + Raises: + ValueError: If a completed step/substep reference is invalid. + """ + from dapr_agents.agents.orchestrators.llm.utils import find_step_in_plan + + logger.debug( + "Finalizing workflow for instance %s with verdict '%s'", + instance_id, + verdict, + ) + + status_updates: List[Dict[str, Any]] = [] + + if verdict == "completed": + # Find and validate the step or substep + step_entry = find_step_in_plan(plan, step, substep) + if not step_entry: + msg = f"Step {step}/{substep} not found in plan; cannot mark as completed." + logger.error(msg) + raise ValueError(msg) + + # Mark the step or substep as completed + step_entry["status"] = "completed" + status_updates.append( + {"step": step, "substep": substep, "status": "completed"} + ) + logger.debug("Marked step %s/%s as completed", step, substep) + + # If it's a substep, check if all sibling substeps are completed + if substep is not None: + parent_step = find_step_in_plan( + plan, step + ) # Get parent without substep + if parent_step: + # Ensure "substeps" is a valid list before iteration + substeps = parent_step.get("substeps", []) + if not isinstance(substeps, list): + substeps = [] + + all_substeps_completed = all( + ss.get("status") == "completed" for ss in substeps + ) + if all_substeps_completed: + parent_step["status"] = "completed" + status_updates.append({"step": step, "status": "completed"}) + logger.debug( + "All substeps of step %s completed; marked parent as completed", + step, + ) + + # Apply updates in one call if any status changes were made + if status_updates: + await self.update_plan_internal( + instance_id=instance_id, + plan=plan, + status_updates=status_updates, + ) + + # Store the final summary and verdict in workflow state + self.update_workflow_state( + instance_id=instance_id, + final_output=summary, + wf_time=wf_time, + ) + + logger.info( + "Workflow %s finalized with verdict '%s'", + instance_id, + verdict, + ) diff --git a/dapr_agents/agents/orchestrators/llm/configs.py b/dapr_agents/agents/orchestrators/llm/configs.py new file mode 100644 index 00000000..74c96531 --- /dev/null +++ b/dapr_agents/agents/orchestrators/llm/configs.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any, Dict, MutableMapping, Optional, Type + +from pydantic import BaseModel + +from dapr_agents.agents.configs import ( + AgentStateConfig, + EntryContainerGetter, + EntryFactory, + MessageCoercer, +) +from dapr_agents.storage.daprstores.stateservice import StateStoreService + +from dapr_agents.agents.orchestrators.llm.state import ( + LLMWorkflowEntry, + LLMWorkflowMessage, + LLMWorkflowState, +) + + +# ---------- helpers (module-private) ---------- + + +def _utcnow() -> datetime: + """Timezone-aware now in UTC.""" + return datetime.now(timezone.utc) + + +def _maybe_aware(dt: Optional[datetime]) -> datetime: + """Coerce naive datetimes to UTC.""" + if dt is None: + return _utcnow() + return dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc) + + +def _default_entry_factory( + *, + instance_id: str, + input_value: Any, + triggering_workflow_instance_id: Optional[str], + start_time: Optional[datetime], +) -> LLMWorkflowEntry: + """ + Create a baseline LLM workflow entry with non-null collections. + Forward-compatible: sets optional ids only if model defines them. + """ + ts = _maybe_aware(start_time) + + # Only populate optional ids if the model actually defines those fields + opt: Dict[str, Any] = {} + fields = getattr(LLMWorkflowEntry, "model_fields", {}) + if "workflow_instance_id" in fields: + opt["workflow_instance_id"] = instance_id + if "triggering_workflow_instance_id" in fields: + opt["triggering_workflow_instance_id"] = triggering_workflow_instance_id + + return LLMWorkflowEntry( + input=str(input_value or ""), + output=None, + start_time=ts, + end_time=None, + messages=[], + last_message=None, + plan=[], # never None + task_history=[], # never None + **opt, + ) + + +def _default_message_coercer(raw: Dict[str, Any]) -> LLMWorkflowMessage: + """ + Coerce raw dicts into the LLM message model. + - Whitelists known fields + - Defaults role/content + - Accepts either datetime or ISO8601 string timestamps + """ + allowed = {"role", "content", "name", "id", "timestamp"} + payload = {k: raw[k] for k in allowed if k in raw} + + # sensible defaults + payload.setdefault("role", "system") + + content = payload.get("content", "") + payload["content"] = content if isinstance(content, str) else str(content) + + # timestamp: accept str or datetime and coerce to aware datetime + ts = payload.get("timestamp") + if isinstance(ts, str): + try: + dt = datetime.fromisoformat(ts) + payload["timestamp"] = dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc) + except ValueError: + payload["timestamp"] = _utcnow() + elif isinstance(ts, datetime): + payload["timestamp"] = _maybe_aware(ts) + else: + payload["timestamp"] = _utcnow() + + return LLMWorkflowMessage(**payload) + + +def _default_entry_container_getter( + model: BaseModel, +) -> Optional[MutableMapping[str, Any]]: + """Return the container that maps instance_id -> entry (if present).""" + return getattr(model, "instances", None) + + +# ---------- public config ---------- + + +@dataclass +class LLMOrchestratorStateConfig(AgentStateConfig): + """ + Drop-in state config for LLM orchestrators. + + Defaults: + • state_model_cls = LLMWorkflowState + • message_model_cls = LLMWorkflowMessage + • entry_factory = _default_entry_factory (non-null plan/history; optional ids) + • message_coercer = _default_message_coercer (defensive + UTC) + • entry_container_getter = instances + + Only `store` is required: + LLMOrchestratorStateConfig(store=StateStoreService(...)) + """ + + # required + store: StateStoreService = None # type: ignore[assignment] + + # baked-in LLM defaults + state_model_cls: Type[BaseModel] = LLMWorkflowState + message_model_cls: Type[BaseModel] = LLMWorkflowMessage + entry_factory: Optional[EntryFactory] = _default_entry_factory + message_coercer: Optional[MessageCoercer] = _default_message_coercer + entry_container_getter: Optional[ + EntryContainerGetter + ] = _default_entry_container_getter + + def __post_init__(self) -> None: + """ + Ensure the base normalization runs with our defaults. + - Validates model classes + - Normalizes/validates `default_state` against LLMWorkflowState + """ + super().__post_init__() diff --git a/dapr_agents/agents/orchestrators/llm/orchestrator.py b/dapr_agents/agents/orchestrators/llm/orchestrator.py new file mode 100644 index 00000000..dc652f4d --- /dev/null +++ b/dapr_agents/agents/orchestrators/llm/orchestrator.py @@ -0,0 +1,668 @@ +from __future__ import annotations + +import json +import logging +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List, Optional + +import dapr.ext.workflow as wf +from durabletask import task as dt_task + +from dapr_agents.agents.orchestrators.llm.base import LLMOrchestratorBase +from dapr_agents.agents.schemas import ( + AgentTaskResponse, + BroadcastMessage, + TriggerAction, +) +from dapr_agents.workflow.decorators.routers import message_router +from dapr_agents.agents.orchestrators.llm.prompts import ( + NEXT_STEP_PROMPT, + PROGRESS_CHECK_PROMPT, + SUMMARY_GENERATION_PROMPT, + TASK_INITIAL_PROMPT, + TASK_PLANNING_PROMPT, +) +from dapr_agents.agents.orchestrators.llm.schemas import ( + IterablePlanStep, + NextStep, + ProgressCheckOutput, + schemas, +) +from dapr_agents.agents.orchestrators.llm.state import PlanStep +from dapr_agents.agents.orchestrators.llm.utils import find_step_in_plan +from dapr_agents.workflow.utils.pubsub import broadcast_message + +logger = logging.getLogger(__name__) + + +class LLMOrchestrator(LLMOrchestratorBase): + """ + LLM-driven orchestrator that dynamically selects the next agent based on context and plan. + Interacts with agents in a multi-step workflow, using an LLM to decide the next step, + validates and triggers agents, and handles responses. Ensures steps are executed in order, + checks for progress, and finalizes the workflow with a summary. + """ + + def __init__( + self, + *, + name: str = "LLMOrchestrator", + timeout_seconds: int = 60, + **kwargs: Any, + ) -> None: + """ + Initializes the orchestrator with the provided configuration parameters. + + Args: + name (str): Logical name of the orchestrator. + timeout_seconds (int): Timeout duration for awaiting agent responses (in seconds). + """ + super().__init__(name=name, **kwargs) + self.timeout = max(1, int(timeout_seconds)) + + def register_workflows(self, runtime: wf.WorkflowRuntime) -> None: + """ + Registers workflows and activities to the provided Dapr WorkflowRuntime. + """ + runtime.register_workflow(self.llm_orchestrator_workflow) + runtime.register_workflow(self.route_agent_response) + runtime.register_activity(self._broadcast_activity) + runtime.register_activity(self._validate_next_step) + runtime.register_activity(self._get_available_agents) + runtime.register_activity(self._initialize_workflow_with_plan) + runtime.register_activity(self._generate_next_step) + runtime.register_activity(self._execute_agent_task_with_progress_tracking) + runtime.register_activity(self._process_agent_response_with_progress) + runtime.register_activity(self._finalize_workflow_with_summary) + + @message_router(message_model=TriggerAction) + def llm_orchestrator_workflow( + self, ctx: wf.DaprWorkflowContext, message: Dict[str, Any] + ): + """ + Orchestrates the workflow, handling up to `self.execution_config.max_iterations` turns using an LLM to choose the next step/agent. + """ + task_text: Optional[str] = message.get("task") + parent_id: Optional[str] = message.get("workflow_instance_id") + instance_id = ctx.instance_id + final_summary: Optional[str] = None + + # Ensure the instance exists in the state model + self.ensure_instance_exists( + instance_id=instance_id, + input_value=task_text or "", + triggering_workflow_instance_id=parent_id, + time=ctx.current_utc_datetime, + ) + + for turn in range(1, self.execution_config.max_iterations + 1): + if not ctx.is_replaying: + logger.info( + "LLM turn %d/%d (instance=%s)", + turn, + self.execution_config.max_iterations, + instance_id, + ) + + # Discover available agents + agents = yield ctx.call_activity(self._get_available_agents) + + # Turn 1: initialize plan & broadcast + if turn == 1: + init = yield ctx.call_activity( + self._initialize_workflow_with_plan, + input={ + "instance_id": instance_id, + "task": task_text or "", + "agents": agents, + "wf_time": ctx.current_utc_datetime.isoformat(), + }, + ) + plan = init["plan"] + if not ctx.is_replaying: + logger.info( + "Received plan from initialization with %d steps", len(plan) + ) + initial_message = init["message"] + + # Broadcast the initial plan to all agents + if not ctx.is_replaying: + logger.info( + "Broadcasting initial plan with %d steps to all agents", + len(plan), + ) + yield ctx.call_activity( + self._broadcast_activity, + input={"message": initial_message}, + ) + if not ctx.is_replaying: + logger.info("Initial plan broadcast completed") + else: + plan = list( + self.state.get("instances", {}).get(instance_id, {}).get("plan", []) + ) + if not ctx.is_replaying: + logger.info( + "Loaded plan from state with %d steps (turn %d)", + len(plan), + turn, + ) + + # Fallback: if plan is empty/None, try reading from state + if not plan: + plan = list( + self.state.get("instances", {}).get(instance_id, {}).get("plan", []) + ) + if not ctx.is_replaying: + logger.warning( + "Plan was empty, fallback loaded %d steps from state", len(plan) + ) + + # Ask LLM for next step/agent + next_step = yield ctx.call_activity( + self._generate_next_step, + input={ + "task": task_text or "", + "agents": agents, + "plan": json.dumps( + self._convert_plan_objects_to_dicts(plan), indent=2 + ), + "next_step_schema": schemas.next_step, + }, + ) + + next_agent = next_step["next_agent"] + instruction = next_step["instruction"] + step_id = next_step.get("step") + substep_id = next_step.get("substep") + + # Validate the next step + is_valid = yield ctx.call_activity( + self._validate_next_step, + input={ + "instance_id": instance_id, + "plan": self._convert_plan_objects_to_dicts(plan), + "step": step_id, + "substep": substep_id, + }, + ) + + if is_valid: + result = yield ctx.call_activity( + self._execute_agent_task_with_progress_tracking, + input={ + "instance_id": instance_id, + "next_agent": next_agent, + "step_id": step_id, + "substep_id": substep_id, + "instruction": instruction, + "task": task_text or "", + "plan_objects": self._convert_plan_objects_to_dicts(plan), + }, + ) + plan = result["plan"] + + # Await response or timeout + event_task = ctx.wait_for_external_event("AgentTaskResponse") + timeout_task = ctx.create_timer(timedelta(seconds=self.timeout)) + winner = yield dt_task.when_any([event_task, timeout_task]) + + if winner == timeout_task: + if not ctx.is_replaying: + logger.warning( + "Turn %d timed out waiting for agent response (instance=%s)", + turn, + instance_id, + ) + task_results = { + "name": "timeout", + "content": "⏰ Timeout occurred. Continuing...", + } + else: + task_results = yield event_task + # Normalize + if hasattr(task_results, "model_dump"): + task_results = task_results.model_dump() + elif not isinstance(task_results, dict): + task_results = dict(task_results) + + if not ctx.is_replaying: + self.print_interaction( + source_agent_name=task_results.get("name", "agent"), + target_agent_name=self.name, + message=task_results.get("content", ""), + ) + processed = yield ctx.call_activity( + self._process_agent_response_with_progress, + input={ + "instance_id": instance_id, + "agent": next_agent, + "step_id": step_id, + "substep_id": substep_id, + "task_results": task_results, + "task": task_text or "", + "plan_objects": self._convert_plan_objects_to_dicts(plan), + }, + ) + plan = processed["plan"] + verdict = processed["verdict"] + else: + verdict = "continue" + task_results = { + "name": self.name, + "role": "user", + "content": f"Step {step_id}, substep {substep_id} not found. Adjusting workflow…", + } + + if verdict != "continue" or turn == self.execution_config.max_iterations: + final_summary = yield ctx.call_activity( + self._finalize_workflow_with_summary, + input={ + "instance_id": instance_id, + "task": task_text or "", + "verdict": verdict + if verdict != "continue" + else "max_iterations_reached", + "plan_objects": self._convert_plan_objects_to_dicts(plan), + "step_id": step_id, + "substep_id": substep_id, + "agent": next_agent if is_valid else self.name, + "result": task_results["content"], + "wf_time": ctx.current_utc_datetime.isoformat(), + }, + ) + if not ctx.is_replaying: + logger.info("Workflow %s finalized.", instance_id) + return final_summary + else: + task_text = task_results["content"] + + raise RuntimeError(f"{self.name} workflow {instance_id} exited without summary") + + @message_router(message_model=AgentTaskResponse) + def route_agent_response( + self, ctx: wf.DaprWorkflowContext, message: Dict[str, Any] + ) -> None: + """Route AgentTaskResponse messages into the running workflow.""" + instance_id = message.get("workflow_instance_id") + if not instance_id: + logger.error("AgentTaskResponse missing workflow_instance_id; ignoring.") + return + try: + self.raise_workflow_event( + instance_id=instance_id, + event_name="AgentTaskResponse", + data=message, + ) + except RuntimeError: + return + + # ------------------------------------------------------------------ + # Activities + # ------------------------------------------------------------------ + + def _broadcast_activity( + self, + ctx: wf.WorkflowActivityContext, + payload: Dict[str, Any], + ) -> None: + """Broadcast a message to all agents (if a broadcast topic is configured).""" + message = payload.get("message", {}) + if not isinstance(message, dict): + logger.warning( + "Skipping broadcast: payload is not a dict, type=%s", + type(message).__name__, + ) + return + if not self.broadcast_topic_name: + logger.warning( + "Skipping broadcast: no broadcast topic configured (broadcast_topic_name=%s)", + self.broadcast_topic_name, + ) + return + + logger.info( + "Broadcasting message from %s to topic %s", + self.name, + self.broadcast_topic_name, + ) + + try: + agents_metadata = self.list_team_agents( + include_self=False, team=self.effective_team() + ) + logger.info( + "Found %d agents to broadcast to: %s", + len(agents_metadata), + list(agents_metadata.keys()), + ) + except Exception: + logger.exception("Unable to load agents metadata; broadcast aborted.") + return + + message["role"] = message.get("role", "user") + message["name"] = self.name + broadcast_payload = BroadcastMessage(**message) + + async def _broadcast() -> None: + await broadcast_message( + message=broadcast_payload, + broadcast_topic=self.broadcast_topic_name, # type: ignore[union-attr] + message_bus=self.message_bus_name, # type: ignore[union-attr] + source=self.name, + agents_metadata=agents_metadata, + ) + + try: + self._run_asyncio_task(_broadcast()) + except Exception: # noqa: BLE001 + logger.exception("Failed to publish broadcast message.") + + def _get_available_agents(self, ctx: wf.WorkflowActivityContext) -> str: + """ + Return a human-formatted list of available agents (excluding orchestrators). + + Args: + ctx: The Dapr Workflow context. + + Returns: + A formatted string listing available agents. + """ + agents_metadata = self.list_team_agents( + include_self=False, team=self.effective_team() + ) + if not agents_metadata: + return "No available agents to assign tasks." + lines = [] + for name, meta in agents_metadata.items(): + role = meta.get("role", "Unknown role") + goal = meta.get("goal", "Unknown") + lines.append(f"- {name}: {role} (Goal: {goal})") + return "\n".join(lines) + + def _initialize_workflow_with_plan( + self, ctx: wf.WorkflowActivityContext, payload: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Generate (or rehydrate) a plan. + + Args: + ctx: The Dapr Workflow context. + payload: The input containing instance ID, task details, available agents, and workflow time. + + Returns: + A dictionary containing the plan and message to broadcast. + """ + instance_id = payload["instance_id"] + task = payload["task"] + agents = payload["agents"] + wf_time = payload["wf_time"] + + # Use flexible container accessor (supports custom state layouts) + container = self._get_entry_container() + entry = container.get(instance_id) if container else None + + # Check if THIS instance already has a plan (from a previous turn/replay) + plan_dicts: List[Dict[str, Any]] + existing_plan = getattr(entry, "plan", None) if entry else None + + if existing_plan: + logger.info( + "Reusing existing plan with %d steps from instance %s", + len(existing_plan), + instance_id, + ) + # Convert Plan Step objects to dicts (existing_plan could be List[PlanStep] or List[dict]) + plan_dicts = self._convert_plan_objects_to_dicts(existing_plan) + else: + logger.info("Generating new plan for task: %s", task[:100]) + response = self.llm.generate( + messages=[ + { + "role": "user", + "content": TASK_PLANNING_PROMPT.format( + task=task, agents=agents, plan_schema=schemas.plan + ), + } + ], + response_format=IterablePlanStep, + ) + response_dict = response.model_dump() + plan_objects = [PlanStep(**d) for d in response_dict.get("objects", [])] + plan_dicts = self._convert_plan_objects_to_dicts(plan_objects) + logger.info("Generated new plan with %d steps", len(plan_dicts)) + logger.debug("Plan details: %s", json.dumps(plan_dicts, indent=2)) + + # Persist and broadcast + self.update_workflow_state( + instance_id=instance_id, plan=plan_dicts, wf_time=wf_time + ) + logger.info( + "Persisted plan with %d steps to state for instance %s", + len(plan_dicts), + instance_id, + ) + + formatted_message = TASK_INITIAL_PROMPT.format( + task=task, agents=agents, plan=json.dumps(plan_dicts, indent=2) + ) + initial_message = {"role": "user", "content": formatted_message} + logger.info( + "Returning plan with %d steps from initialization activity", len(plan_dicts) + ) + return {"plan": plan_dicts, "message": initial_message} + + def _generate_next_step( + self, ctx: wf.WorkflowActivityContext, payload: Dict[str, Any] + ) -> Dict[str, Any]: + """Ask the LLM which agent acts next, with instruction and step ids.""" + prompt = NEXT_STEP_PROMPT.format( + task=payload["task"], + agents=payload["agents"], + plan=payload["plan"], + next_step_schema=payload["next_step_schema"], + ) + resp = self.llm.generate( + messages=[{"role": "user", "content": prompt}], + response_format=NextStep, + structured_mode="json", + ) + if hasattr(resp, "choices") and resp.choices: + data = resp.choices[0].message.content + return dict(NextStep(**json.loads(data))) + # Fallback if your LLM client returns a pydantic instance directly + if isinstance(resp, NextStep): + return resp.model_dump() + return dict(resp) + + def _validate_next_step( + self, ctx: wf.WorkflowActivityContext, payload: Dict[str, Any] + ) -> bool: + """Return True if (step, substep) exists in the plan.""" + step = payload["step"] + substep = payload.get("substep") + plan = payload["plan"] + ok = bool(find_step_in_plan(plan, step, substep)) + if not ok: + logger.error( + "Step %s/%s not in plan for instance %s", + step, + substep, + payload.get("instance_id"), + ) + return ok + + def _execute_agent_task_with_progress_tracking( + self, ctx: wf.WorkflowActivityContext, payload: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Mark step in_progress, persist, and trigger the agent with an InternalTriggerAction. + """ + + async def _execute() -> List[Dict[str, Any]]: + return await self.execute_with_compensation( + self.trigger_agent_internal( + instance_id=payload["instance_id"], + name=payload["next_agent"], + step=payload["step_id"], + substep=payload["substep_id"], + instruction=payload["instruction"], + plan=list(payload["plan_objects"]), + ), + activity_name="execute_agent_task_with_progress_tracking", + instance_id=payload["instance_id"], + step_id=payload["step_id"], + substep_id=payload["substep_id"], + ) + + updated_plan = self._run_asyncio_task(_execute()) + return {"plan": updated_plan, "status": "agent_triggered"} + + def _process_agent_response_with_progress( + self, ctx: wf.WorkflowActivityContext, payload: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Append the agent's response, ask LLM for progress verdict/updates, and persist updated plan. + """ + instance_id = payload["instance_id"] + agent = payload["agent"] + step_id = payload["step_id"] + substep_id = payload["substep_id"] + task = payload["task"] + plan_objects = list(payload["plan_objects"]) + task_results = dict(payload["task_results"]) + + async def _process() -> Dict[str, Any]: + try: + await self.update_task_history_internal( + instance_id=instance_id, + agent=agent, + step=step_id, + substep=substep_id, + results=task_results, + plan=plan_objects, + ) + + progress_prompt = PROGRESS_CHECK_PROMPT.format( + task=task, + plan=json.dumps( + self._convert_plan_objects_to_dicts(plan_objects), indent=2 + ), + step=step_id, + substep=substep_id if substep_id is not None else "N/A", + results=task_results["content"], + progress_check_schema=schemas.progress_check, + ) + progress_resp = self.llm.generate( + messages=[{"role": "user", "content": progress_prompt}], + response_format=ProgressCheckOutput, + structured_mode="json", + ) + if hasattr(progress_resp, "choices") and progress_resp.choices: + data = progress_resp.choices[0].message.content + progress = ProgressCheckOutput(**json.loads(data)) + elif isinstance(progress_resp, ProgressCheckOutput): + progress = progress_resp + else: + # Best-effort parse + progress = ProgressCheckOutput( + **(progress_resp if isinstance(progress_resp, dict) else {}) + ) + + status_updates = [ + (u.model_dump() if hasattr(u, "model_dump") else u) + for u in (progress.plan_status_update or []) + ] + plan_updates = [ + (u.model_dump() if hasattr(u, "model_dump") else u) + for u in (progress.plan_restructure or []) + ] + + if status_updates or plan_updates: + updated_plan = await self.update_plan_internal( + instance_id=instance_id, + plan=plan_objects, + status_updates=status_updates, + plan_updates=plan_updates, + ) + else: + updated_plan = plan_objects + + return { + "plan": updated_plan, + "verdict": progress.verdict, + "status_updates": status_updates, + "plan_updates": plan_updates, + "status": "success", + } + + except Exception as exc: # noqa: BLE001 + logger.error("Failed to process agent response: %s", exc) + await self.rollback_agent_response_processing( + instance_id, agent, step_id, substep_id + ) + self.update_workflow_state( + instance_id=instance_id, + message={ + "name": agent, + "role": "system", + "content": f"Failed to process agent response: {exc}", + "step": step_id, + "substep": substep_id, + "timestamp": datetime.now(timezone.utc).isoformat(), + }, + ) + return { + "plan": plan_objects, + "verdict": "failed", + "status_updates": [], + "plan_updates": [], + "status": "failed", + } + + return self._run_asyncio_task(_process()) + + def _finalize_workflow_with_summary( + self, ctx: wf.WorkflowActivityContext, payload: Dict[str, Any] + ) -> str: + """ + Ask the LLM for a final summary and persist the finale (plan status + output + end time). + """ + instance_id = payload["instance_id"] + + async def _finalize() -> str: + prompt = SUMMARY_GENERATION_PROMPT.format( + task=payload["task"], + verdict=payload["verdict"], + plan=json.dumps(payload["plan_objects"], indent=2), + step=payload["step_id"], + substep=payload["substep_id"] + if payload["substep_id"] is not None + else "N/A", + agent=payload["agent"], + result=payload["result"], + ) + summary_resp = self.llm.generate( + messages=[{"role": "user", "content": prompt}] + ) + if hasattr(summary_resp, "choices") and summary_resp.choices: + summary = summary_resp.choices[0].message.content + elif hasattr(summary_resp, "results") and summary_resp.results: + # Handle LLMChatResponse with results list + summary = summary_resp.results[0].message.content + else: + # Fallback: try to extract content from the response object + summary = str(summary_resp) + + await self.finish_workflow_internal( + instance_id=instance_id, + plan=list(payload["plan_objects"]), + step=payload["step_id"], + substep=payload["substep_id"], + verdict=payload["verdict"], + summary=summary, + wf_time=payload["wf_time"], + ) + return summary + + return self._run_asyncio_task(_finalize()) diff --git a/dapr_agents/workflow/orchestrators/llm/prompts.py b/dapr_agents/agents/orchestrators/llm/prompts.py similarity index 100% rename from dapr_agents/workflow/orchestrators/llm/prompts.py rename to dapr_agents/agents/orchestrators/llm/prompts.py diff --git a/dapr_agents/workflow/orchestrators/llm/schemas.py b/dapr_agents/agents/orchestrators/llm/schemas.py similarity index 69% rename from dapr_agents/workflow/orchestrators/llm/schemas.py rename to dapr_agents/agents/orchestrators/llm/schemas.py index 7cd57415..e6c9289d 100644 --- a/dapr_agents/workflow/orchestrators/llm/schemas.py +++ b/dapr_agents/agents/orchestrators/llm/schemas.py @@ -3,8 +3,7 @@ from typing import List, Optional, Literal from pydantic import BaseModel, Field -from dapr_agents.workflow.orchestrators.llm.state import PlanStep -from dapr_agents.types.message import BaseMessage +from dapr_agents.agents.orchestrators.llm.state import PlanStep from dapr_agents.llm.utils import StructureHandler @@ -17,51 +16,6 @@ class IterablePlanStep(BaseModel): objects: List[PlanStep] = Field(description="A list of PlanStep objects") -class BroadcastMessage(BaseMessage): - """ - Represents a broadcast message from an agent. - """ - - -class AgentTaskResponse(BaseMessage): - """ - Represents a response message from an agent after completing a task. - """ - - workflow_instance_id: Optional[str] = Field( - default=None, description="Dapr workflow instance id from source if available" - ) - - -class TriggerAction(BaseModel): - """ - Represents a message used to trigger an agent's activity within the workflow by an end user. - """ - - task: Optional[str] = Field( - None, - description="The specific task to execute. If not provided, the agent can act based on its memory or predefined behavior.", - ) - workflow_instance_id: Optional[str] = Field( - default=None, description="Dapr workflow instance id from source if available" - ) - - -class InternalTriggerAction(BaseModel): - """ - Represents an internal message used by orchestrators to trigger agents. - This prevents self-triggering loops. - """ - - task: Optional[str] = Field( - None, - description="The specific task to execute. If not provided, the agent can act based on its memory or predefined behavior.", - ) - workflow_instance_id: Optional[str] = Field( - default=None, description="Dapr workflow instance id from source if available" - ) - - class NextStep(BaseModel): """ Represents the next step in a workflow, including the next agent to respond, diff --git a/dapr_agents/workflow/orchestrators/llm/state.py b/dapr_agents/agents/orchestrators/llm/state.py similarity index 96% rename from dapr_agents/workflow/orchestrators/llm/state.py rename to dapr_agents/agents/orchestrators/llm/state.py index a9154eac..8a9160cb 100644 --- a/dapr_agents/workflow/orchestrators/llm/state.py +++ b/dapr_agents/agents/orchestrators/llm/state.py @@ -87,8 +87,8 @@ class LLMWorkflowEntry(BaseModel): last_message: Optional[LLMWorkflowMessage] = Field( default=None, description="Last processed message in the workflow" ) - plan: Optional[List[PlanStep]] = Field( - None, description="Structured execution plan for the workflow." + plan: List[PlanStep] = Field( + default_factory=list, description="Structured execution plan for the workflow." ) task_history: List[TaskResult] = Field( default_factory=list, diff --git a/dapr_agents/workflow/orchestrators/llm/utils.py b/dapr_agents/agents/orchestrators/llm/utils.py similarity index 100% rename from dapr_agents/workflow/orchestrators/llm/utils.py rename to dapr_agents/agents/orchestrators/llm/utils.py diff --git a/dapr_agents/agents/orchestrators/random.py b/dapr_agents/agents/orchestrators/random.py new file mode 100644 index 00000000..09d6a09f --- /dev/null +++ b/dapr_agents/agents/orchestrators/random.py @@ -0,0 +1,327 @@ +from __future__ import annotations + +import logging +import random +from datetime import timedelta +from typing import Any, Dict, Optional + +import dapr.ext.workflow as wf +from durabletask import task as dt_task + +from dapr_agents.agents.configs import ( + AgentPubSubConfig, + AgentRegistryConfig, + AgentStateConfig, + AgentExecutionConfig, +) +from dapr_agents.agents.orchestrators.base import OrchestratorBase +from dapr_agents.agents.schemas import ( + AgentTaskResponse, + BroadcastMessage, + TriggerAction, +) +from dapr_agents.workflow.decorators.routers import message_router +from dapr_agents.workflow.utils.pubsub import broadcast_message, send_message_to_agent + +logger = logging.getLogger(__name__) + + +class RandomOrchestrator(OrchestratorBase): + """ + Workflow-native orchestrator that randomly selects an agent each turn. + + Flow: + 1) Optionally broadcast the initial task to all agents. + 2) For up to ``max_iterations``: + - Pick a random agent (avoid the most recent speaker when possible), + - Trigger it, + - Wait for a response event or a timeout, + - Use the response content as the next turn's task. + 3) Return the final content. + """ + + def __init__( + self, + *, + name: str = "RandomOrchestrator", + pubsub_config: Optional[AgentPubSubConfig] = None, + state_config: Optional[AgentStateConfig] = None, + registry_config: Optional[AgentRegistryConfig] = None, + agent_metadata: Optional[Dict[str, Any]] = None, + execution_config: Optional[AgentExecutionConfig] = None, + timeout_seconds: int = 60, + runtime: Optional[wf.WorkflowRuntime] = None, + ) -> None: + super().__init__( + name=name, + pubsub_config=pubsub_config, + state_config=state_config, + registry_config=registry_config, + agent_metadata=agent_metadata, + runtime=runtime, + ) + self.execution_config = execution_config or AgentExecutionConfig() + self.timeout = max(1, timeout_seconds) + self.current_speaker: Optional[str] = None + + # ------------------------------------------------------------------ + # Registration + # ------------------------------------------------------------------ + def register_workflows(self, runtime: wf.WorkflowRuntime) -> None: + """Register workflows and activities for the random orchestrator.""" + runtime.register_workflow(self.random_workflow) + runtime.register_workflow(self.process_agent_response) + runtime.register_activity(self._process_input_activity) + runtime.register_activity(self._broadcast_activity) + runtime.register_activity(self._select_random_speaker_activity) + runtime.register_activity(self._trigger_agent_activity) + + # ------------------------------------------------------------------ + # Workflows + # ------------------------------------------------------------------ + @message_router(message_model=TriggerAction) + def random_workflow(self, ctx: wf.DaprWorkflowContext, message: dict): + """ + Entry workflow that drives random speaker selection and response handling. + + Args: + ctx: Dapr workflow context. + message: Input payload (expects optional ``task``). + + Returns: + The final message content (str). + + Raises: + RuntimeError: If no final output is produced by the end. + """ + task = message.get("task") + instance_id = ctx.instance_id + final_output: Optional[str] = None + + for turn in range(1, self.execution_config.max_iterations + 1): + if not ctx.is_replaying: + logger.info( + "Random workflow turn %d/%d (instance=%s)", + turn, + self.execution_config.max_iterations, + instance_id, + ) + + # On first turn, normalize + broadcast initial task + if turn == 1 and task is not None: + initial_message = yield ctx.call_activity( + self._process_input_activity, + input={"task": task}, + ) + if not ctx.is_replaying: + # Console UX: show "user -> orchestrator" + self.print_interaction( + source_agent_name=initial_message.get("name", "user"), + target_agent_name=self.name, + message=initial_message.get("content", ""), + ) + yield ctx.call_activity( + self._broadcast_activity, + input={"message": initial_message}, + ) + + # Select agent + selected_agent = yield ctx.call_activity( + self._select_random_speaker_activity, + input={}, + ) + if not ctx.is_replaying: + logger.info("Selected '%s' for turn %d", selected_agent, turn) + + # Trigger agent + if not ctx.is_replaying: + self.print_interaction( + source_agent_name=self.name, + target_agent_name=selected_agent, + message="TriggerAction", + ) + yield ctx.call_activity( + self._trigger_agent_activity, + input={"name": selected_agent, "instance_id": instance_id}, + ) + + # Await response or timeout + event_task = ctx.wait_for_external_event("AgentTaskResponse") + timeout_task = ctx.create_timer(timedelta(seconds=self.timeout)) + winner = yield dt_task.when_any([event_task, timeout_task]) + + if winner == timeout_task: + if not ctx.is_replaying: + logger.warning( + "Turn %d timed out waiting for agent response (instance=%s)", + turn, + instance_id, + ) + result = { + "name": "timeout", + "content": "⏰ Timeout occurred. Continuing...", + } + else: + result = yield event_task + # Normalize + if hasattr(result, "model_dump"): + result = result.model_dump() + elif not isinstance(result, dict): + result = dict(result) + + if not ctx.is_replaying: + self.print_interaction( + source_agent_name=result.get("name", "agent"), + target_agent_name=self.name, + message=result.get("content", ""), + ) + + if turn == self.execution_config.max_iterations: + final_output = result.get("content", "") + break + + task = result.get("content") + + if final_output is None: + raise RuntimeError( + "Random workflow completed without producing a final output." + ) + + return final_output + + @message_router(message_model=AgentTaskResponse) + def process_agent_response( + self, ctx: wf.DaprWorkflowContext, message: Dict[str, Any] + ) -> None: + """ + Route agent responses back into the workflow via an external event. + + Args: + ctx: Dapr workflow context. + message: Response payload from the agent (must include workflow_instance_id). + """ + instance_id = (message or {}).get("workflow_instance_id") + if not instance_id: + logger.error("AgentTaskResponse missing workflow_instance_id; ignoring.") + return + + try: + self.raise_workflow_event( + instance_id=instance_id, + event_name="AgentTaskResponse", + data=message, + ) + except RuntimeError: + # Already logged in helper; keep workflow alive. + return + + # ------------------------------------------------------------------ + # Activities + # ------------------------------------------------------------------ + def _process_input_activity( + self, + ctx: wf.WorkflowActivityContext, + payload: Dict[str, Any], + ) -> Dict[str, Any]: + """Normalize the initial task into a user message dict.""" + task = payload.get("task") or "" + return {"role": "user", "name": "user", "content": task} + + def _broadcast_activity( + self, + ctx: wf.WorkflowActivityContext, + payload: Dict[str, Any], + ) -> None: + """Broadcast a message to all agents (if a broadcast topic is configured).""" + message = payload.get("message", {}) + if not isinstance(message, dict): + logger.debug("Skipping broadcast: payload is not a dict.") + return + if not self.broadcast_topic_name: + logger.debug("Skipping broadcast: no broadcast topic configured.") + return + + try: + agents_metadata = self.list_team_agents( + include_self=False, team=self.effective_team() + ) + except Exception: + logger.exception("Unable to load agents metadata; broadcast aborted.") + return + + message["role"] = message.get("role", "user") + message["name"] = self.name + broadcast_payload = BroadcastMessage(**message) + + async def _broadcast() -> None: + await broadcast_message( + message=broadcast_payload, + broadcast_topic=self.broadcast_topic_name, # type: ignore[union-attr] + message_bus=self.message_bus_name, # type: ignore[union-attr] + source=self.name, + agents_metadata=agents_metadata, + ) + + try: + self._run_asyncio_task(_broadcast()) + except Exception: # noqa: BLE001 + logger.exception("Failed to publish broadcast message.") + + def _select_random_speaker_activity( + self, + ctx: wf.WorkflowActivityContext, + ) -> str: + """Pick a random agent from the registry, avoiding the most recent speaker when possible.""" + try: + agents_metadata = self.list_team_agents( + include_self=False, team=self.effective_team() + ) + except Exception: + logger.exception("Unable to load agents metadata; broadcast aborted.") + return + + if not agents_metadata: + raise ValueError("No agents available for selection.") + + names = list(agents_metadata.keys()) + if len(names) > 1 and self.current_speaker in names: + names.remove(self.current_speaker) + + choice = random.choice(names) + self.current_speaker = choice + return choice + + def _trigger_agent_activity( + self, + ctx: wf.WorkflowActivityContext, + payload: Dict[str, Any], + ) -> None: + """Send a TriggerAction to a specific agent via pub/sub.""" + name = payload.get("name") + instance_id = payload.get("instance_id") + if not name or not instance_id: + logger.debug("Trigger activity missing agent name or instance id.") + return + + trigger = TriggerAction(workflow_instance_id=instance_id) + + try: + agents_metadata = self.list_team_agents( + include_self=False, team=self.effective_team() + ) + except Exception: + logger.exception("Unable to load agents metadata; broadcast aborted.") + return + + async def _trigger() -> None: + await send_message_to_agent( + source=self.name, + target_agent=name, + message=trigger, + agents_metadata=agents_metadata, + ) + + try: + self._run_asyncio_task(_trigger()) + except Exception: # noqa: BLE001 + logger.exception("Failed to trigger agent %s", name) diff --git a/dapr_agents/agents/orchestrators/roundrobin.py b/dapr_agents/agents/orchestrators/roundrobin.py new file mode 100644 index 00000000..776a9388 --- /dev/null +++ b/dapr_agents/agents/orchestrators/roundrobin.py @@ -0,0 +1,305 @@ +from __future__ import annotations + +import logging +from datetime import timedelta +from typing import Any, Dict, Optional + +import dapr.ext.workflow as wf +from durabletask import task as dt_task + +from dapr_agents.agents.configs import ( + AgentPubSubConfig, + AgentRegistryConfig, + AgentStateConfig, + AgentExecutionConfig, +) +from dapr_agents.agents.orchestrators.base import OrchestratorBase +from dapr_agents.agents.schemas import ( + AgentTaskResponse, + BroadcastMessage, + TriggerAction, +) +from dapr_agents.workflow.decorators.routers import message_router +from dapr_agents.workflow.utils.pubsub import broadcast_message, send_message_to_agent + +logger = logging.getLogger(__name__) + + +class RoundRobinOrchestrator(OrchestratorBase): + """ + Orchestrator that selects agents in strict round-robin order each turn. + - Turn 1 broadcasts the initial task (if provided). + - Each turn triggers the next agent in a deterministic order. + - Waits for response or times out. + - Feeds the response content into the next turn; returns last content. + """ + + def __init__( + self, + *, + name: str = "RoundRobinOrchestrator", + pubsub_config: Optional[AgentPubSubConfig] = None, + state_config: Optional[AgentStateConfig] = None, + registry_config: Optional[AgentRegistryConfig] = None, + execution_config: Optional[AgentExecutionConfig] = None, + agent_metadata: Optional[Dict[str, Any]] = None, + timeout_seconds: int = 60, + runtime: Optional[wf.WorkflowRuntime] = None, + ) -> None: + super().__init__( + name=name, + pubsub_config=pubsub_config, + state_config=state_config, + registry_config=registry_config, + execution_config=execution_config, + agent_metadata=agent_metadata, + runtime=runtime, + ) + self.timeout = max(1, timeout_seconds) + + # ------------------------------------------------------------------ + # Registration + # ------------------------------------------------------------------ + def register_workflows(self, runtime: wf.WorkflowRuntime) -> None: + """Register workflows and activities.""" + runtime.register_workflow(self.round_robin_workflow) + runtime.register_workflow(self.process_agent_response) + runtime.register_activity(self._process_input_activity) + runtime.register_activity(self._broadcast_activity) + runtime.register_activity(self._select_next_speaker_activity) + runtime.register_activity(self._trigger_agent_activity) + + # ------------------------------------------------------------------ + # Workflows + # ------------------------------------------------------------------ + @message_router(message_model=TriggerAction) + def round_robin_workflow(self, ctx: wf.DaprWorkflowContext, message: dict): + """ + Drive round-robin loop for up to `self.execution_config.max_iterations`. + """ + task = message.get("task") + instance_id = ctx.instance_id + final_output: Optional[str] = None + + for turn in range(1, self.execution_config.max_iterations + 1): + if not ctx.is_replaying: + logger.info( + "Round-robin turn %d/%d (instance=%s)", + turn, + self.execution_config.max_iterations, + instance_id, + ) + + # Turn 1: normalize + broadcast initial task + if turn == 1 and task is not None: + initial_message = yield ctx.call_activity( + self._process_input_activity, + input={"task": task}, + ) + if not ctx.is_replaying: + self.print_interaction( + source_agent_name=initial_message.get("name", "user"), + target_agent_name=self.name, + message=initial_message.get("content", ""), + ) + yield ctx.call_activity( + self._broadcast_activity, + input={"message": initial_message}, + ) + + # Select next agent (round-robin) + selected_agent = yield ctx.call_activity( + self._select_next_speaker_activity, + input={"turn": turn}, + ) + if not ctx.is_replaying: + logger.info("Selected '%s' for turn %d", selected_agent, turn) + + # Trigger agent + if not ctx.is_replaying: + self.print_interaction( + source_agent_name=self.name, + target_agent_name=selected_agent, + message="TriggerAction", + ) + yield ctx.call_activity( + self._trigger_agent_activity, + input={"name": selected_agent, "instance_id": instance_id}, + ) + + # Await response or timeout + event_task = ctx.wait_for_external_event("AgentTaskResponse") + timeout_task = ctx.create_timer(timedelta(seconds=self.timeout)) + winner = yield dt_task.when_any([event_task, timeout_task]) + + if winner == timeout_task: + if not ctx.is_replaying: + logger.warning( + "Turn %d timed out waiting for agent response (instance=%s)", + turn, + instance_id, + ) + result = { + "name": "timeout", + "content": "⏰ Timeout occurred. Continuing...", + } + else: + result = yield event_task + # Normalize + if hasattr(result, "model_dump"): + result = result.model_dump() + elif not isinstance(result, dict): + result = dict(result) + + if not ctx.is_replaying: + self.print_interaction( + source_agent_name=result.get("name", "agent"), + target_agent_name=self.name, + message=result.get("content", ""), + ) + + if turn == self.execution_config.max_iterations: + final_output = result.get("content", "") + break + + task = result.get("content") + + if final_output is None: + raise RuntimeError( + "RoundRobin workflow completed without producing a final output." + ) + return final_output + + @message_router(message_model=AgentTaskResponse) + def process_agent_response( + self, ctx: wf.DaprWorkflowContext, message: Dict[str, Any] + ) -> None: + """ + Route agent responses back into the workflow via an external event. + """ + instance_id = (message or {}).get("workflow_instance_id") + if not instance_id: + logger.error("AgentTaskResponse missing workflow_instance_id; ignoring.") + return + + try: + self.raise_workflow_event( + instance_id=instance_id, + event_name="AgentTaskResponse", + data=message, + ) + except RuntimeError: + return # already logged + + # ------------------------------------------------------------------ + # Activities + # ------------------------------------------------------------------ + def _process_input_activity( + self, + ctx: wf.WorkflowActivityContext, + payload: Dict[str, Any], + ) -> Dict[str, Any]: + """Normalize the initial task into a user message dict.""" + task = payload.get("task") or "" + return {"role": "user", "name": "user", "content": task} + + def _broadcast_activity( + self, + ctx: wf.WorkflowActivityContext, + payload: Dict[str, Any], + ) -> None: + """Broadcast a message to all agents (if a broadcast topic is configured).""" + message = payload.get("message", {}) + if not isinstance(message, dict): + logger.debug("Skipping broadcast: payload is not a dict.") + return + if not self.broadcast_topic_name: + logger.debug("Skipping broadcast: no broadcast topic configured.") + return + + try: + agents_metadata = self.list_team_agents( + include_self=False, team=self.effective_team() + ) + except Exception: + logger.exception("Unable to load agents metadata; broadcast aborted.") + return + + message["role"] = message.get("role", "user") + message["name"] = self.name + broadcast_payload = BroadcastMessage(**message) + + async def _broadcast() -> None: + await broadcast_message( + message=broadcast_payload, + broadcast_topic=self.broadcast_topic_name, # type: ignore[union-attr] + message_bus=self.message_bus_name, # type: ignore[union-attr] + source=self.name, + agents_metadata=agents_metadata, + ) + + try: + self._run_asyncio_task(_broadcast()) + except Exception: + logger.exception("Failed to publish broadcast message.") + + def _select_next_speaker_activity( + self, + ctx: wf.WorkflowActivityContext, + payload: Dict[str, Any], + ) -> str: + """ + Deterministic round-robin selection based on turn number. + """ + turn = int(payload.get("turn", 1)) + try: + agents_metadata = self.list_team_agents( + include_self=False, team=self.effective_team() + ) + except Exception as exc: + logger.error("Unable to load agents metadata: %s", exc) + raise + + if not agents_metadata: + raise ValueError("No agents available for selection.") + + # Deterministic ordering across runs/replay + names = sorted(agents_metadata.keys()) + idx = (max(1, turn) - 1) % len(names) + return names[idx] + + def _trigger_agent_activity( + self, + ctx: wf.WorkflowActivityContext, + payload: Dict[str, Any], + ) -> None: + """Send a TriggerAction to a specific agent via pub/sub.""" + name = payload.get("name") + instance_id = payload.get("instance_id") + if not name or not instance_id: + logger.debug("Trigger activity missing agent name or instance id.") + return + + trigger = TriggerAction(workflow_instance_id=instance_id) + + try: + agents_metadata = self.list_team_agents( + include_self=False, + team=self.effective_team(), + ) + except Exception: + logger.exception("Unable to load agents metadata for trigger.") + return + + async def _trigger() -> None: + await send_message_to_agent( + source=self.name, + target_agent=name, + message=trigger, + agents_metadata=agents_metadata, + ) + + try: + self._run_asyncio_task(_trigger()) + except Exception: + logger.exception("Failed to trigger agent %s", name) diff --git a/dapr_agents/agents/prompting.py b/dapr_agents/agents/prompting.py new file mode 100644 index 00000000..fbbf70e7 --- /dev/null +++ b/dapr_agents/agents/prompting.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import logging +import re +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +from dapr_agents.agents.configs import AgentProfileConfig, PromptSection +from dapr_agents.agents.utils.text_printer import ColorTextFormatter +from dapr_agents.prompt.base import PromptTemplateBase +from dapr_agents.prompt.chat import ChatPromptTemplate +from dapr_agents.types import MessagePlaceHolder + +logger = logging.getLogger(__name__) + +_JINJA_PLACEHOLDER_PATTERN = re.compile(r"(? str: + """Convert single-brace placeholders to Jinja without touching existing Jinja blocks.""" + return _JINJA_PLACEHOLDER_PATTERN.sub(r"{{\1}}", text) + + +@dataclass +class PromptSpec: + """Declarative description used to build a system prompt.""" + + name: str = "Agent" + role: str = "Assistant" + goal: str = "Help users accomplish their tasks." + instructions: List[str] = field(default_factory=list) + style_guidelines: List[str] = field(default_factory=list) + system_prompt: Optional[str] = None + template_format: str = "jinja2" + + def build_sections(self) -> List[PromptSection]: + sections: List[PromptSection] = [ + PromptSection(title="Name", lines=["Your name is {{name}}."]), + PromptSection(title="Role", lines=["You are {{role}}."]), + PromptSection(title="Goal", lines=["Your goal is {{goal}}."]), + ] + + if self.instructions: + sections.append( + PromptSection(title="Primary Instructions", lines=self.instructions) + ) + + if self.style_guidelines: + sections.append( + PromptSection(title="Communication Style", lines=self.style_guidelines) + ) + + return sections + + def render_system_prompt(self) -> str: + if self.system_prompt: + return self._apply_template_format(self.system_prompt) + + date_line = datetime.now().strftime("%B %d, %Y") + header = f"# Today's date is: {date_line}" + rendered_sections: List[str] = [] + for section in self.build_sections(): + rendered = section.render(self.template_format) + if rendered: + rendered_sections.append(rendered) + sections = "\n\n".join(rendered_sections) + prompt = f"{header}\n\n{sections}".strip() + return self._apply_template_format(prompt) + + def _apply_template_format(self, text: str) -> str: + if self.template_format == "jinja2": + return _ensure_jinja_placeholders(text) + return text + + +class PromptTemplateFactory: + """Utility for constructing chat prompt templates from PromptSpec objects.""" + + @staticmethod + def build( + spec: PromptSpec, + *, + template_format: Optional[str] = None, + include_placeholders: Sequence[str] = ("chat_history",), + extra_messages: Optional[Iterable[Tuple[str, str]]] = None, + extra_sections: Optional[Iterable[PromptSection]] = None, + ) -> ChatPromptTemplate: + template_format = template_format or spec.template_format + system_message = spec.render_system_prompt() + + messages: List[Union[Tuple[str, str], Dict[str, Any], MessagePlaceHolder]] = [ + ("system", system_message) + ] + + if extra_sections: + for section in extra_sections: + rendered = section.render(template_format) + if rendered: + messages.append(("system", rendered)) + + for placeholder in include_placeholders: + messages.append(MessagePlaceHolder(variable_name=placeholder)) + + if extra_messages: + messages.extend(extra_messages) + + template = ChatPromptTemplate.from_messages( + messages=messages, + template_format=template_format, + ) + template.input_variables = sorted(set(template.input_variables + ["name"])) + return template + + +PROMPT_MODULE_REGISTRY: Dict[ + str, Callable[["AgentProfileConfig"], Optional[PromptSection]] +] = {} + + +def register_prompt_module( + name: str, + factory: Callable[["AgentProfileConfig"], Optional[PromptSection]], +) -> None: + """Register a callable that returns an additional prompt section.""" + PROMPT_MODULE_REGISTRY[name] = factory + + +def prompt_spec_from_profile( + profile: AgentProfileConfig, + *, + default_name: str, +) -> Tuple[PromptSpec, List[PromptSection]]: + """Build a PromptSpec and collect additional sections from the profile configuration.""" + + spec = PromptSpec( + name=profile.name or default_name, + role=profile.role or "Assistant", + goal=profile.goal or "Help users accomplish their tasks.", + instructions=list(profile.instructions), + style_guidelines=list(profile.style_guidelines), + system_prompt=profile.system_prompt, + template_format=profile.template_format, + ) + + extra_sections: List[PromptSection] = [] + + for module_name in profile.modules: + if module_name in profile.module_overrides: + section = profile.module_overrides[module_name] + else: + factory = PROMPT_MODULE_REGISTRY.get(module_name) + section = factory(profile) if factory else None + if section: + extra_sections.append(section) + + return spec, extra_sections + + +@dataclass +class PromptingAgentBase: + """ + Lightweight helper focused on prompt construction and history injection. + + This class intentionally omits workflow, tool-execution, and LLM wiring; higher-level + agent implementations can compose this helper alongside their own runtime concerns. + """ + + name: str = "Durable Agent" + role: str = "Assistant" + goal: str = "Deliver helpful responses." + instructions: List[str] = field(default_factory=list) + style_guidelines: List[str] = field(default_factory=list) + system_prompt: Optional[str] = None + template_format: str = "jinja2" + include_chat_history: bool = True + prompt_template: Optional[PromptTemplateBase] = None + profile_config: Optional[AgentProfileConfig] = None + + def __post_init__(self) -> None: + self._text_formatter = ColorTextFormatter() + + spec: PromptSpec + extra_sections: List[PromptSection] = [] + if self.profile_config is not None: + spec, extra_sections = prompt_spec_from_profile( + self.profile_config, + default_name=self.name, + ) + # adopt values from spec so other consumers see resolved persona + self.name = spec.name + self.role = spec.role + self.goal = spec.goal + self.instructions = list(spec.instructions) + self.style_guidelines = list(spec.style_guidelines) + self.system_prompt = spec.system_prompt + self.template_format = spec.template_format + else: + spec = self.to_prompt_spec() + + if self.prompt_template is None: + placeholders = ["chat_history"] if self.include_chat_history else [] + self.prompt_template = PromptTemplateFactory.build( + spec, + include_placeholders=placeholders, + extra_sections=extra_sections, + ) + + self._prefill_prompt_variables() + + def to_prompt_spec(self) -> PromptSpec: + return PromptSpec( + name=self.name, + role=self.role, + goal=self.goal, + instructions=list(self.instructions), + style_guidelines=list(self.style_guidelines), + system_prompt=self.system_prompt, + template_format=self.template_format, + ) + + def rebuild_prompt_template( + self, + *, + spec: Optional[PromptSpec] = None, + include_placeholders: Optional[Sequence[str]] = None, + ) -> None: + extra_sections: List[PromptSection] = [] + if spec is None: + if self.profile_config is not None: + spec, extra_sections = prompt_spec_from_profile( + self.profile_config, + default_name=self.name, + ) + else: + spec = self.to_prompt_spec() + placeholders = ( + include_placeholders + if include_placeholders is not None + else (["chat_history"] if self.include_chat_history else []) + ) + self.prompt_template = PromptTemplateFactory.build( + spec, + include_placeholders=placeholders, + extra_sections=extra_sections, + ) + self._prefill_prompt_variables() + + def _prefill_prompt_variables(self) -> None: + if not self.prompt_template: + return + + variables: Dict[str, Any] = { + "name": self.name, + "role": self.role, + "goal": self.goal, + } + if self.instructions: + variables["instructions"] = "\n".join(self.instructions) + if self.style_guidelines: + variables["style_guidelines"] = "\n".join(self.style_guidelines) + + self.prompt_template = self.prompt_template.pre_fill_variables(**variables) + + def build_initial_messages( + self, + user_input: Optional[Union[str, Dict[str, Any]]] = None, + chat_history: Optional[List[Dict[str, Any]]] = None, + **extra_variables: Any, + ) -> List[Dict[str, Any]]: + if not self.prompt_template: + raise ValueError("Prompt template has not been initialised.") + + variables = dict(extra_variables) + if self.include_chat_history and chat_history is not None: + variables.setdefault("chat_history", chat_history) + + messages = self.prompt_template.format_prompt(**variables) + + if isinstance(user_input, str): + messages.append({"role": "user", "content": user_input}) + elif isinstance(user_input, dict): + messages.append(user_input) + + return messages + + @property + def text_formatter(self) -> ColorTextFormatter: + return self._text_formatter diff --git a/dapr_agents/agents/durableagent/state.py b/dapr_agents/agents/schemas.py similarity index 57% rename from dapr_agents/agents/durableagent/state.py rename to dapr_agents/agents/schemas.py index 242dca22..73205d8d 100644 --- a/dapr_agents/agents/durableagent/state.py +++ b/dapr_agents/agents/schemas.py @@ -1,44 +1,86 @@ +import uuid +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + from pydantic import BaseModel, Field -from typing import List, Optional, Dict, Any + from dapr_agents.types import MessageContent, ToolExecutionRecord +from dapr_agents.types.message import BaseMessage from dapr_agents.types.workflow import DaprWorkflowStatus -from datetime import datetime -import uuid -class DurableAgentMessage(MessageContent): +def utcnow() -> datetime: + """Return current time as timezone-aware UTC datetime.""" + return datetime.now(timezone.utc) + + +class BroadcastMessage(BaseMessage): + """ + Represents a broadcast message from an agent. + """ + + +class AgentTaskResponse(BaseMessage): + """ + Represents a response message from an agent after completing a task. + """ + + workflow_instance_id: Optional[str] = Field( + default=None, description="Dapr workflow instance id from source if available" + ) + + +class TriggerAction(BaseModel): + """ + Represents a message used to trigger an agent's activity within the workflow. + """ + + task: Optional[str] = Field( + None, + description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.", + ) + workflow_instance_id: Optional[str] = Field( + default=None, description="Dapr workflow instance id from source if available" + ) + + +class AgentWorkflowMessage(MessageContent): id: str = Field( default_factory=lambda: str(uuid.uuid4()), description="Unique identifier for the message", ) timestamp: datetime = Field( - default_factory=datetime.now, + default_factory=utcnow, description="Timestamp when the message was created", ) -class DurableAgentWorkflowEntry(BaseModel): +class AgentWorkflowEntry(BaseModel): """Represents a workflow and its associated data, including metadata on the source of the task request.""" - input: str = Field( + input_value: str = Field( ..., description="The input or description of the Workflow to be performed" ) output: Optional[str] = Field( default=None, description="The output or result of the Workflow, if completed" ) start_time: datetime = Field( - default_factory=datetime.now, + default_factory=utcnow, description="Timestamp when the workflow was started", ) end_time: Optional[datetime] = Field( default=None, description="Timestamp when the workflow was completed or failed", ) - messages: List[DurableAgentMessage] = Field( + messages: List[AgentWorkflowMessage] = Field( default_factory=list, description="Messages exchanged during the workflow (user, assistant, or tool messages).", ) - last_message: Optional[DurableAgentMessage] = Field( + system_messages: List[AgentWorkflowMessage] = Field( + default_factory=list, + description="Rendered system prompt messages included when invoking the LLM.", + ) + last_message: Optional[AgentWorkflowMessage] = Field( default=None, description="Last processed message in the workflow" ) tool_history: List[ToolExecutionRecord] = Field( @@ -57,6 +99,10 @@ class DurableAgentWorkflowEntry(BaseModel): default=None, description="The name of the workflow.", ) + session_id: Optional[str] = Field( + default=None, + description="Conversation memory session identifier, when available.", + ) trace_context: Optional[Dict[str, Any]] = Field( default=None, description="OpenTelemetry trace context for workflow resumption.", @@ -67,14 +113,10 @@ class DurableAgentWorkflowEntry(BaseModel): ) -class DurableAgentWorkflowState(BaseModel): +class AgentWorkflowState(BaseModel): """Represents the state of multiple Agent workflows.""" - instances: Dict[str, DurableAgentWorkflowEntry] = Field( + instances: Dict[str, AgentWorkflowEntry] = Field( default_factory=dict, description="Workflow entries indexed by their instance_id.", ) - chat_history: List[DurableAgentMessage] = Field( - default_factory=list, - description="Chat history of messages exchanged during the workflow.", - ) diff --git a/dapr_agents/agents/standalone.py b/dapr_agents/agents/standalone.py new file mode 100644 index 00000000..bf9dce39 --- /dev/null +++ b/dapr_agents/agents/standalone.py @@ -0,0 +1,471 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import signal +import uuid +from datetime import datetime, timezone +from typing import Any, Awaitable, Dict, Iterable, List, Optional, Sequence, Union + +from dapr_agents.agents.base import AgentBase +from dapr_agents.agents.configs import ( + AgentMemoryConfig, + AgentProfileConfig, + AgentRegistryConfig, + AgentStateConfig, + AgentExecutionConfig, +) +from dapr_agents.llm.chat import ChatClientBase +from dapr_agents.prompt.base import PromptTemplateBase +from dapr_agents.types import ( + AgentError, + AssistantMessage, + LLMChatResponse, + ToolCall, + ToolExecutionRecord, + ToolMessage, +) +from dapr_agents.types.workflow import DaprWorkflowStatus + +logger = logging.getLogger(__name__) + + +class Agent(AgentBase): + """ + Standalone (non-workflow) agent built on AgentBase. + + Overview: + Reuses AgentBase for profile/prompting, LLM wiring, memory, and durable state. + Runs an in-process conversation loop (no Dapr Workflows), but persists + per-instance timeline using the flexible state model (custom `instances` layout allowed). + """ + + def __init__( + self, + *, + # Profile / prompt + profile_config: Optional[AgentProfileConfig] = None, + name: Optional[str] = None, + role: Optional[str] = None, + goal: Optional[str] = None, + instructions: Optional[Iterable[str]] = None, + style_guidelines: Optional[Iterable[str]] = None, + system_prompt: Optional[str] = None, + prompt_template: Optional[PromptTemplateBase] = None, + # Runtime + llm: Optional[ChatClientBase] = None, + tools: Optional[Iterable[Any]] = None, + memory_config: Optional[AgentMemoryConfig] = None, + # Persistence/registry + state_config: Optional[AgentStateConfig] = None, + registry_config: Optional[AgentRegistryConfig] = None, + # Behavior / execution + execution_config: Optional[AgentExecutionConfig] = None, + # Misc + agent_metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Initialize behavior + infrastructure for a non-workflow agent. + + Args: + profile_config: High-level profile (can be overridden by explicit fields). + name, role, goal, instructions, style_guidelines, system_prompt: Prompt/profile fields. + prompt_template: Optional explicit prompt template instance. + llm: Chat client; defaults to `get_default_llm()`. + tools: Optional tool callables or `AgentTool` instances. + memory_config: Conversation memory config. + state_config: Durable state configuration/model customization. + registry_config: Team registry configuration. + execution_config: Execution dials for the agent run. + agent_metadata: Extra metadata to store in the registry. + """ + super().__init__( + profile_config=profile_config, + name=name, + role=role, + goal=goal, + instructions=instructions, + style_guidelines=style_guidelines, + system_prompt=system_prompt, + state_config=state_config, + memory_config=memory_config, + registry_config=registry_config, + execution_config=execution_config, + agent_metadata=agent_metadata, + llm=llm, + tools=tools, + prompt_template=prompt_template, + ) + + self._shutdown_event = asyncio.Event() + self._setup_signal_handlers() + + try: + self.load_state() + except Exception: + logger.debug( + "Standalone agent state load failed; using defaults.", exc_info=True + ) + + # ------------------------------------------------------------------ + # Public entrypoint + # ------------------------------------------------------------------ + async def run( + self, + input_data: Optional[Union[str, Dict[str, Any]]] = None, + *, + instance_id: Optional[str] = None, + ) -> Optional[AssistantMessage]: + """ + Execute a conversational run in-process. + + Args: + input_data: Optional user input (string or structured dict). + instance_id: Optional workflow-like instance id; auto-generated if omitted. + + Returns: + The final assistant message (if not cancelled), else None. + + Raises: + AgentError: Propagates structured errors from generation/tooling. + """ + try: + return await self._race( + self._run_agent(input_data=input_data, instance_id=instance_id) + ) + except asyncio.CancelledError: + logger.info("Standalone agent run was cancelled.") + return None + except Exception as exc: + logger.exception("Standalone agent run failed: %s", exc) + raise + + # ------------------------------------------------------------------ + # Internal runtime loop + # ------------------------------------------------------------------ + async def _race( + self, coro: Awaitable[Optional[AssistantMessage]] + ) -> Optional[AssistantMessage]: + """Race the agent execution against shutdown signals.""" + task = asyncio.create_task(coro) + shutdown_task = asyncio.create_task(self._shutdown_event.wait()) + done, pending = await asyncio.wait( + [task, shutdown_task], + return_when=asyncio.FIRST_COMPLETED, + ) + for pending_task in pending: + pending_task.cancel() + if self._shutdown_event.is_set(): + logger.info( + "Shutdown requested during execution; cancelling standalone run." + ) + task.cancel() + return None + return await task + + async def _run_agent( + self, + *, + input_data: Optional[Union[str, Dict[str, Any]]], + instance_id: Optional[str], + ) -> Optional[AssistantMessage]: + """One-shot conversational run with tool loop and durable timeline.""" + self.load_state() + active_instance = instance_id or self._generate_instance_id() + + # Build initial messages with persistent + per-instance history + chat_history = self._construct_messages_with_instance_history(active_instance) + messages = self.prompting_helper.build_initial_messages( + user_input=input_data, + chat_history=chat_history, + ) + + # Keep per-instance system messages in sync with state + self._sync_system_messages_with_state(active_instance, messages) + + # Print + capture the user's message if present + user_message = self._get_last_user_message(messages) + user_message_copy = dict(user_message) if user_message else None + task_text = user_message_copy.get("content") if user_message_copy else None + + if user_message_copy is not None: + self.text_formatter.print_message( + {str(k): v for k, v in user_message_copy.items()} + ) + + # Ensure instance exists (flexible model via _get_entry_container) + created_instance = active_instance not in ( + getattr(self.workflow_state, "instances", {}) or {} + ) + self.ensure_instance_exists( + instance_id=active_instance, + input_value=task_text or "Triggered without input.", + triggering_workflow_instance_id=None, + time=datetime.now(timezone.utc), + ) + if created_instance: + self.save_state() + + # Persist the user message into timeline + memory + self._process_user_message(active_instance, task_text, user_message_copy) + + # Enter the tool/LLM loop + final_reply = await self._conversation_loop( + instance_id=active_instance, + messages=messages, + ) + + return final_reply + + def construct_messages( + self, + input_data: Optional[Union[str, Dict[str, Any]]] = None, + *, + instance_id: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """ + Build the outbound message list (without running the loop). + + Args: + input_data: Optional user input. + instance_id: Optional instance id to load per-instance history. + + Returns: + List of message dicts suitable for an LLM chat API. + """ + self.load_state() + active_instance = instance_id or self._generate_instance_id() + chat_history = self._construct_messages_with_instance_history(active_instance) + return self.prompting_helper.build_initial_messages( + user_input=input_data, + chat_history=chat_history, + ) + + async def _conversation_loop( + self, + *, + instance_id: str, + messages: List[Dict[str, Any]], + ) -> Optional[AssistantMessage]: + """ + Core in-process loop that alternates LLM responses and tool calls. + + Args: + instance_id: Per-run id used to persist timeline to state. + messages: Initial outbound messages (system + history + user). + + Returns: + Final assistant message, or None if cancelled mid-run. + + Raises: + AgentError: If chat generation fails or no assistant message is produced. + """ + pending_messages = list(messages) + final_reply: Optional[AssistantMessage] = None + + for turn in range(1, self.execution_config.max_iterations + 1): + logger.info( + "Iteration %d/%d started.", turn, self.execution_config.max_iterations + ) + try: + response: LLMChatResponse = self.llm.generate( + messages=pending_messages, + tools=self.get_llm_tools(), + **( + {"tool_choice": self.execution_config.tool_choice} + if self.execution_config.tool_choice is not None + else {} + ), + ) + assistant_message = response.get_message() + if assistant_message is None: + raise AgentError("LLM returned no assistant message.") + + assistant_dict = assistant_message.model_dump() + self._save_assistant_message(instance_id, assistant_dict) + self.text_formatter.print_message(assistant_dict) + + if assistant_message.has_tool_calls(): + tool_calls = assistant_message.get_tool_calls() + if tool_calls: + pending_messages.append(assistant_dict) + tool_msgs = await self._execute_tool_calls( + instance_id, tool_calls + ) + pending_messages.extend(tool_msgs) + if turn == self.execution_config.max_iterations: + final_reply = assistant_message + logger.info( + "Reached max iterations after tool calls; stopping." + ) + break + continue + + final_reply = assistant_message + break + except Exception as exc: + logger.error("Error on turn %d: %s", turn, exc) + raise AgentError(f"Failed during chat generation: {exc}") from exc + + self._update_instance_completion(instance_id, final_reply) + return final_reply + + async def _execute_tool_calls( + self, + instance_id: str, + tool_calls: Sequence[ToolCall], + ) -> List[Dict[str, Any]]: + """ + Execute tool calls concurrently with bounded parallelism. + + Args: + instance_id: Timeline instance id to append tool results to. + tool_calls: ToolCall objects from the assistant. + + Returns: + List of tool message dicts to append to the LLM turn. + """ + max_concurrent = 10 + semaphore = asyncio.Semaphore(max_concurrent) + + async def run_single(tool_call: ToolCall) -> Dict[str, Any]: + async with semaphore: + return await self._run_tool_call(instance_id, tool_call) + + return await asyncio.gather(*(run_single(call) for call in tool_calls)) + + async def _run_tool_call( + self, instance_id: str, tool_call: ToolCall + ) -> Dict[str, Any]: + """ + Execute one tool call and persist outcome to state + memory. + + Args: + instance_id: Timeline instance id to append messages/history. + tool_call: ToolCall describing the function and arguments. + + Returns: + A tool message dict ({role:"tool", name, content, id}) + + Raises: + AgentError: On invalid call or tool execution failure. + """ + function_name = tool_call.function.name + if not function_name: + error_msg = f"Tool call missing function name: {tool_call}" + logger.error(error_msg) + raise AgentError(error_msg) + + function_args = tool_call.function.arguments_dict + try: + result = await self.tool_executor.run_tool(function_name, **function_args) + except Exception as exc: + logger.error("Error executing tool %s: %s", function_name, exc) + raise AgentError(f"Error executing tool '{function_name}': {exc}") from exc + + # Safe serialization of tool result + if isinstance(result, str): + serialized_result = result + else: + try: + serialized_result = json.dumps(result) + except Exception: # noqa: BLE001 + serialized_result = str(result) + + # Build memory + durable messages + tool_message = ToolMessage( + tool_call_id=tool_call.id, + name=function_name, + content=serialized_result, + ) + message_dict = tool_message.model_dump() + + history_entry = ToolExecutionRecord( + tool_call_id=tool_call.id, + tool_name=function_name, + tool_args=function_args, + execution_result=serialized_result, + ) + + # Append to durable timeline using the flexible model/coercer path + container = self._get_entry_container() + entry = container.get(instance_id) if container else None + if entry is not None and hasattr(entry, "messages"): + # Prefer a custom coercer if configured; otherwise the configured message model, with a safe fallback. + try: + if getattr(self, "_message_coercer", None): + durable_message = self._message_coercer(message_dict) # type: ignore[attr-defined] + else: + durable_message = self._message_dict_to_message_model(message_dict) + except Exception: + # Last-resort: keep the raw dict so we don't drop tool output. + durable_message = dict(message_dict) + + entry.messages.append(durable_message) + if hasattr(entry, "tool_history"): + entry.tool_history.append(history_entry) + if hasattr(entry, "last_message"): + entry.last_message = durable_message + + # Always persist to memory + in-process history + self.text_formatter.print_message(message_dict) + self.memory.add_message(tool_message) + self.tool_history.append(history_entry) + self.save_state() + + # Return tool message dict so the next LLM turn can see it + return message_dict + + # ------------------------------------------------------------------ + # Finalization + # ------------------------------------------------------------------ + def _update_instance_completion( + self, + instance_id: str, + final_reply: Optional[AssistantMessage], + ) -> None: + """ + Mark an instance as completed/failed with end time and output. + + Args: + instance_id: Timeline instance id. + final_reply: AssistantMessage (if any) that ended the loop. + """ + container = self._get_entry_container() + entry = container.get(instance_id) if container else None + if entry is None: + return + + entry.status = ( + DaprWorkflowStatus.COMPLETED.value + if final_reply + else DaprWorkflowStatus.FAILED.value + ) + if final_reply and hasattr(entry, "output"): + entry.output = final_reply.content or "" + entry.end_time = datetime.now(timezone.utc) + self.save_state() + + def _generate_instance_id(self) -> str: + """Generate a unique instance id for standalone runs.""" + return f"{self.name}-{uuid.uuid4().hex}" + + # ------------------------------------------------------------------ + # Infrastructure hooks (signals) + # ------------------------------------------------------------------ + def _setup_signal_handlers(self) -> None: + """Install SIGINT/SIGTERM handlers to allow graceful shutdown.""" + try: + signal.signal(signal.SIGINT, self._signal_handler) + signal.signal(signal.SIGTERM, self._signal_handler) + except (OSError, ValueError): + # Not available in some environments (e.g., Windows/subthreads) + pass + + def _signal_handler( + self, signum, frame + ) -> None: # pragma: no cover - signal handler + """Signal handler that asks the run loop to stop.""" + logger.info("Received signal %s. Shutting down gracefully...", signum) + self._shutdown_event.set() diff --git a/dapr_agents/memory/daprstatestore.py b/dapr_agents/memory/daprstatestore.py index f2fa3734..ee17b178 100644 --- a/dapr_agents/memory/daprstatestore.py +++ b/dapr_agents/memory/daprstatestore.py @@ -98,7 +98,9 @@ def add_message(self, message: Union[Dict[str, Any], BaseMessage]) -> None: f"Adding message {message} with key {message_key} to session {self.session_id}" ) self.dapr_store.save_state( - self.session_id, json.dumps(existing), {"contentType": "application/json"} + self.session_id, + json.dumps(existing), + state_metadata={"contentType": "application/json"}, ) def add_messages(self, messages: List[Union[Dict[str, Any], BaseMessage]]) -> None: diff --git a/dapr_agents/observability/wrappers/agent.py b/dapr_agents/observability/wrappers/agent.py index 67060752..df922e90 100644 --- a/dapr_agents/observability/wrappers/agent.py +++ b/dapr_agents/observability/wrappers/agent.py @@ -135,7 +135,9 @@ def _build_agent_attributes( "agent.goal": getattr(instance, "goal", None), "agent.tools": [tool.name for tool in getattr(instance, "tools", [])], "agent.tools.count": len(getattr(instance, "tools", [])), - "agent.max_iterations": getattr(instance, "max_iterations", None), + "agent.execution_config.max_iterations": getattr( + instance, "execution_config.max_iterations", None + ), } # Extract actual input value - the user's query/request diff --git a/dapr_agents/observability/wrappers/workflow.py b/dapr_agents/observability/wrappers/workflow.py index 8e331c86..fef5c34d 100644 --- a/dapr_agents/observability/wrappers/workflow.py +++ b/dapr_agents/observability/wrappers/workflow.py @@ -403,7 +403,7 @@ def _store_context_in_workflow_state( """ try: # Import here to avoid circular imports - from dapr_agents.agents.durableagent.state import DurableAgentWorkflowEntry + from dapr_agents.agents.durableagent.state import AgentWorkflowEntry # This is a simplified approach - in practice, you'd need to access # the actual workflow state through the Dapr Workflow runtime diff --git a/dapr_agents/service/fastapi/base.py b/dapr_agents/service/fastapi/base.py index 0dbcd7b7..f60b4c9d 100644 --- a/dapr_agents/service/fastapi/base.py +++ b/dapr_agents/service/fastapi/base.py @@ -4,7 +4,7 @@ from pydantic import Field, ConfigDict from typing import List, Optional, Any from dapr_agents.service import APIServerBase -from dapr_agents.utils import SignalHandlingMixin +from dapr_agents.utils.signal.mixin import SignalMixin import uvicorn import asyncio import logging @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) -class FastAPIServerBase(APIServerBase, SignalHandlingMixin): +class FastAPIServerBase(APIServerBase, SignalMixin): """ Abstract base class for FastAPI-based API server services. Provides core FastAPI functionality, with support for CORS, lifecycle management, and graceful shutdown. @@ -72,6 +72,7 @@ def model_post_init(self, __context: Any) -> None: # Call the base post-initialization super().model_post_init(__context) + SignalMixin.__init__(self) @asynccontextmanager async def lifespan(self, app: FastAPI): @@ -104,24 +105,27 @@ async def start(self, log_level=None): self.server: uvicorn.Server = uvicorn.Server(config) # Set up signal handlers using the mixin - self.setup_signal_handlers() + self.install_signal_handlers() - # Start in background so we can inspect the actual port - server_task = asyncio.create_task(self.server.serve()) - - # Wait for startup to complete - while not self.server.started: - await asyncio.sleep(0.1) - - # Extract the real port from the bound socket - if self.server.servers: - sock = list(self.server.servers)[0].sockets[0] - actual_port = sock.getsockname()[1] - self.service_port = actual_port - else: - logger.warning(f"{self.service_name} could not determine bound port") - - await server_task + try: + # Start in background so we can inspect the actual port + server_task = asyncio.create_task(self.server.serve()) + + # Wait for startup to complete + while not self.server.started: + await asyncio.sleep(0.1) + + # Extract the real port from the bound socket + if self.server.servers: + sock = list(self.server.servers)[0].sockets[0] + actual_port = sock.getsockname()[1] + self.service_port = actual_port + else: + logger.warning(f"{self.service_name} could not determine bound port") + + await server_task + finally: + self.remove_signal_handlers() async def graceful_shutdown(self) -> None: """ diff --git a/dapr_agents/storage/daprstores/__init__.py b/dapr_agents/storage/daprstores/__init__.py index b1f89a29..ce6595e2 100644 --- a/dapr_agents/storage/daprstores/__init__.py +++ b/dapr_agents/storage/daprstores/__init__.py @@ -1,4 +1,29 @@ from .base import DaprStoreBase from .statestore import DaprStateStore +from .stateservice import ( + StateStoreError, + StateStoreService, + load_state_dict, + load_state_with_etag, + load_state_many, + save_state_dict, + save_state_many, + delete_state, + state_exists, + execute_state_transaction, +) -__all__ = ["DaprStoreBase", "DaprStateStore"] +__all__ = [ + "DaprStoreBase", + "DaprStateStore", + "StateStoreError", + "StateStoreService", + "load_state_dict", + "load_state_with_etag", + "load_state_many", + "save_state_dict", + "save_state_many", + "delete_state", + "state_exists", + "execute_state_transaction", +] diff --git a/dapr_agents/storage/daprstores/stateservice.py b/dapr_agents/storage/daprstores/stateservice.py new file mode 100644 index 00000000..ea14d9e5 --- /dev/null +++ b/dapr_agents/storage/daprstores/stateservice.py @@ -0,0 +1,714 @@ +from __future__ import annotations + +import json +import logging +import os +import random +import tempfile +import threading +import time +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union + +from dapr.clients.grpc._response import BulkStateItem, StateResponse +from pydantic import BaseModel, ValidationError + +from dapr_agents.storage.daprstores.statestore import ( + DaprStateStore, + _coerce_state_options, +) + +logger = logging.getLogger(__name__) + + +class StateStoreError(RuntimeError): + """Raised when high-level state operations fail.""" + + +class StateStoreService: + """ + High-level state helper that composes a `DaprStateStore` instance. + + Prefer this in application code and workflow activities when you want dicts, + validation, retries, and convenient TTL support. + """ + + _mirror_lock = threading.Lock() + + def __init__( + self, + *, + store_name: str, + key_prefix: str = "", + model: Optional[Type[BaseModel]] = None, + mirror_writes: bool = False, + local_mirror_path: Optional[str] = None, + store_factory: Optional[Callable[[], DaprStateStore]] = None, + retry_attempts: int = 3, + retry_initial_backoff: float = 0.1, + retry_backoff_multiplier: float = 2.0, + retry_jitter: float = 0.1, + ) -> None: + """ + Args: + store_name: Dapr state component name (required). + key_prefix: Optional logical prefix applied to all keys (e.g., "blog:"). + model: Optional Pydantic model used to validate/shape payloads. + mirror_writes: If True, also mirror successful writes to local disk. + local_mirror_path: Directory where mirror files are written (defaults to CWD). + store_factory: Factory returning a `DaprStateStore` (DI/testing). + retry_attempts: Max attempts per Dapr call (>=1). + retry_initial_backoff: Initial backoff in seconds. + retry_backoff_multiplier: Backoff multiplier per attempt (>=1.0). + retry_jitter: Proportional jitter [0,1] applied to backoff. + """ + if not store_name: + raise StateStoreError("State store name is required (store_name).") + + self.store_name = store_name + self.key_prefix = key_prefix + self.model = model + self.mirror_writes = mirror_writes + self.local_mirror_path = local_mirror_path + self.retry_attempts = max(1, retry_attempts) + self.retry_initial_backoff = max(0.0, retry_initial_backoff) + self.retry_backoff_multiplier = max(1.0, retry_backoff_multiplier) + self.retry_jitter = max(0.0, retry_jitter) + + self._store_factory = store_factory or ( + lambda: DaprStateStore(store_name=self.store_name) + ) + self._store_cached: Optional[DaprStateStore] = None + + def _store(self) -> DaprStateStore: + """Return the lazily-constructed `DaprStateStore`.""" + if self._store_cached is None: + self._store_cached = self._store_factory() + return self._store_cached + + def _qualify(self, key: str) -> str: + """Apply the configured key prefix to a logical key.""" + return f"{self.key_prefix}{key}" if self.key_prefix else key + + def _strip_prefix(self, qualified: str) -> str: + """Remove the configured key prefix from a qualified key (best effort).""" + if self.key_prefix and qualified.startswith(self.key_prefix): + return qualified[len(self.key_prefix) :] + return qualified + + def _with_retries(self, func: Callable[[], Any]) -> Any: + """Execute a callable with retry/backoff/jitter.""" + delay = self.retry_initial_backoff + attempt = 0 + while True: + try: + return func() + except Exception as exc: # noqa: BLE001 + attempt += 1 + if attempt >= self.retry_attempts: + raise + sleep_for = delay * ( + 1 + random.uniform(-self.retry_jitter, self.retry_jitter) + ) + if sleep_for > 0: + time.sleep(max(0.0, sleep_for)) + delay *= self.retry_backoff_multiplier + logger.debug( + "Retrying state operation after error: %s", exc, exc_info=True + ) + + def _model_dump(self, model: BaseModel) -> Dict[str, Any]: + """Dump a Pydantic model to dict (v2/v1 support).""" + if hasattr(model, "model_dump"): + return model.model_dump() + if hasattr(model, "dict"): + return model.dict() + raise StateStoreError(f"Unsupported pydantic model type: {type(model)}") + + def _ensure_dict(self, value: Any) -> Dict[str, Any]: + """ + Coerce value into a dict. + + Accepts: + - dict (returned as-is) + - pydantic BaseModel (dumped) + - JSON str (parsed to dict) + - JSON bytes (decoded to str, parsed to dict) + """ + if isinstance(value, BaseModel): + return self._model_dump(value) + if isinstance(value, dict): + return value + if isinstance(value, str): + try: + parsed = json.loads(value) + except json.JSONDecodeError as exc: + raise StateStoreError(f"State string is not valid JSON: {exc}") from exc + if not isinstance(parsed, dict): + raise StateStoreError(f"Expected dict JSON, got {type(parsed)}") + return parsed + if isinstance(value, bytes): + return self._ensure_dict(value.decode("utf-8")) + raise StateStoreError( + f"Unsupported state type: {type(value)}. Expected dict, BaseModel, str, or bytes." + ) + + def _validate_model( + self, payload: Dict[str, Any], *, return_model: bool = False + ) -> Union[Dict[str, Any], BaseModel]: + """Validate payload with configured Pydantic model (if any).""" + if not self.model: + return payload + try: + parsed = self.model(**payload) + except ValidationError as exc: + raise StateStoreError(f"State validation failed: {exc.errors()}") from exc + return parsed if return_model else self._model_dump(parsed) + + def _save_local_copy(self, *, key: str, data: Dict[str, Any]) -> None: + """ + Write/merge a pretty-printed JSON file for the qualified key (debug/dev). + + Uses a temp file and atomic replace to avoid partial writes. + """ + directory = self.local_mirror_path or os.getcwd() + os.makedirs(directory, exist_ok=True) + filename = f"{key}.json" + file_path = os.path.join(directory, filename) + + tmp_fd, tmp_path = tempfile.mkstemp(dir=directory) + os.close(tmp_fd) + try: + with self._mirror_lock: + existing: Dict[str, Any] = {} + if os.path.exists(file_path): + try: + with open(file_path, "r", encoding="utf-8") as current: + existing = json.load(current) + except json.JSONDecodeError: + logger.debug( + "Existing state file corrupt; overwriting", exc_info=True + ) + + merged = _deep_merge(existing, data) + + with open(tmp_path, "w", encoding="utf-8") as tmp_file: + json.dump(merged, tmp_file, indent=2) + + os.replace(tmp_path, file_path) + except Exception as exc: # noqa: BLE001 + logger.error("Failed to mirror state locally", exc_info=True) + raise StateStoreError(f"Failed to save local state mirror: {exc}") from exc + finally: + if os.path.exists(tmp_path): + try: + os.remove(tmp_path) + except Exception: # noqa: BLE001 + pass + + def load( + self, + *, + key: str, + default: Optional[Dict[str, Any]] = None, + state_metadata: Optional[Dict[str, str]] = None, + return_model: bool = False, + ) -> Union[Dict[str, Any], BaseModel]: + """ + Load a JSON dict from the state store. + + Args: + key: Logical (unprefixed) key. + default: Returned if the item does not exist (must be dict or None). + state_metadata: Optional Dapr metadata. + return_model: If True and a Pydantic model is configured, return model instance. + + Returns: + Dict payload or model instance. + """ + qualified = self._qualify(key) + logger.debug("Loading state from %s key=%s", self.store_name, qualified) + + def call() -> StateResponse: + return self._store().get_state( + qualified, + state_metadata=state_metadata, + ) + + try: + response = self._with_retries(call) + except Exception as exc: # noqa: BLE001 + raise StateStoreError( + f"Failed to load state for key '{qualified}': {exc}" + ) from exc + + if not response or not getattr(response, "data", None): + return default.copy() if isinstance(default, dict) else (default or {}) + + try: + state_data = response.json() + except Exception as exc: # noqa: BLE001 + raise StateStoreError( + f"State for key '{qualified}' is not valid JSON: {exc}" + ) from exc + + if not isinstance(state_data, dict): + raise StateStoreError( + f"State for key '{qualified}' must be a dict, got {type(state_data)}" + ) + + return self._validate_model(state_data, return_model=return_model) + + def load_with_etag( + self, + *, + key: str, + default: Optional[Dict[str, Any]] = None, + state_metadata: Optional[Dict[str, str]] = None, + return_model: bool = False, + ) -> Tuple[Union[Dict[str, Any], BaseModel], Optional[str]]: + """ + Load a JSON dict and return `(payload, etag)`. + + Args: + key: Logical (unprefixed) key. + default: Returned payload when not found. + state_metadata: Optional Dapr metadata. + return_model: If True and model configured, return model instance. + + Returns: + (dict_or_model, etag_or_none) + """ + qualified = self._qualify(key) + logger.debug( + "Loading state with etag from %s key=%s", self.store_name, qualified + ) + + def call() -> StateResponse: + return self._store().get_state( + qualified, + state_metadata=state_metadata, + ) + + try: + response = self._with_retries(call) + except Exception as exc: # noqa: BLE001 + raise StateStoreError( + f"Failed to load state for key '{qualified}': {exc}" + ) from exc + + if not response or not getattr(response, "data", None): + data = default.copy() if isinstance(default, dict) else (default or {}) + return data, None + + try: + state_data = response.json() + except Exception as exc: # noqa: BLE001 + raise StateStoreError( + f"State for key '{qualified}' is not valid JSON: {exc}" + ) from exc + + if not isinstance(state_data, dict): + raise StateStoreError( + f"State for key '{qualified}' must be a dict, got {type(state_data)}" + ) + + payload = self._validate_model(state_data, return_model=return_model) + etag = getattr(response, "etag", None) + return payload, etag + + def load_many( + self, + keys: Sequence[str], + *, + parallelism: int = 1, + state_metadata: Optional[Dict[str, str]] = None, + return_model: bool = False, + ) -> Dict[str, Union[Dict[str, Any], BaseModel]]: + """ + Bulk load multiple keys. + + Args: + keys: Logical (unprefixed) keys. + parallelism: Backend-specific parallelism. + state_metadata: Optional Dapr metadata. + return_model: If True and model configured, return model instances. + + Returns: + Mapping of logical key -> dict/model for keys that existed. + """ + qualified_keys = [self._qualify(k) for k in keys] + logger.debug( + "Loading bulk state from %s keys=%s", self.store_name, qualified_keys + ) + + def call() -> Sequence[BulkStateItem]: + return self._store().get_bulk_state( + qualified_keys, + parallelism=parallelism, + states_metadata=state_metadata, + ) + + try: + items = self._with_retries(call) + except Exception as exc: # noqa: BLE001 + raise StateStoreError(f"Failed to bulk load state: {exc}") from exc + + results: Dict[str, Union[Dict[str, Any], BaseModel]] = {} + for item in items or []: + data_raw = item.data + if not data_raw: + continue + if isinstance(data_raw, bytes): + data_raw = data_raw.decode("utf-8") + try: + parsed = json.loads(data_raw) + except Exception as exc: # noqa: BLE001 + raise StateStoreError( + f"State for key '{item.key}' is not valid JSON: {exc}" + ) from exc + logical_key = self._strip_prefix(item.key) + results[logical_key] = self._validate_model( + parsed, return_model=return_model + ) + return results + + def save( + self, + *, + key: str, + value: Any, + etag: Optional[str] = None, + state_metadata: Optional[Dict[str, str]] = None, + state_options: Optional[Dict[str, Any]] = None, + ttl_in_seconds: Optional[int] = None, + ) -> None: + """ + Save a JSON payload under a logical key. + + Args: + key: Logical (unprefixed) key. + value: dict | BaseModel | JSON str | JSON bytes. + etag: Optional ETag for optimistic concurrency. + state_metadata: Optional Dapr metadata. + state_options: Dict of `StateOptions` fields (or a `StateOptions` instance). + ttl_in_seconds: Optional TTL; backend must support TTL via metadata. + """ + qualified = self._qualify(key) + payload_dict = self._ensure_dict(value) + payload_str = json.dumps(payload_dict) + + metadata = dict(state_metadata or {}) + if ttl_in_seconds is not None: + metadata.setdefault("ttlInSeconds", str(ttl_in_seconds)) + + logger.debug( + "Saving state to %s key=%s etag=%s ttl=%s", + self.store_name, + qualified, + etag, + ttl_in_seconds, + ) + + def call() -> None: + self._store().save_state( + qualified, + payload_str, + state_metadata=metadata or None, + etag=etag, + state_options=_coerce_state_options(state_options), + ) + + try: + self._with_retries(call) + except Exception as exc: # noqa: BLE001 + raise StateStoreError( + f"Failed to save state for key '{qualified}': {exc}" + ) from exc + + if self.mirror_writes: + self._save_local_copy(key=qualified, data=payload_dict) + + def delete( + self, + *, + key: str, + etag: Optional[str] = None, + state_metadata: Optional[Dict[str, str]] = None, + state_options: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Delete a logical key. + + Args: + key: Logical (unprefixed) key. + etag: Optional ETag for concurrency. + state_metadata: Optional Dapr metadata. + state_options: Dict or `StateOptions` controlling delete behavior. + """ + qualified = self._qualify(key) + logger.debug( + "Deleting state from %s key=%s etag=%s", self.store_name, qualified, etag + ) + + def call() -> None: + self._store().delete_state( + qualified, + etag=etag, + state_metadata=state_metadata, + state_options=_coerce_state_options(state_options), + ) + + try: + self._with_retries(call) + except Exception as exc: # noqa: BLE001 + raise StateStoreError( + f"Failed to delete state for key '{qualified}': {exc}" + ) from exc + + def exists(self, *, key: str) -> bool: + """ + Return True if the key exists (uses presence of an ETag as heuristic). + """ + _, etag = self.load_with_etag(key=key, default=None) + return etag is not None + + def save_many( + self, + items: Dict[str, Any], + *, + state_metadata: Optional[Dict[str, str]] = None, + state_options: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Save multiple logical keys (non-atomic sequence of individual saves). + + Args: + items: Mapping key -> JSON-serializable payload (same accepted types as `save()`). + state_metadata: Optional metadata applied to each save. + state_options: Dict or `StateOptions` applied to each save. + """ + logger.debug( + "Saving bulk state to %s keys=%s", self.store_name, list(items.keys()) + ) + metadata = state_metadata or {} + options = _coerce_state_options(state_options) + + def call() -> None: + store = self._store() + for key, value in items.items(): + payload_dict = self._ensure_dict(value) + payload_str = json.dumps(payload_dict) + store.save_state( + self._qualify(key), + payload_str, + state_metadata=metadata, + state_options=options, + ) + if self.mirror_writes: + self._save_local_copy(key=self._qualify(key), data=payload_dict) + + try: + self._with_retries(call) + except Exception as exc: # noqa: BLE001 + raise StateStoreError(f"Failed to bulk save state: {exc}") from exc + + def execute_transaction( + self, + operations: Sequence[Dict[str, Any]], + *, + metadata: Optional[Dict[str, str]] = None, + ) -> None: + """ + Execute a transactional batch (backend must support transactions). + + Args: + operations: Dapr transaction operations (upserts/deletes). + metadata: Optional request metadata. + """ + logger.debug( + "Executing state transaction on %s operations=%s", + self.store_name, + operations, + ) + + def call() -> None: + self._store().execute_state_transaction(operations, metadata=metadata) + + try: + self._with_retries(call) + except Exception as exc: # noqa: BLE001 + raise StateStoreError( + f"Failed to execute state transaction: {exc}" + ) from exc + + +def _deep_merge(original: Dict[str, Any], updates: Dict[str, Any]) -> Dict[str, Any]: + """ + Recursively merge two dictionaries (values in `updates` override `original`). + """ + result = dict(original) + for key, value in updates.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = _deep_merge(result[key], value) + else: + result[key] = value + return result + + +# ---------------------------------------------------------------------- +# Optional convenience helpers (explicit service; no env lookups) +# ---------------------------------------------------------------------- + + +def load_state_dict( + service: StateStoreService, + key: str, + *, + default: Optional[Dict[str, Any]] = None, + state_metadata: Optional[Dict[str, str]] = None, + return_model: bool = False, +) -> Union[Dict[str, Any], BaseModel]: + """ + Convenience wrapper for `service.load(...)` using an injected StateStoreService. + """ + return service.load( + key=key, + default=default, + state_metadata=state_metadata, + return_model=return_model, + ) + + +def load_state_with_etag( + service: StateStoreService, + key: str, + *, + default: Optional[Dict[str, Any]] = None, + state_metadata: Optional[Dict[str, str]] = None, + return_model: bool = False, +) -> Tuple[Union[Dict[str, Any], BaseModel], Optional[str]]: + """ + Convenience wrapper for `service.load_with_etag(...)`. + """ + return service.load_with_etag( + key=key, + default=default, + state_metadata=state_metadata, + return_model=return_model, + ) + + +def load_state_many( + service: StateStoreService, + keys: Sequence[str], + *, + parallelism: int = 1, + state_metadata: Optional[Dict[str, str]] = None, + return_model: bool = False, +) -> Dict[str, Union[Dict[str, Any], BaseModel]]: + """ + Convenience wrapper for `service.load_many(...)`. + """ + return service.load_many( + keys, + parallelism=parallelism, + state_metadata=state_metadata, + return_model=return_model, + ) + + +def save_state_dict( + service: StateStoreService, + key: str, + value: Any, + *, + etag: Optional[str] = None, + state_metadata: Optional[Dict[str, str]] = None, + state_options: Optional[Dict[str, Any]] = None, + ttl_in_seconds: Optional[int] = None, +) -> None: + """ + Convenience wrapper for `service.save(...)`. + """ + service.save( + key=key, + value=value, + etag=etag, + state_metadata=state_metadata, + state_options=state_options, + ttl_in_seconds=ttl_in_seconds, + ) + + +def save_state_many( + service: StateStoreService, + items: Dict[str, Any], + *, + state_metadata: Optional[Dict[str, str]] = None, + state_options: Optional[Dict[str, Any]] = None, +) -> None: + """ + Convenience wrapper for `service.save_many(...)`. + """ + service.save_many( + items, + state_metadata=state_metadata, + state_options=state_options, + ) + + +def delete_state( + service: StateStoreService, + key: str, + *, + etag: Optional[str] = None, + state_metadata: Optional[Dict[str, str]] = None, + state_options: Optional[Dict[str, Any]] = None, +) -> None: + """ + Convenience wrapper for `service.delete(...)`. + """ + service.delete( + key=key, + etag=etag, + state_metadata=state_metadata, + state_options=state_options, + ) + + +def state_exists( + service: StateStoreService, + key: str, +) -> bool: + """ + Convenience wrapper for `service.exists(...)`. + """ + return service.exists(key=key) + + +def execute_state_transaction( + service: StateStoreService, + operations: Sequence[Dict[str, Any]], + *, + metadata: Optional[Dict[str, str]] = None, +) -> None: + """ + Convenience wrapper for `service.execute_transaction(...)`. + """ + service.execute_transaction(operations, metadata=metadata) + + +__all__ = [ + "StateStoreError", + "StateStoreService", + "DaprStateStore", + "_coerce_state_options", + "load_state_dict", + "load_state_with_etag", + "load_state_many", + "save_state_dict", + "save_state_many", + "delete_state", + "state_exists", + "execute_state_transaction", +] diff --git a/dapr_agents/storage/daprstores/statestore.py b/dapr_agents/storage/daprstores/statestore.py index d2d41e68..6705248f 100644 --- a/dapr_agents/storage/daprstores/statestore.py +++ b/dapr_agents/storage/daprstores/statestore.py @@ -1,151 +1,237 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +from dapr.clients import DaprClient from dapr.clients.grpc._response import ( - BulkStatesResponse, BulkStateItem, - StateResponse, + BulkStatesResponse, QueryResponse, + StateResponse, ) -from dapr.clients import DaprClient -from dapr.clients.grpc._state import StateItem +from dapr.clients.grpc._state import StateItem, StateOptions + from dapr_agents.storage.daprstores.base import DaprStoreBase -from typing import Optional, Union, Dict, List, Tuple + + +def _coerce_state_options( + state_options: Optional[Union[StateOptions, Dict[str, Any]]], +) -> Optional[StateOptions]: + """ + Convert a dict of state options into a `StateOptions` instance, or pass + through an existing `StateOptions`. + + Args: + state_options: None, a dict matching `StateOptions` fields, or a `StateOptions`. + + Returns: + A `StateOptions` instance or None. + """ + if state_options is None or isinstance(state_options, StateOptions): + return state_options + return StateOptions(**state_options) class DaprStateStore(DaprStoreBase): + """ + Thin wrapper around Dapr state APIs returning raw gRPC response types. + + This class intentionally avoids JSON coercion, validation, prefixing, retries, + and mirroring. If you want those conveniences, use `StateStoreService`. + """ + def get_state( self, key: str, - state_metadata: Optional[Dict[str, str]] = dict(), + *, + state_metadata: Optional[Dict[str, str]] = None, ) -> StateResponse: """ - Retrieves a value from the state store using the provided key. + Retrieve a single state item. Args: - key (str): The key for the state store item. - state_metadata (Dict[str, str], optional): Dapr metadata for state request + key: Key to fetch (as stored in the Dapr component). + state_metadata: Optional Dapr metadata for the request. Returns: - StateResponse: gRPC metadata returned from callee and value obtained from the state store + `StateResponse` containing bytes payload, etag, and metadata. """ with DaprClient() as client: - response: StateResponse = client.get_state( - store_name=self.store_name, key=key, state_metadata=state_metadata + return client.get_state( + store_name=self.store_name, + key=key, + state_metadata=state_metadata, ) - return response def try_get_state( - self, key: str, state_metadata: Optional[Dict[str, str]] = dict() + self, + key: str, + *, + state_metadata: Optional[Dict[str, str]] = None, ) -> Tuple[bool, Optional[dict]]: """ - Attempts to retrieve a value from the state store using the provided key. + Attempt to get a JSON-encoded state item and decode it into a dict. Args: - key (str): The key for the state store item. - state_metadata (Dict[str, str], optional): Dapr metadata for state request. + key: Key to fetch. + state_metadata: Optional Dapr metadata. Returns: - Tuple[bool, Optional[dict]]: A tuple where the first element is a boolean indicating whether the state exists, - and the second element is the retrieved state data or None if not found. + (exists, payload_dict_or_none) """ - with DaprClient() as client: - response: StateResponse = client.get_state( - store_name=self.store_name, key=key, state_metadata=state_metadata - ) - if response and response.data: - return True, response.json() - return False, None + response = self.get_state( + key=key, + state_metadata=state_metadata, + ) + if response and response.data: + return True, response.json() + return False, None def get_bulk_state( self, keys: List[str], + *, parallelism: int = 1, states_metadata: Optional[Dict[str, str]] = None, ) -> List[BulkStateItem]: """ - Retrieves multiple values from the state store in bulk using a list of keys. + Retrieve multiple keys in one call. Args: - keys (List[str]): The keys to retrieve in bulk. - parallelism (int, optional): Number of keys to retrieve in parallel. - states_metadata (Dict[str, str], optional): Metadata for state request. + keys: Keys to fetch. + parallelism: How many to fetch in parallel (backend dependent). + states_metadata: Optional bulk metadata. Returns: - List[BulkStateItem]: A list of BulkStateItem objects representing the retrieved state. + List of `BulkStateItem`. Items with missing keys may have empty data. """ - states_metadata = states_metadata or {} - with DaprClient() as client: response: BulkStatesResponse = client.get_bulk_state( store_name=self.store_name, keys=keys, parallelism=parallelism, - states_metadata=states_metadata, + states_metadata=states_metadata or {}, ) - - if response and response.items: - return response.items - return [] + return response.items or [] def save_state( self, key: str, value: Union[str, bytes], - state_metadata: Optional[Dict[str, str]] = dict(), - ): + *, + state_metadata: Optional[Dict[str, str]] = None, + etag: Optional[str] = None, + state_options: Optional[Union[StateOptions, Dict[str, Any]]] = None, + ) -> None: """ - Saves a key-value pair in the state store. + Save a single key with raw bytes/str value. Args: - key (str): The key to save. - value (Union[str, bytes]): The value to save. - state_metadata (Dict[str, str], optional): Dapr metadata for state request + key: Key to write. + value: Bytes or string payload (caller handles JSON if desired). + state_metadata: Optional Dapr metadata. + etag: Optional ETag for concurrency. + state_options: `StateOptions` or dict fields for options. """ + options = _coerce_state_options(state_options) with DaprClient() as client: client.save_state( store_name=self.store_name, key=key, value=value, state_metadata=state_metadata, + etag=etag, + options=options, ) def save_bulk_state( - self, states: List[StateItem], metadata: Optional[Dict[str, str]] = None + self, + states: List[StateItem], + metadata: Optional[Dict[str, str]] = None, ) -> None: """ - Saves multiple key-value pairs to the state store in bulk. + Save multiple `StateItem`s. Caller constructs `StateItem` objects. Args: - states (List[StateItem]): The list of key-value pairs to save. - metadata (Dict[str, str], optional): Metadata for the save request. + states: List of StateItem to write. + metadata: Optional request metadata. """ with DaprClient() as client: client.save_bulk_state( - store_name=self.store_name, states=states, metadata=metadata + store_name=self.store_name, + states=states, + metadata=metadata, ) - def delete_state(self, key: str): + def delete_state( + self, + key: str, + *, + etag: Optional[str] = None, + state_options: Optional[Union[StateOptions, Dict[str, Any]]] = None, + state_metadata: Optional[Dict[str, str]] = None, + ) -> None: """ - Deletes a key-value pair from the state store. + Delete a single key. Args: - key (str): The key to delete. + key: Key to delete. + etag: Optional ETag for concurrency. + state_options: `StateOptions` or dict of options. + state_metadata: Optional Dapr metadata. """ + options = _coerce_state_options(state_options) with DaprClient() as client: - client.delete_state(store_name=self.store_name, key=key) + client.delete_state( + store_name=self.store_name, + key=key, + etag=etag, + state_options=options, + state_metadata=state_metadata, + ) def query_state( - self, query: str, states_metadata: Optional[Dict[str, str]] = None + self, + query: str, + *, + states_metadata: Optional[Dict[str, str]] = None, ) -> QueryResponse: """ - Queries the state store with a specific query. + Execute a state query (backend must support Dapr state queries). Args: - query (str): The query to be executed (in JSON format). - states_metadata (Dict[str, str], optional): Custom metadata for the state request. + query: JSON query string. + states_metadata: Optional Dapr metadata. Returns: - QueryResponse: Contains query results and metadata. + `QueryResponse` containing results and metadata. """ with DaprClient() as client: - client.query_state( - store_name=self.store_name, query=query, states_metadata=states_metadata + return client.query_state( + store_name=self.store_name, + query=query, + states_metadata=states_metadata, + ) + + def execute_state_transaction( + self, + operations: Sequence[Dict[str, Any]], + *, + metadata: Optional[Dict[str, str]] = None, + ) -> None: + """ + Execute a transactional batch of operations. + + Args: + operations: Dapr transaction operations (upserts/deletes). + metadata: Optional request metadata. + + Note: + Backend must support transactions (e.g., Redis in certain modes). + """ + with DaprClient() as client: + client.execute_state_transaction( + store_name=self.store_name, + operations=list(operations), + metadata=metadata, ) diff --git a/dapr_agents/types/message.py b/dapr_agents/types/message.py index 9bc4ad3f..eac34892 100644 --- a/dapr_agents/types/message.py +++ b/dapr_agents/types/message.py @@ -279,8 +279,7 @@ class LLMChatCandidate(BaseModel): message: AssistantMessage finish_reason: Optional[str] = None - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class LLMChatResponse(BaseModel): diff --git a/dapr_agents/types/workflow.py b/dapr_agents/types/workflow.py index aeeceb9d..ac3eff46 100644 --- a/dapr_agents/types/workflow.py +++ b/dapr_agents/types/workflow.py @@ -1,4 +1,6 @@ from enum import Enum +from dataclasses import dataclass +from typing import Any, Callable, List, Optional, Type class DaprWorkflowStatus(str, Enum): @@ -11,3 +13,51 @@ class DaprWorkflowStatus(str, Enum): TERMINATED = "terminated" # Workflow was canceled or forcefully terminated SUSPENDED = "suspended" # Workflow was temporarily paused PENDING = "pending" # Workflow is waiting to start + + +@dataclass +class PubSubRouteSpec: + """ + Pub/sub subscription that schedules a workflow when a message arrives. + + Attributes: + pubsub_name: Dapr pub/sub component name. + topic: Topic to subscribe to. + handler_fn: Bound workflow callable to run (method or function). + message_model: Optional schema (Pydantic/dataclass/dict). If omitted and + `handler_fn` is decorated with `@message_router`, the decorator's + first schema is used; otherwise `dict`. + dead_letter_topic: Optional DLQ topic name. + """ + + pubsub_name: str + topic: str + handler_fn: Callable[..., Any] + message_model: Optional[Type[Any]] = None + dead_letter_topic: Optional[str] = None + + +@dataclass +class HttpRouteSpec: + """ + HTTP endpoint that schedules a workflow when a request arrives. + + Attributes: + path: FastAPI path to mount (e.g., "/blog/start"). + handler_fn: Bound workflow callable to run (method or function). + method: HTTP method (default: POST). + request_model: Optional schema for request validation. If omitted and + `handler_fn` is decorated with `@http_router`, the decorator's first + schema is used; otherwise `dict`. + summary: Optional OpenAPI summary. + tags: Optional OpenAPI tags. + response_model: Optional Pydantic response model for docs. + """ + + path: str + handler_fn: Callable[..., Any] + method: str = "POST" + request_model: Optional[Type[Any]] = None + summary: Optional[str] = None + tags: Optional[List[str]] = None + response_model: Optional[Type[Any]] = None diff --git a/dapr_agents/utils/signal/mixin.py b/dapr_agents/utils/signal/mixin.py new file mode 100644 index 00000000..a780b442 --- /dev/null +++ b/dapr_agents/utils/signal/mixin.py @@ -0,0 +1,303 @@ +from __future__ import annotations + +import asyncio +import logging +from typing import Callable, Optional + +from .tools import install_signal_handlers, uninstall_signal_handlers + +logger = logging.getLogger(__name__) + + +class SignalMixin: + """ + Mixin providing graceful shutdown via Ctrl+C (SIGINT) / SIGTERM. + + Features: + - `install_signal_handlers()` / `remove_signal_handlers()` to manage OS handlers. + - An internal `asyncio.Event` waited by `wait_for_shutdown()`. + - `request_shutdown()` to trigger the same path programmatically (thread-safe). + - Overridable async `graceful_shutdown()` hook. + - Schedules cleanup (never awaits) from the actual signal context. + - Safe to reuse the same instance across starts/stops in tests. + + Typical usage:: + + class Service(SignalMixin): + async def start(self): + self.install_signal_handlers() + # ... start tasks/resources ... + await self.wait_for_shutdown() + await self.graceful_shutdown() + self.remove_signal_handlers() + + async def graceful_shutdown(self): + # ... stop tasks/resources, close clients ... + ... + """ + + def __init__(self, *args, **kwargs) -> None: + """ + Initialize mixin state. + + Supports multiple inheritance (delegates to super()). + """ + super().__init__(*args, **kwargs) # type: ignore[misc] + self._shutdown_event: Optional[asyncio.Event] = None + self._cleanup_handlers: Optional[Callable[[], None]] = None + self._shutdown_task_scheduled: bool = False + self._signal_loop: Optional[ + asyncio.AbstractEventLoop + ] = None # loop used for handler install + self._last_signal: Optional[ + int + ] = None # last observed signal number (-1 if programmatic) + + # ------------------- Public API ------------------- + + def install_signal_handlers(self) -> None: + """ + Install OS signal handlers and initialize (or reinitialize) the shutdown event. + + Idempotent: safe to call multiple times. + + Also resets internal scheduling flags so the instance can be restarted. + + Returns: + None + + Raises: + RuntimeError: If no event loop can be obtained for the current thread. + """ + # (Re)create a fresh event if first install OR previous event is already set. + if self._shutdown_event is None or self._shutdown_event.is_set(): + self._shutdown_event = asyncio.Event() + + # Reset scheduling gate so new shutdowns can be scheduled after restart. + self._shutdown_task_scheduled = False + + # If handlers are already installed and we have a loop, nothing to do. + if self._cleanup_handlers is not None and self._signal_loop is not None: + return + + # Capture the loop we will always bounce into (thread-safe). + try: + loop = asyncio.get_running_loop() + except RuntimeError: + try: + loop = asyncio.get_event_loop() + except RuntimeError as exc: + raise RuntimeError( + "No asyncio event loop available to install signal handlers." + ) from exc + + self._signal_loop = loop + self._cleanup_handlers = install_signal_handlers(loop, self._on_signal) + + def remove_signal_handlers(self) -> None: + """ + Remove/uninstall previously installed signal handlers. + + Safe to call multiple times; no-op if nothing is installed. + + Returns: + None + """ + uninstall_signal_handlers(self._cleanup_handlers) + self._cleanup_handlers = None + # Keep _signal_loop for request_shutdown() bounce to succeed even after uninstall. + # You can clear it when disposing the object if desired. + + async def wait_for_shutdown(self) -> None: + """ + Block until a shutdown is requested (signal or programmatic). + + Returns: + None + + Raises: + RuntimeError: If handlers/event have not been installed first. + """ + if self._shutdown_event is None: + raise RuntimeError( + "Call install_signal_handlers() before wait_for_shutdown()." + ) + await self._shutdown_event.wait() + + def is_shutdown_requested(self) -> bool: + """ + Indicate whether shutdown has been requested. + + Returns: + bool: True if a shutdown signal has been received or requested programmatically. + """ + return bool(self._shutdown_event and self._shutdown_event.is_set()) + + def request_shutdown(self) -> None: + """ + Programmatically request shutdown (thread-safe, same path as OS signal). + + If a loop was captured during installation, this always uses + `loop.call_soon_threadsafe(...)` so callers from foreign threads + behave like real signals. + + Returns: + None + """ + if self._signal_loop is not None: + try: + self._signal_loop.call_soon_threadsafe(self._on_signal, -1) + return + except Exception: + # Fall through to direct call if loop is already closed. + pass + self._on_signal(sig=-1) + + @property + def last_signal(self) -> Optional[int]: + """ + The last OS signal observed (or -1 if programmatic), else None. + + Returns: + Optional[int]: Last observed signal. + """ + return self._last_signal + + # ------------------- Hooks ------------------- + + async def graceful_shutdown(self) -> None: + """ + Override to perform async cleanup (close clients, cancel tasks, flush logs). + + Default implementation: + If the class provides a `stop()` attribute callable, it is invoked. + If `stop()` returns a coroutine, it is awaited. + + Returns: + None + """ + stop = getattr(self, "stop", None) + if callable(stop): + maybe_coro = stop() + if asyncio.iscoroutine(maybe_coro): + await maybe_coro # type: ignore[func-returns-value] + + # ------------------- Internals ------------------- + + def _schedule_graceful_shutdown(self) -> None: + """ + Schedule `graceful_shutdown()` on the captured loop exactly once, + and reset the scheduling gate when it completes. + + This function is resilient to a stopped/closed loop: + - If `call_soon_threadsafe` fails, the scheduling flag is reset. + - We then try a best-effort fallback on the currently running loop (if any). + """ + loop = self._signal_loop + if loop is None: + # Best-effort: try current running loop, then thread's loop. + try: + loop = asyncio.get_running_loop() + except RuntimeError: + try: + loop = asyncio.get_event_loop() + except RuntimeError: + logger.debug( + "No event loop available to schedule graceful shutdown." + ) + return + + def _mark_done(task: asyncio.Task) -> None: + """ + Done-callback for the graceful shutdown task. + + - Resets the scheduling gate to allow reuse of the instance. + - Surfaces exceptions from the shutdown task into logs. + - Calls optional `on_shutdown_complete()` hook (non-async, best-effort). + """ + try: + exc = task.exception() + if exc is not None: + logger.exception("graceful_shutdown() raised", exc_info=exc) + hook = getattr(self, "on_shutdown_complete", None) + if callable(hook): + try: + hook() + except Exception: + logger.debug("on_shutdown_complete hook raised", exc_info=True) + except asyncio.CancelledError: + logger.debug("graceful_shutdown() task was cancelled", exc_info=True) + finally: + # Allow reuse of the instance in tests / restarts. + self._shutdown_task_scheduled = False + + # Try to bounce into the captured loop thread. + try: + + def _spawn() -> None: + t = loop.create_task(self.graceful_shutdown()) + t.add_done_callback(_mark_done) + + loop.call_soon_threadsafe(_spawn) + return + except Exception: + # If scheduling on the captured loop failed, reset the gate so callers can retry. + logger.debug( + "Failed to schedule via call_soon_threadsafe; attempting fallback.", + exc_info=True, + ) + self._shutdown_task_scheduled = False + + # Fallback: try the *current* running loop (e.g., called from a different alive loop). + try: + current = asyncio.get_running_loop() + except RuntimeError: + # No running loop; nothing else we can do. Leave the gate open for future attempts. + return + + try: + # Re-set the gate since we're attempting a new schedule path. + self._shutdown_task_scheduled = True + task = current.create_task(self.graceful_shutdown()) + task.add_done_callback(_mark_done) + except Exception: + # Scheduling failed again; open the gate so a later attempt can retry. + logger.debug("Fallback scheduling on current loop failed.", exc_info=True) + self._shutdown_task_scheduled = False + + def _on_signal(self, sig: int) -> None: + """ + Internal signal callback. Do not await here. + + Sets the shutdown event and schedules `graceful_shutdown()` (once). + Records the signal and calls an optional non-blocking hook. + + Args: + sig: The received signal number (or -1 if triggered programmatically). + + Returns: + None + """ + if self._shutdown_event is None: + return + + # Record & optional hook for observability + self._last_signal = sig + hook = getattr(self, "on_signal_received", None) + if callable(hook): + try: + hook(sig) # tiny, non-blocking hook; do not await here + except Exception: + logger.debug("on_signal_received hook raised", exc_info=True) + + if not self._shutdown_event.is_set(): + self._shutdown_event.set() + + if self._shutdown_task_scheduled: + return + self._shutdown_task_scheduled = True + + logger.debug( + "Shutdown requested (signal=%s); scheduling graceful shutdown.", sig + ) + self._schedule_graceful_shutdown() diff --git a/dapr_agents/utils/signal/tools.py b/dapr_agents/utils/signal/tools.py new file mode 100644 index 00000000..99817e95 --- /dev/null +++ b/dapr_agents/utils/signal/tools.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import asyncio +import logging +import signal +import threading +from typing import Callable, Dict, Iterable, List, Optional + +logger = logging.getLogger(__name__) + +HandlerCleanup = Callable[[], None] + + +def _is_main_thread() -> bool: + """ + Return True if running in the main thread. + + Signal handlers must be registered in the main thread; most runtimes + will ignore or error otherwise. + """ + return threading.current_thread() is threading.main_thread() + + +def install_signal_handlers( + loop: asyncio.AbstractEventLoop, + on_signal: Callable[[int], None], + signals_to_handle: Iterable[signal.Signals] = (signal.SIGINT, signal.SIGTERM), +) -> HandlerCleanup: + """ + Install Ctrl+C (SIGINT) / SIGTERM handlers in a cross-platform way. + + On Unix: + - Prefer `loop.add_signal_handler` (non-blocking, runs in the loop thread). + - If unsupported (e.g., loop not running), fall back to `signal.signal` + and forward into the loop via `loop.call_soon_threadsafe`. + + On Windows: + - Use `signal.signal`, forwarding into the loop with `loop.call_soon_threadsafe`. + + Handlers are only installed in the main thread. + + Args: + loop: The asyncio event loop to schedule work onto after a signal arrives. + on_signal: Callback invoked with the signal number (int). This is called + in the loop thread (via `call_soon_threadsafe`) or directly as a + best-effort fallback if the loop is already closed. + signals_to_handle: Iterable of signals to install (default: SIGINT, SIGTERM). + + Returns: + A zero-arg callable that cleans up the installed handlers. Always safe to call. + + Raises: + ValueError: If `on_signal` is not callable. + """ + if not callable(on_signal): + raise ValueError("on_signal must be callable") + + if not _is_main_thread(): + logger.debug("Skipping signal handler installation (not in main thread).") + return lambda: None + + sigs: List[signal.Signals] = list(signals_to_handle) + previous: Dict[signal.Signals, object] = {} + added_via_loop = False + + def _dispatch(sig_num: int) -> None: + """Forward the signal to the loop thread safely.""" + try: + loop.call_soon_threadsafe(on_signal, sig_num) + except RuntimeError: + # Loop likely closed; fall back to direct call (best effort). + try: + on_signal(sig_num) + except Exception: + logger.exception("Error dispatching shutdown after loop close") + + # Try loop-integrated handlers first (Unix). May raise on Windows or early init. + try: + for s in sigs: + loop.add_signal_handler(s, _dispatch, s) + added_via_loop = True + logger.debug("Installed signal handlers via loop.add_signal_handler") + except (NotImplementedError, RuntimeError): + # Fallback: traditional handlers + bounce to loop + for s in sigs: + try: + previous[s] = signal.getsignal(s) + + def _handler(signum: int, _frame) -> None: # noqa: ANN001 + _dispatch(signum) + + signal.signal(s, _handler) + except Exception: + logger.exception("Failed to install handler for %s", s) + + logger.debug("Installed signal handlers via signal.signal fallback") + + def _cleanup() -> None: + """Remove loop handlers or restore previous traditional handlers.""" + if added_via_loop: + for s in sigs: + try: + loop.remove_signal_handler(s) + except Exception: + # Loop may already be closed; ignore during shutdown. + pass + return + + for s, prev in previous.items(): + try: + # Only pass prev to signal.signal if it's a valid handler type + if prev is None or callable(prev) or isinstance(prev, int): + signal.signal(s, prev) + else: + signal.signal(s, None) + except Exception: + # Avoid raising from cleanup paths. + pass + + return _cleanup + + +def uninstall_signal_handlers(cleanup: Optional[HandlerCleanup]) -> None: + """ + Execute the cleanup closure returned by `install_signal_handlers`, if any. + + Args: + cleanup: The cleanup function returned by `install_signal_handlers()`. + + Returns: + None + + Notes: + Safe to call multiple times and never raises. + """ + if cleanup is None: + return + try: + cleanup() + except Exception: + # Never raise during shutdown cleanup. + logger.debug("Error while uninstalling signal handlers", exc_info=True) diff --git a/dapr_agents/workflow/__init__.py b/dapr_agents/workflow/__init__.py index 088db3c1..6652594a 100644 --- a/dapr_agents/workflow/__init__.py +++ b/dapr_agents/workflow/__init__.py @@ -1,17 +1,8 @@ -from .agentic import AgenticWorkflow -from .base import WorkflowApp -from .decorators import message_router, task, workflow -from .orchestrators import LLMOrchestrator, RandomOrchestrator, RoundRobinOrchestrator -from .task import WorkflowTask +from .decorators import message_router, http_router, llm_activity, agent_activity __all__ = [ - "WorkflowApp", - "WorkflowTask", - "AgenticWorkflow", - "LLMOrchestrator", - "RandomOrchestrator", - "RoundRobinOrchestrator", - "workflow", - "task", "message_router", + "http_router", + "llm_activity", + "agent_activity", ] diff --git a/dapr_agents/workflow/agentic.py b/dapr_agents/workflow/agentic.py deleted file mode 100644 index 49b1974b..00000000 --- a/dapr_agents/workflow/agentic.py +++ /dev/null @@ -1,412 +0,0 @@ -import asyncio -import json -import logging -import time -from typing import Any, Callable, Dict, Optional, Tuple, Type, List - -from cloudevents.http.conversion import from_http -from cloudevents.http.event import CloudEvent -from dapr.clients import DaprClient -from dapr.clients.grpc._request import ( - TransactionalStateOperation, - TransactionOperationType, -) -from dapr.clients.grpc._response import StateResponse -from dapr.clients.grpc._state import Concurrency, Consistency, StateOptions -from fastapi import Request, status -from fastapi.responses import JSONResponse -from pydantic import BaseModel, Field, PrivateAttr - -from dapr_agents.agents.utils.text_printer import ColorTextFormatter -from dapr_agents.memory import ( - ConversationListMemory, - ConversationVectorMemory, - MemoryBase, -) -from dapr_agents.storage.daprstores.statestore import DaprStateStore -from dapr_agents.workflow.base import WorkflowApp -from dapr_agents.workflow.mixins import ( - MessagingMixin, - PubSubMixin, - ServiceMixin, - StateManagementMixin, -) - -logger = logging.getLogger(__name__) - - -class AgenticWorkflow( - WorkflowApp, - PubSubMixin, - StateManagementMixin, - ServiceMixin, - MessagingMixin, -): - """ - A class for managing agentic workflows, extending `WorkflowApp`. - Handles agent interactions, workflow execution, messaging, and metadata management. - """ - - name: str = Field(..., description="The name of the agentic system.") - message_bus_name: str = Field( - ..., - description="The name of the message bus component, defining the pub/sub base.", - ) - broadcast_topic_name: Optional[str] = Field( - default=None, - description="Default topic for broadcasting messages. Set explicitly for multi-agent setups.", - ) - state_store_name: str = Field( - ..., description="Dapr state store for workflow state." - ) - state_key: str = Field( - default="workflow_state", - description="Dapr state key for workflow state storage.", - ) - state: dict = Field( - default_factory=dict, description="Current state of the workflow." - ) - state_format: Optional[Type[BaseModel]] = Field( - default=None, - description=( - "Optional Pydantic model used to validate the persisted workflow " - "state. If provided, state loaded from storage is coerced to this " - "schema." - ), - ) - agents_registry_store_name: str = Field( - ..., description="Dapr state store for agent metadata." - ) - agents_registry_key: str = Field( - default="agents_registry", description="Key for agents registry in state store." - ) - # TODO: test this is respected by runtime. - max_iterations: int = Field( - default=10, description="Maximum iterations for workflows.", ge=1 - ) - memory: MemoryBase = Field( - default_factory=ConversationListMemory, - description="Handles conversation history storage.", - ) - save_state_locally: bool = Field( - default=True, description="Whether to save workflow state locally." - ) - local_state_path: Optional[str] = Field( - default=None, description="Local path for saving state files." - ) - client: Optional[DaprClient] = Field( - default=None, init=False, description="Dapr client instance." - ) - - # Private internal attributes (not schema/validated) - _state_store_client: Optional[DaprStateStore] = PrivateAttr(default=None) - _text_formatter: ColorTextFormatter = PrivateAttr(default=ColorTextFormatter) - _agent_metadata: Optional[Dict[str, Any]] = PrivateAttr(default=None) - _workflow_name: str = PrivateAttr(default=None) - _dapr_client: Optional[DaprClient] = PrivateAttr(default=None) - _is_running: bool = PrivateAttr(default=False) - _shutdown_event: asyncio.Event = PrivateAttr(default_factory=asyncio.Event) - _http_server: Optional[Any] = PrivateAttr(default=None) - _subscriptions: Dict[str, Callable] = PrivateAttr(default_factory=dict) - _topic_handlers: Dict[ - Tuple[str, str], Dict[Type[BaseModel], Callable] - ] = PrivateAttr(default_factory=dict) - - def model_post_init(self, __context: Any) -> None: - """ - Post-initialization hook for the AgenticWorkflow. - - This method initializes the workflow service, messaging, and metadata storage. - It sets up the color formatter, state store client, loads or initializes the workflow state, - and creates a Dapr client for service-to-service calls or state interactions. - - Args: - __context (Any): The context passed by Pydantic's model initialization. - - Raises: - RuntimeError: If Dapr is not available in the current environment. - """ - self._dapr_client = DaprClient() - self._text_formatter = ColorTextFormatter() - self._state_store_client = DaprStateStore(store_name=self.state_store_name) - logger.info(f"State store '{self.state_store_name}' initialized.") - self.initialize_state() - super().model_post_init(__context) - - def get_chat_history(self, task: Optional[str] = None) -> List[Dict[str, Any]]: - """ - Retrieves the chat history from memory as a list of dictionaries. - - Args: - task (Optional[str]): The task or query provided by the user (used for vector search). - - Returns: - List[Dict[str, Any]]: The chat history as dictionaries. - """ - if isinstance(self.memory, ConversationVectorMemory) and task: - if ( - hasattr(self.memory.vector_store, "embedding_function") - and self.memory.vector_store.embedding_function - and hasattr( - self.memory.vector_store.embedding_function, "embed_documents" - ) - ): - query_embeddings = self.memory.vector_store.embedding_function.embed( - task - ) - messages = self.memory.get_messages(query_embeddings=query_embeddings) - else: - messages = self.memory.get_messages() - else: - messages = self.memory.get_messages() - return messages - - @property - def chat_history(self) -> List[Dict[str, Any]]: - """ - Returns the full chat history as a list of dictionaries. - - Returns: - List[Dict[str, Any]]: The chat history. - """ - return self.get_chat_history() - - def get_data_from_store(self, store_name: str, key: str) -> Optional[dict]: - """ - Retrieves data from the Dapr state store using the given key. - - Args: - store_name (str): The name of the Dapr state store component. - key (str): The key to fetch data from. - - Returns: - Optional[dict]: the retrieved dictionary or None if not found. - """ - try: - response: StateResponse = self._dapr_client.get_state( - store_name=store_name, key=key - ) - data = response.data - return json.loads(data) if data else None - except Exception: - logger.warning( - f"Error retrieving data for key '{key}' from store '{store_name}'" - ) - return None - - def get_agents_metadata( - self, exclude_self: bool = True, exclude_orchestrator: bool = False - ) -> Dict[str, Any]: - """ - Retrieve metadata for all registered agents. - - Args: - exclude_self: If True, excludes the current agent from results. - exclude_orchestrator: If True, excludes orchestrators from results. - - Returns: - Dict[str, Any]: Mapping of agent names to their metadata. - - Raises: - RuntimeError: If retrieval fails. - """ - try: - agents_metadata = ( - self.get_data_from_store( - self.agents_registry_store_name, self.agents_registry_key - ) - or {} - ) - - if agents_metadata: - logger.info( - f"Agents found in '{self.agents_registry_store_name}' for key '{self.agents_registry_key}'." - ) - filtered = { - name: metadata - for name, metadata in agents_metadata.items() - if not (exclude_self and name == self.name) - and not ( - exclude_orchestrator and metadata.get("orchestrator", False) - ) - } - if not filtered: - logger.info("No other agents found after filtering.") - return filtered - - logger.info( - f"No agents found in '{self.agents_registry_store_name}' for key '{self.agents_registry_key}'." - ) - return {} - except Exception as e: - logger.error(f"Failed to retrieve agents metadata: {e}", exc_info=True) - raise RuntimeError(f"Error retrieving agents metadata: {str(e)}") from e - - def print_interaction( - self, sender_agent_name: str, recipient_agent_name: str, message: str - ) -> None: - """ - Pretty-print an interaction between two agents. - - Args: - sender_agent_name: The name of the agent sending the message. - recipient_agent_name: The name of the agent receiving the message. - message: The message content to display. - """ - separator = "-" * 80 - interaction_text = [ - (sender_agent_name, "dapr_agents_mustard"), - (" -> ", "dapr_agents_teal"), - (f"{recipient_agent_name}\n\n", "dapr_agents_mustard"), - (message + "\n\n", None), - (separator + "\n", "dapr_agents_teal"), - ] - self._text_formatter.print_colored_text(interaction_text) - - def register_agent( - self, store_name: str, store_key: str, agent_name: str, agent_metadata: dict - ) -> None: - """ - Merges the existing data with the new data and updates the store. - - Args: - store_name (str): The name of the Dapr state store component. - key (str): The key to update. - data (dict): The data to update the store with. - """ - # retry the entire operation up to twenty times sleeping 1-2 seconds between each - # TODO: rm the custom retry logic here and use the DaprClient retry_policy instead. - for attempt in range(1, 21): - try: - response: StateResponse = self._dapr_client.get_state( - store_name=store_name, key=store_key - ) - if not response.etag: - # if there is no etag the following transaction won't work as expected - # so we need to save an empty object with a strong consistency to force the etag to be created - self._dapr_client.save_state( - store_name=store_name, - key=store_key, - value=json.dumps({}), - state_metadata={ - "contentType": "application/json", - "partitionKey": store_key, - }, - options=StateOptions( - concurrency=Concurrency.first_write, - consistency=Consistency.strong, - ), - ) - # raise an exception to retry the entire operation - raise Exception(f"No etag found for key: {store_key}") - existing_data = json.loads(response.data) if response.data else {} - if (agent_name, agent_metadata) in existing_data.items(): - logger.debug(f"agent {agent_name} already registered.") - return None - agent_data = {agent_name: agent_metadata} - merged_data = {**existing_data, **agent_data} - logger.debug(f"merged data: {merged_data} etag: {response.etag}") - try: - # using the transactional API to be able to later support the Dapr outbox pattern - self._dapr_client.execute_state_transaction( - store_name=store_name, - operations=[ - TransactionalStateOperation( - key=store_key, - data=json.dumps(merged_data), - etag=response.etag, - operation_type=TransactionOperationType.upsert, - ) - ], - transactional_metadata={ - "contentType": "application/json", - "partitionKey": store_key, - }, - ) - except Exception as e: - raise e - return None - except Exception as e: - logger.error(f"Error on transaction attempt: {attempt}: {e}") - # Add random jitter - import random - - delay = 1 + random.uniform(0, 1) # 1-2 seconds - logger.info( - f"Sleeping for {delay:.2f} seconds before retrying transaction..." - ) - time.sleep(delay) - raise Exception( - f"Failed to update state store key: {store_key} after 20 attempts." - ) - - def register_agentic_system(self) -> None: - """ - Register this agent's metadata in the Dapr state store. - - Raises: - Exception: If registration fails. - """ - try: - self.register_agent( - store_name=self.agents_registry_store_name, - store_key=self.agents_registry_key, - agent_name=self.name, - agent_metadata=self._agent_metadata, - ) - except Exception as e: - logger.error(f"Failed to register metadata for agent {self.name}: {e}") - raise e - - async def run_workflow_from_request(self, request: Request) -> JSONResponse: - """ - Run a workflow instance triggered by an HTTP POST request. - - Args: - request: The incoming FastAPI request. - - Returns: - JSONResponse: HTTP response with workflow instance ID or error. - """ - try: - workflow_name = request.query_params.get("name") or self._workflow_name - if not workflow_name: - return JSONResponse( - content={"error": "No workflow name specified."}, - status_code=status.HTTP_400_BAD_REQUEST, - ) - - if workflow_name not in self.workflows: - return JSONResponse( - content={ - "error": f"Unknown workflow '{workflow_name}'. Available: {list(self.workflows.keys())}" - }, - status_code=status.HTTP_400_BAD_REQUEST, - ) - - try: - event: CloudEvent = from_http( - dict(request.headers), await request.body() - ) - input_data = event.data - except Exception: - input_data = await request.json() - - logger.info(f"Starting workflow '{workflow_name}' with input: {input_data}") - instance_id = await self.run_and_monitor_workflow_async( - workflow=workflow_name, input=input_data - ) - - return JSONResponse( - content={ - "message": "Workflow initiated successfully.", - "workflow_instance_id": instance_id, - }, - status_code=status.HTTP_202_ACCEPTED, - ) - except Exception as e: - logger.error(f"Error starting workflow: {str(e)}", exc_info=True) - return JSONResponse( - content={"error": "Failed to start workflow", "details": str(e)}, - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) diff --git a/dapr_agents/workflow/base.py b/dapr_agents/workflow/base.py deleted file mode 100644 index f71feae5..00000000 --- a/dapr_agents/workflow/base.py +++ /dev/null @@ -1,1387 +0,0 @@ -import asyncio -import functools -import inspect -import json -import logging -import time -import sys -import uuid -from datetime import datetime, timezone -from typing import Any, Callable, Dict, List, Optional, TypeVar, Union, Sequence - -from dapr.ext.workflow import ( - DaprWorkflowClient, - WorkflowActivityContext, - WorkflowRuntime, -) -from dapr.ext.workflow.workflow_state import WorkflowState -from durabletask import task as dtask -from pydantic import BaseModel, ConfigDict, Field, model_validator - -from dapr_agents.agents.base import ChatClientBase -from dapr_agents.llm.utils.defaults import get_default_llm -from dapr_agents.types.workflow import DaprWorkflowStatus -from dapr_agents.utils import SignalHandlingMixin -from dapr_agents.workflow.task import WorkflowTask -from dapr_agents.workflow.utils.core import get_decorated_methods, is_pydantic_model - -logger = logging.getLogger(__name__) - -T = TypeVar("T") - - -class WorkflowApp(BaseModel, SignalHandlingMixin): - """ - A Pydantic-based class to encapsulate a Dapr Workflow runtime and manage workflows and tasks. - """ - - # NOTE: Workflow instrumentation is applied directly during instrumentor initialization - - llm: Optional[ChatClientBase] = Field( - default=None, - description="The default LLM client for tasks that explicitly require an LLM but don't specify one (optional).", - ) - # TODO: I think this should be within the wf client or wf runtime...? - timeout: int = Field( - default=300, - description="Default timeout duration in seconds for workflow tasks.", - ) - grpc_max_send_message_length: Optional[int] = Field( - default=None, - description="Maximum message length in bytes for gRPC send operations. Default is 4MB if not specified. Useful for AI workflows with large payloads (e.g., images).", - ) - grpc_max_receive_message_length: Optional[int] = Field( - default=None, - description="Maximum message length in bytes for gRPC receive operations. Default is 4MB if not specified. Useful for AI workflows with large payloads (e.g., images).", - ) - - # Initialized in model_post_init - wf_runtime: Optional[WorkflowRuntime] = Field( - default=None, init=False, description="Workflow runtime instance." - ) - wf_runtime_is_running: Optional[bool] = Field( - default=None, init=False, description="Is the Workflow runtime running?" - ) - wf_client: Optional[DaprWorkflowClient] = Field( - default=None, init=False, description="Workflow client instance." - ) - tasks: Dict[str, Callable] = Field( - default_factory=dict, init=False, description="Dictionary of registered tasks." - ) - workflows: Dict[str, Callable] = Field( - default_factory=dict, - init=False, - description="Dictionary of registered workflows.", - ) - - model_config = ConfigDict(arbitrary_types_allowed=True) - - @model_validator(mode="before") - def validate_grpc_chanell_options(cls, values: Any): - if not isinstance(values, dict): - return values - - if values.get("grpc_max_send_message_length") is not None: - if values["grpc_max_send_message_length"] < 0: - raise ValueError("grpc_max_send_message_length must be greater than 0") - - if values.get("grpc_max_receive_message_length") is not None: - if values["grpc_max_receive_message_length"] < 0: - raise ValueError( - "grpc_max_receive_message_length must be greater than 0" - ) - - return values - - def model_post_init(self, __context: Any) -> None: - """ - Initialize the Dapr workflow runtime and register tasks & workflows. - """ - if self.grpc_max_send_message_length or self.grpc_max_receive_message_length: - self._configure_grpc_channel_options() - - # Initialize LLM first - if self.llm is None: - self.llm = get_default_llm() - - # Initialize clients and runtime - self.wf_runtime = WorkflowRuntime() - self.wf_runtime_is_running = False - self.wf_client = DaprWorkflowClient() - logger.info("WorkflowApp initialized; discovering tasks and workflows.") - - self.start_runtime() - - # Set up automatic signal handlers for graceful shutdown - try: - self.setup_signal_handlers() - except Exception as e: - logger.warning(f"Could not set up signal handlers: {e}") - - super().model_post_init(__context) - - def _configure_grpc_channel_options(self) -> None: - """ - Configure gRPC channel options before workflow runtime initialization. - This patches the durabletask internal channel factory to support custom message size limits. - - This is particularly useful for AI-powered workflows that may need to handle large payloads - such as images, which can exceed the default 4MB gRPC message size limit. - """ - try: - import grpc - from durabletask.internal import shared - - # Create custom options list - options = [] - if self.grpc_max_send_message_length: - options.append( - ("grpc.max_send_message_length", self.grpc_max_send_message_length) - ) - logger.debug( - f"Configured gRPC max_send_message_length: {self.grpc_max_send_message_length} bytes ({self.grpc_max_send_message_length / (1024 * 1024):.2f} MB)" - ) - if self.grpc_max_receive_message_length: - options.append( - ( - "grpc.max_receive_message_length", - self.grpc_max_receive_message_length, - ) - ) - logger.debug( - f"Configured gRPC max_receive_message_length: {self.grpc_max_receive_message_length} bytes ({self.grpc_max_receive_message_length / (1024 * 1024):.2f} MB)" - ) - - # Patch the function to include our custom options - def get_grpc_channel_with_options( - host_address: Optional[str], - secure_channel: bool = False, - interceptors: Optional[Sequence["grpc.ClientInterceptor"]] = None, - ): - # This is a copy of the original get_grpc_channel function in durabletask.internal.shared at - # https://github.com/dapr/durabletask-python/blob/7070cb07d07978d079f8c099743ee4a66ae70e05/durabletask/internal/shared.py#L30C1-L61C19 - # but with my option overrides applied above. - if host_address is None: - host_address = shared.get_default_host_address() - - for protocol in getattr(shared, "SECURE_PROTOCOLS", []): - if host_address.lower().startswith(protocol): - secure_channel = True - # remove the protocol from the host name - host_address = host_address[len(protocol) :] - break - - for protocol in getattr(shared, "INSECURE_PROTOCOLS", []): - if host_address.lower().startswith(protocol): - secure_channel = False - # remove the protocol from the host name - host_address = host_address[len(protocol) :] - break - - # Create the base channel - if secure_channel: - credentials = grpc.ssl_channel_credentials() - channel = grpc.secure_channel( - host_address, credentials, options=options - ) - else: - channel = grpc.insecure_channel(host_address, options=options) - - # Apply interceptors ONLY if they exist - if interceptors: - channel = grpc.intercept_channel(channel, *interceptors) - - return channel - - # Replace the function - shared.get_grpc_channel = get_grpc_channel_with_options - - logger.debug( - "Successfully patched durabletask gRPC channel factory with custom options" - ) - - except ImportError as e: - logger.error( - f"Failed to import required modules for gRPC configuration: {e}" - ) - raise - except Exception as e: - logger.error(f"Failed to configure gRPC channel options: {e}") - raise - - def graceful_shutdown(self) -> None: - """ - Perform graceful shutdown operations for the WorkflowApp. - - This method stops the workflow runtime and cleans up resources. - Overrides the SignalHandlingMixin method to provide WorkflowApp-specific cleanup. - """ - logger.debug("Initiating graceful shutdown of WorkflowApp...") - - try: - if getattr(self, "wf_runtime_is_running", False): - logger.debug("Shutting down workflow runtime...") - self.stop_runtime() - logger.debug("Workflow runtime stopped successfully.") - except Exception as e: - logger.error(f"Error during workflow runtime shutdown: {e}") - - def __del__(self): - """ - Cleanup method called when WorkflowApp is garbage collected. - Ensures runtime is properly stopped. - """ - try: - if getattr(self, "wf_runtime_is_running", False): - logger.debug("Cleaning up workflow runtime in destructor...") - self.stop_runtime() - except Exception: - # Ignore errors during cleanup in destructor - pass - - def setup_shutdown_handlers(self) -> None: - """ - Set up signal handlers for graceful shutdown. - - Call this method to enable automatic cleanup when the process receives - shutdown signals (SIGINT, SIGTERM). - """ - self.setup_signal_handlers() - logger.debug("Shutdown signal handlers configured for WorkflowApp.") - - async def __aenter__(self): - """ - Async context manager entry. - Sets up signal handlers for automatic cleanup. - """ - self.setup_shutdown_handlers() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """ - Async context manager exit. - Ensures graceful shutdown when exiting the context. - """ - await self.graceful_shutdown() - - def _choose_llm_for(self, method: Callable) -> Optional[ChatClientBase]: - """ - Encapsulate LLM selection logic. - 1. Use per-task override if provided on decorator. - 2. Else if marked as explicitly requiring an LLM, fall back to default app LLM (if available). - 3. Otherwise, returns None. - """ - per_task = getattr(method, "_task_llm", None) - if per_task: - return per_task - if getattr(method, "_explicit_llm", False): - if self.llm is None: - logger.warning( - f"Task '{getattr(method, '_task_name', getattr(method, '__name__', str(method)))}' requires an LLM " - "but no default LLM is configured in WorkflowApp and no explicit LLM was provided." - ) - return self.llm - return None - - def _discover_tasks(self) -> Dict[str, Callable]: - """Gather all @task-decorated functions and methods.""" - module = sys.modules["__main__"] - tasks: Dict[str, Callable] = {} - # Free functions in __main__ - for name, fn in inspect.getmembers(module, inspect.isfunction): - if getattr(fn, "_is_task", False) and fn.__module__ == module.__name__: - tasks[getattr(fn, "_task_name", name)] = fn - # Bound methods (if any) discovered via helper - for name, method in get_decorated_methods(self, "_is_task").items(): - tasks[getattr(method, "_task_name", name)] = method - logger.debug(f"Discovered tasks: {list(tasks)}") - return tasks - - def register_task(self, task_func: Callable, name: Optional[str] = None) -> None: - """ - Manually register a @task-decorated function, similar to native Dapr pattern. - - This allows explicit registration of tasks from other modules or packages. - The task will be registered immediately with the Dapr runtime. - - Args: - task_func: The @task-decorated function to register - name: Optional custom name (defaults to function name or _task_name) - - Example: - from tasks.my_tasks import generate_queries - wfapp.register_task(generate_queries) - wfapp.register_task(generate_queries, name="custom_name") - """ - # Validate input - if not callable(task_func): - raise ValueError( - f"task_func must be callable, got {type(task_func)}: {task_func}" - ) - - if not getattr(task_func, "_is_task", False): - raise ValueError( - f"Function {getattr(task_func, '__name__', str(task_func))} is not decorated with @task" - ) - - task_name = name or getattr( - task_func, "_task_name", getattr(task_func, "__name__", "unknown_task") - ) - - # Check if already registered - if task_name in self.tasks: - logger.warning(f"Task '{task_name}' is already registered, skipping") - return - - # Register immediately with Dapr runtime using existing registration logic - llm = self._choose_llm_for(task_func) - logger.debug( - f"Manually registering task '{task_name}' with llm={getattr(llm, '__class__', None)}" - ) - - kwargs = getattr(task_func, "_task_kwargs", {}) - task_instance = WorkflowTask( - func=task_func, - description=getattr(task_func, "_task_description", None), - agent=getattr(task_func, "_task_agent", None), - llm=llm, - include_chat_history=getattr( - task_func, "_task_include_chat_history", False - ), - workflow_app=self, - **kwargs, - ) - - # Wrap for Dapr invocation - wrapped = self._make_task_wrapper(task_name, task_func, task_instance) - self.wf_runtime.register_activity(wrapped) - self.tasks[task_name] = wrapped - - def register_tasks_from_module( - self, module_name_or_object: Union[str, Any] - ) -> None: - """ - Register all @task-decorated functions from a specific module. - - Args: - module_name_or_object: Module name string (e.g., "tasks.queries") or imported module object - - Example: - # Using string name - wfapp.register_tasks_from_module("tasks.queries") - - # Using imported module - import tasks.queries - wfapp.register_tasks_from_module(tasks.queries) - - # Using from import - from tasks import queries - wfapp.register_tasks_from_module(queries) - """ - try: - # Handle both string names and module objects - if isinstance(module_name_or_object, str): - import importlib - - module = importlib.import_module(module_name_or_object) - module_name = module_name_or_object - else: - # Assume it's a module object - module = module_name_or_object - module_name = getattr(module, "__name__", str(module)) - - registered_count = 0 - for name, fn in inspect.getmembers(module, inspect.isfunction): - if getattr(fn, "_is_task", False): - task_name = getattr(fn, "_task_name", name) - if task_name not in self.tasks: # Skip if already registered - self.register_task(fn) - registered_count += 1 - - logger.info( - f"Registered {registered_count} tasks from module '{module_name}'" - ) - - except ImportError as e: - raise ImportError(f"Could not import module '{module_name_or_object}': {e}") - except Exception as e: - raise RuntimeError( - f"Error registering tasks from module '{module_name_or_object}': {e}" - ) - - def register_tasks_from_package(self, package_name: str) -> None: - """ - Register all @task-decorated functions from all modules in a package. - - Args: - package_name: Name of package to scan (e.g., "tasks") - - Example: - wfapp.register_tasks_from_package("tasks") # Scans tasks/*.py - """ - try: - import importlib - import pkgutil - - package = importlib.import_module(package_name) - - # Collect all tasks first, then register using the original _register_tasks method - discovered_tasks: Dict[str, Callable] = {} - - total_tasks = 0 - for importer, modname, ispkg in pkgutil.iter_modules(package.__path__): - if not ispkg: # Only scan modules, not sub-packages - full_module_name = f"{package_name}.{modname}" - try: - module = importlib.import_module(full_module_name) - - for name, fn in inspect.getmembers(module, inspect.isfunction): - if getattr(fn, "_is_task", False): - task_name = getattr(fn, "_task_name", name) - if ( - task_name not in self.tasks - ): # Skip if already registered - discovered_tasks[task_name] = fn - total_tasks += 1 - - except Exception as e: - logger.warning(f"Failed to scan module {full_module_name}: {e}") - - # Now register all discovered tasks using the original _register_tasks method - if discovered_tasks: - self._register_tasks(discovered_tasks) - - logger.info(f"Registered {total_tasks} tasks from package '{package_name}'") - except ImportError as e: - raise ImportError(f"Could not import package '{package_name}': {e}") - except Exception as e: - raise RuntimeError( - f"Error registering tasks from package '{package_name}': {e}" - ) - - def _register_tasks(self, tasks: Dict[str, Callable]) -> None: - """Register each discovered task with the Dapr runtime using direct registration.""" - for task_name, method in tasks.items(): - # Don't reregister tasks that are already registered - if task_name in self.tasks: - continue - - llm = self._choose_llm_for(method) - logger.debug( - f"Registering task '{task_name}' with llm={getattr(llm, '__class__', None)}" - ) - kwargs = getattr(method, "_task_kwargs", {}) - task_instance = WorkflowTask( - func=method, - description=getattr(method, "_task_description", None), - agent=getattr(method, "_task_agent", None), - llm=llm, - include_chat_history=getattr( - method, "_task_include_chat_history", False - ), - workflow_app=self, - **kwargs, - ) - # Wrap for Dapr invocation - wrapped = self._make_task_wrapper(task_name, method, task_instance) - - # Use direct registration like official Dapr examples - self.wf_runtime.register_activity(wrapped) - self.tasks[task_name] = wrapped - - def _make_task_wrapper( - self, task_name: str, method: Callable, task_instance: WorkflowTask - ) -> Callable: - """Produce the function that Dapr will invoke for each activity.""" - - def run_sync(coro): - # Try to get the running event loop and run until complete - try: - loop = asyncio.get_running_loop() - return loop.run_until_complete(coro) - except RuntimeError: - # If no running loop, create one - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - return loop.run_until_complete(coro) - finally: - loop.close() - - @functools.wraps(method) - def wrapper(ctx: WorkflowActivityContext, *args, **kwargs): - wf_ctx = WorkflowActivityContext(ctx) - try: - call = task_instance(wf_ctx, *args, **kwargs) - if asyncio.iscoroutine(call): - return run_sync(call) - return call - except Exception: - logger.exception(f"Task '{task_name}' failed") - raise - - return wrapper - - # TODO: workflow discovery can also come from dapr runtime - # Python workflows can be registered in a variety of ways, and we need to support all of them. - # This supports decorator-based registration; - # however, there is also manual registration approach. - # See example below: - # def setup_workflow_runtime(): - # wf_runtime = WorkflowRuntime() - # wf_runtime.register_workflow(order_processing_workflow) - # wf_runtime.register_workflow(fulfillment_workflow) - # wf_runtime.register_activity(process_payment) - # wf_runtime.register_activity(send_notification) - # return wf_runtime - - # runtime = setup_workflow_runtime() - # runtime.start() - def _discover_workflows(self) -> Dict[str, Callable]: - """Gather all @workflow-decorated functions and methods.""" - module = sys.modules["__main__"] - wfs: Dict[str, Callable] = {} - for name, fn in inspect.getmembers(module, inspect.isfunction): - if getattr(fn, "_is_workflow", False) and fn.__module__ == module.__name__: - wfs[getattr(fn, "_workflow_name", name)] = fn - for name, method in get_decorated_methods(self, "_is_workflow").items(): - wfs[getattr(method, "_workflow_name", name)] = method - logger.info(f"Discovered workflows: {list(wfs)}") - return wfs - - def _register_workflows(self, wfs: Dict[str, Callable]) -> None: - """Register each discovered workflow with the Dapr runtime.""" - for wf_name, method in wfs.items(): - # Don't reregister workflows that are already registered - if wf_name in self.workflows: - continue - - # Use a closure helper to avoid late-binding capture issues. - def make_wrapped(meth: Callable) -> Callable: - @functools.wraps(meth) - def wrapped(*args, **kwargs): - return meth(*args, **kwargs) - - return wrapped - - decorator = self.wf_runtime.workflow(name=wf_name) - self.workflows[wf_name] = decorator(make_wrapped(method)) - - def resolve_task(self, task: Union[str, Callable]) -> Callable: - """ - Resolves a registered task function by its name or decorated function. - - Args: - task (Union[str, Callable]): The task name or callable function. - - Returns: - Callable: The resolved task function. - - Raises: - AttributeError: If the task is not found. - """ - if isinstance(task, str): - task_name = task - elif callable(task): - task_name = getattr( - task, "_task_name", getattr(task, "__name__", "unknown_task") - ) - else: - raise ValueError(f"Invalid task reference: {task}") - - task_func = self.tasks.get(task_name) - if not task_func: - raise AttributeError(f"Task '{task_name}' not found.") - - return task_func - - def resolve_workflow(self, workflow: Union[str, Callable]) -> Callable: - """ - Resolves a registered workflow function by its name or decorated function. - - Args: - workflow (Union[str, Callable]): The workflow name or callable function. - - Returns: - Callable: The resolved workflow function. - - Raises: - AttributeError: If the workflow is not found. - """ - if isinstance(workflow, str): - workflow_name = workflow # Direct lookup by string name - elif callable(workflow): - workflow_name = getattr( - workflow, - "_workflow_name", - getattr(workflow, "__name__", "unknown_workflow"), - ) - else: - raise ValueError(f"Invalid workflow reference: {workflow}") - - workflow_func = self.workflows.get(workflow_name) - if not workflow_func: - raise AttributeError(f"Workflow '{workflow_name}' not found.") - - return workflow_func - - # NOTE: Workflow instrumentation is applied directly during instrumentor initialization - # since start_runtime() is called in model_post_init - def start_runtime(self): - """Idempotently start the Dapr workflow runtime.""" - if not self.wf_runtime_is_running: - logger.info("Starting workflow runtime.") - self.wf_runtime.start() - self.wf_runtime_is_running = True - - logger.info("Sleeping for 5 seconds to ensure runtime is started.") - time.sleep(5) - - # Sync database state with Dapr workflow status after runtime starts - # This ensures our database reflects the actual state of resumed workflows - self._sync_workflow_state_after_startup() - - # Start monitoring resumed workflows to keep database in sync and handle trace continuity - self._monitor_resumed_workflows() - else: - logger.debug("Workflow runtime already running; skipping.") - - self._ensure_activities_registered() - - def _ensure_activities_registered(self): - """Ensure all workflow activities are registered with the Dapr runtime.""" - # Discover and register tasks and workflows - discovered_tasks = self._discover_tasks() - self._register_tasks(discovered_tasks) - discovered_wfs = self._discover_workflows() - self._register_workflows(discovered_wfs) - logger.debug("Workflow activities registration completed.") - - def _sync_workflow_state_after_startup(self): - """ - Sync database workflow state with actual Dapr workflow status after runtime startup. - This ensures our database reflects the current state of any resumed workflows. - """ - try: - # Only sync if this class has state management capabilities - if ( - not hasattr(self, "state") - or not hasattr(self, "load_state") - or not hasattr(self, "save_state") - ): - logger.debug( - "No state management capabilities, skipping workflow state sync" - ) - return - - self.load_state() - instances = self.state.get("instances", {}) - - logger.debug(f"Found {len(instances)} workflow instances to sync") - - # Sync each instance with Dapr's actual status - for instance_id, instance_data in instances.items(): - try: - # Skip if already completed - end_time = instance_data.get("end_time") - if end_time is not None: - continue - - # Get actual status from Dapr - workflow_state = self.get_workflow_state(instance_id) - if workflow_state: - runtime_status = workflow_state.runtime_status.name - logger.debug( - f"Instance {instance_id}: Dapr status = {runtime_status}" - ) - - # Update our database state based on Dapr's status - if runtime_status.upper() in [ - DaprWorkflowStatus.COMPLETED.value.upper(), - DaprWorkflowStatus.FAILED.value.upper(), - DaprWorkflowStatus.TERMINATED.value.upper(), - ]: - # Mark as completed in our state - instance_data["end_time"] = datetime.now( - timezone.utc - ).isoformat() - instance_data["status"] = runtime_status.lower() - - logger.debug( - f"Marked workflow {instance_id} as {runtime_status.lower()} in database" - ) - elif runtime_status.upper() in [ - DaprWorkflowStatus.RUNNING.value.upper(), - DaprWorkflowStatus.PENDING.value.upper(), - ]: - # Ensure it's marked as running - instance_data["status"] = DaprWorkflowStatus.RUNNING.value - logger.debug(f"Confirmed workflow {instance_id} is running") - else: - logger.warning( - f"Unknown status for workflow {instance_id}: {runtime_status}" - ) - else: - # Workflow no longer exists in Dapr, mark as completed - instance_data["end_time"] = datetime.now( - timezone.utc - ).isoformat() - instance_data["status"] = DaprWorkflowStatus.COMPLETED.value - - logger.debug( - f"Workflow {instance_id} no longer in Dapr, marked as completed" - ) - - except Exception as e: - logger.warning(f"Error syncing workflow {instance_id}: {e}") - continue - - # Save updated state - self.save_state() - logger.debug("Workflow state sync completed") - - except Exception as e: - logger.error(f"Error during workflow state sync: {e}", exc_info=True) - - def _monitor_resumed_workflows(self): - """ - Monitor any resumed workflows in the background to keep database state synchronized - and handle trace continuity. This runs after the initial state sync. - """ - import asyncio - import threading - - def monitor_workflows(): - """Monitor resumed workflows in background thread.""" - try: - # Create event loop for this thread - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - # Run the monitoring - loop.run_until_complete(self._monitor_workflows_async()) - - except Exception as e: - logger.error(f"Error monitoring resumed workflows: {e}", exc_info=True) - finally: - loop.close() - - # Start monitoring in background thread - monitor_thread = threading.Thread(target=monitor_workflows, daemon=True) - monitor_thread.start() - logger.debug("Started background monitoring for resumed workflows") - - async def _monitor_workflows_async(self): - """ - Monitor any running workflows and update database when they complete. - Also handles trace continuity for resumed workflows. - """ - try: - # Only monitor if this class has state management capabilities - if ( - not hasattr(self, "state") - or not hasattr(self, "load_state") - or not hasattr(self, "save_state") - ): - logger.debug( - "No state management capabilities, skipping workflow monitoring" - ) - return - - # Load current state - self.load_state() - instances = ( - getattr(self.state, "instances", {}) - if hasattr(self.state, "instances") - else self.state.get("instances", {}) - ) - - # Find workflows that need trace continuity restoration (resumed workflows only) - resumed_workflows = [ - (instance_id, instance_data) - for instance_id, instance_data in instances.items() - if instance_data.get("end_time") is None - and instance_data.get("status") == DaprWorkflowStatus.SUSPENDED.value - and instance_data.get("trace_context", {}).get( - "needs_agent_span_on_resume" - ) - ] - - if not resumed_workflows: - logger.debug("No resumed workflows found that need trace continuity") - return - - logger.debug( - f"Restoring trace continuity for {len(resumed_workflows)} resumed workflows..." - ) - for instance_id, instance_data in resumed_workflows: - try: - logger.debug( - f"Restoring trace continuity for resumed workflow {instance_id}..." - ) - await self._ensure_trace_continuity(instance_id, instance_data) - except Exception as e: - logger.error( - f"Error restoring trace continuity for workflow {instance_id}: {e}" - ) - - except Exception as e: - logger.error(f"Error in workflow monitoring: {e}", exc_info=True) - - async def _ensure_trace_continuity(self, instance_id: str, instance_data: dict): - """ - Ensure trace continuity for resumed workflows by creating proper AGENT spans. - This is the key fix for missing parent traces in resumed workflows. - """ - try: - stored_trace_context = instance_data.get("trace_context") - - # Check if this workflow needs trace context restored - if stored_trace_context and stored_trace_context.get( - "needs_agent_span_on_resume" - ): - logger.info( - f"Restoring trace context for resumed workflow {instance_id}" - ) - await self._create_agent_span_for_resumed_workflow( - instance_id, instance_data - ) - elif stored_trace_context and stored_trace_context.get("traceparent"): - logger.debug( - f"Restoring trace context for resumed workflow {instance_id}" - ) - - # Store the trace context for workflow tasks to use - from dapr_agents.observability.context_storage import ( - store_workflow_context, - ) - - context_data = { - "traceparent": stored_trace_context.get("traceparent"), - "tracestate": stored_trace_context.get("tracestate"), - "trace_id": stored_trace_context.get("trace_id"), - "span_id": stored_trace_context.get("span_id"), - "instance_id": instance_id, - "resumed": True, - "restored": True, - } - store_workflow_context( - f"__workflow_context_{instance_id}__", context_data - ) - logger.debug(f"Trace context restored for workflow {instance_id}") - - else: - # Create new trace context for resumed workflow without stored context - logger.debug( - f"Creating new trace context for resumed workflow {instance_id}" - ) - agent_name = getattr(self, "name", None) or "DurableAgent" - - from dapr_agents.observability.context_storage import _context_storage - - context_data = _context_storage.create_resumed_workflow_context( - instance_id, agent_name, stored_trace_context - ) - logger.debug(f"New trace context created for workflow {instance_id}") - - except Exception as e: - logger.warning( - f"Error ensuring trace continuity for workflow {instance_id}: {e}" - ) - - # TODO: This needs further work as resumed workflows remain intact on their workflow task traces, - # but the official agent span is not created. - async def _create_agent_span_for_resumed_workflow( - self, instance_id: str, instance_data: dict - ): - """ - Create a proper AGENT span for resumed workflows to restore trace hierarchy. - This ensures resumed workflows have the same trace structure as new workflows. - """ - try: - from dapr_agents.observability.context_storage import store_workflow_context - from opentelemetry import trace - from opentelemetry.trace import set_span_in_context - from opentelemetry.context import Context - - # Get stored trace context - stored_trace_context = instance_data.get("trace_context", {}) - - if stored_trace_context and stored_trace_context.get("traceparent"): - # Parse the stored traceparent to restore the original trace - traceparent = stored_trace_context.get("traceparent") - trace_id_hex = traceparent.split("-")[1] - parent_span_id_hex = traceparent.split("-")[2] - trace_id = int(trace_id_hex, 16) - parent_span_id = int(parent_span_id_hex, 16) - - # Create span context from stored trace - from opentelemetry.trace import SpanContext, TraceFlags - - parent_span_context = SpanContext( - trace_id=trace_id, - span_id=parent_span_id, - is_remote=True, - trace_flags=TraceFlags(0x01), # Sampled - ) - - parent_context = set_span_in_context( - trace.NonRecordingSpan(parent_span_context), Context() - ) - - # Get tracer and create AGENT span as child of the original trace - tracer = trace.get_tracer(__name__) - agent_name = getattr(self, "name", "DurableAgent") - workflow_name = instance_data.get("workflow_name", "AgenticWorkflow") - span_name = f"{agent_name}.{workflow_name}" - - # Create the AGENT span that will show up in the trace - agent_span = tracer.start_span( - span_name, - context=parent_context, - kind=trace.SpanKind.INTERNAL, - attributes={ - "openinference.span.kind": "AGENT", - "workflow.instance_id": instance_id, - "agent.name": agent_name, - "workflow.name": workflow_name, - "workflow.resumed": True, - "input.value": instance_data.get("input", ""), - "input.mime_type": "text/plain", - }, - ) - - # Make this span the current context for child spans - agent_span_context = agent_span.get_span_context() - context_data = { - "traceparent": f"00-{format(agent_span_context.trace_id, '032x')}-{format(agent_span_context.span_id, '016x')}-01", - "tracestate": stored_trace_context.get("tracestate", ""), - "trace_id": format(agent_span_context.trace_id, "032x"), - "span_id": format(agent_span_context.span_id, "016x"), - "instance_id": instance_id, - "resumed": True, - "restored": True, - "agent_span": agent_span, # Store span reference for lifecycle management - } - - # Store context for child spans to use - store_workflow_context( - f"__workflow_context_{instance_id}__", context_data - ) - - # Remove the resume flag and save state - stored_trace_context.pop("needs_agent_span_on_resume", None) - self.save_state() - - logger.debug( - f"Created AGENT span '{span_name}' for resumed workflow {instance_id}" - ) - - else: - logger.warning( - f"No valid trace context found for resumed workflow {instance_id}" - ) - - except Exception as e: - logger.error( - f"Failed to create agent span for resumed workflow {instance_id}: {e}" - ) - - # TODO: This needs further work as resumed workflows remain intact on their workflow task traces, - # but the official agent span is not created. - def _close_resumed_workflow_span(self, instance_id: str, final_output: str = None): - """ - Close the agent span for a resumed workflow when it completes. - This should be called when the workflow finishes execution. - """ - try: - from dapr_agents.observability.context_storage import get_workflow_context - from opentelemetry.trace import Status, StatusCode - - context = get_workflow_context(f"__workflow_context_{instance_id}__") - if context and context.get("resumed") and "agent_span" in context: - agent_span = context["agent_span"] - if final_output: - agent_span.set_attribute("output.value", str(final_output)[:1000]) - agent_span.set_attribute("output.mime_type", "text/plain") - - agent_span.set_status(Status(StatusCode.OK)) - agent_span.end() - - logger.debug( - f"Closed AGENT span for completed resumed workflow {instance_id}" - ) - context.pop("agent_span", None) - - except Exception as e: - logger.warning( - f"Failed to close resumed workflow span for {instance_id}: {e}" - ) - - def stop_runtime(self): - """Idempotently stop the Dapr workflow runtime.""" - if self.wf_runtime_is_running: - logger.info("Stopping workflow runtime.") - self.wf_runtime.shutdown() - self.wf_runtime_is_running = False - else: - logger.debug("Workflow runtime already stopped; skipping.") - - def run_workflow( - self, workflow: Union[str, Callable], input: Union[str, Dict[str, Any]] = None - ) -> str: - """ - Starts a workflow execution. - - Args: - workflow (Union[str, Callable]): The workflow name or callable. - input (Union[str, Dict[str, Any]], optional): Input data for the workflow. - - Returns: - str: The instance ID of the started workflow. - - Raises: - Exception: If workflow execution fails. - """ - try: - # Generate unique instance ID - instance_id = uuid.uuid4().hex - - # Resolve the workflow function - workflow_func = self.resolve_workflow(workflow) - - # Schedule workflow execution - instance_id = self.wf_client.schedule_new_workflow( - workflow=workflow_func, input=input, instance_id=instance_id - ) - - logger.info(f"Started workflow with instance ID {instance_id}.") - return instance_id - except Exception as e: - logger.error(f"Failed to start workflow {workflow}: {e}") - raise - - async def monitor_workflow_state(self, instance_id: str) -> Optional[WorkflowState]: - """ - Monitors and retrieves the final state of a workflow instance. - - Args: - instance_id (str): The workflow instance ID. - - Returns: - Optional[WorkflowState]: The final state of the workflow or None if not found. - """ - try: - state: WorkflowState = await asyncio.to_thread( - self.wait_for_workflow_completion, - instance_id, - fetch_payloads=True, - timeout_in_seconds=self.timeout, - ) - - if not state: - logger.error(f"Workflow '{instance_id}' not found.") - return None - - return state - except TimeoutError: - logger.error(f"Workflow '{instance_id}' monitoring timed out.") - return None - except Exception as e: - logger.error(f"Error retrieving workflow state for '{instance_id}': {e}") - return None - - async def monitor_workflow_completion(self, instance_id: str) -> None: - """ - Monitors the execution of a workflow and logs its final state. - - Args: - instance_id (str): The workflow instance ID. - """ - try: - logger.info(f"Monitoring workflow '{instance_id}'...") - - # Retrieve workflow state - state: WorkflowState = await self.monitor_workflow_state(instance_id) - if not state: - return # Error already logged in monitor_workflow_state - - # Extract relevant details - workflow_status = state.runtime_status.name - failure_details = ( - state.failure_details - ) # This is an object, not a dictionary - - if workflow_status.upper() == DaprWorkflowStatus.COMPLETED.value.upper(): - logger.info( - f"Workflow '{instance_id}' completed successfully. Status: {workflow_status}." - ) - - if state.serialized_output: - logger.debug( - f"Output: {json.dumps(state.serialized_output, indent=2)}" - ) - - elif workflow_status.upper() in ( - DaprWorkflowStatus.FAILED.value.upper(), - "ABORTED", - ): - # Ensure `failure_details` exists before accessing attributes - error_type = getattr(failure_details, "error_type", "Unknown") - message = getattr(failure_details, "message", "No message provided") - stack_trace = getattr( - failure_details, "stack_trace", "No stack trace available" - ) - - logger.error( - f"Workflow '{instance_id}' failed.\n" - f"Error Type: {error_type}\n" - f"Message: {message}\n" - f"Stack Trace:\n{stack_trace}\n" - f"Input: {json.dumps(state.serialized_input, indent=2)}" - ) - - self.terminate_workflow(instance_id) - - else: - logger.warning( - f"Workflow '{instance_id}' ended with status '{workflow_status}'.\n" - f"Input: {json.dumps(state.serialized_input, indent=2)}" - ) - - logger.debug( - f"Workflow Details: Instance ID={state.instance_id}, Name={state.name}, " - f"Created At={state.created_at}, Last Updated At={state.last_updated_at}" - ) - - except Exception as e: - logger.error( - f"Error monitoring workflow '{instance_id}': {e}", exc_info=True - ) - finally: - logger.info(f"Finished monitoring workflow '{instance_id}'.") - - async def run_and_monitor_workflow_async( - self, - workflow: Union[str, Callable], - input: Optional[Union[str, Dict[str, Any]]] = None, - ) -> Optional[str]: - """ - Runs a workflow asynchronously and monitors its completion. - - Args: - workflow (Union[str, Callable]): The workflow name or callable. - input (Optional[Union[str, Dict[str, Any]]]): The workflow input payload. - - Returns: - Optional[str]: The serialized output of the workflow. - """ - instance_id = None - try: - # Off-load the potentially blocking run_workflow call to a thread. - instance_id = await asyncio.to_thread(self.run_workflow, workflow, input) - - logger.debug( - f"Workflow '{workflow}' started with instance ID: {instance_id}" - ) - - # Await the asynchronous monitoring of the workflow state. - state = await self.monitor_workflow_state(instance_id) - - if not state: - raise RuntimeError(f"Workflow '{instance_id}' not found.") - - workflow_status = ( - DaprWorkflowStatus[state.runtime_status.name] - if state.runtime_status.name in DaprWorkflowStatus.__members__ - else DaprWorkflowStatus.UNKNOWN - ) - - if workflow_status == DaprWorkflowStatus.COMPLETED: - logger.info(f"Workflow '{instance_id}' completed successfully!") - logger.debug(f"Output: {state.serialized_output}") - else: - logger.error( - f"Workflow '{instance_id}' ended with status '{workflow_status.value}'." - ) - - # Return the final state output - return state.serialized_output - - except Exception as e: - if instance_id: - logger.error(f"Error during workflow '{instance_id}': {e}") - else: - logger.error(f"Error starting workflow '{workflow}': {e}") - raise - finally: - if instance_id: - logger.info(f"Finished workflow with Instance ID: {instance_id}.") - else: - logger.info(f"Finished workflow attempt for '{workflow}'.") - - def run_and_monitor_workflow_sync( - self, - workflow: Union[str, Callable], - input: Optional[Union[str, Dict[str, Any]]] = None, - ) -> Optional[str]: - """ - Synchronous wrapper for running and monitoring a workflow. - This allows calling code that is not async to still run the workflow. - - Args: - workflow (Union[str, Callable]): The workflow name or callable. - input (Optional[Union[str, Dict[str, Any]]]): The workflow input payload. - - Returns: - Optional[str]: The serialized output of the workflow. - """ - return asyncio.run(self.run_and_monitor_workflow_async(workflow, input)) - - def terminate_workflow( - self, instance_id: str, *, output: Optional[Any] = None - ) -> None: - """ - Terminates a running workflow. - - Args: - instance_id (str): The workflow instance ID. - output (Optional[Any]): Optional output to set for the terminated workflow. - - Raises: - Exception: If the termination fails. - """ - try: - self.wf_client.terminate_workflow(instance_id=instance_id, output=output) - logger.info( - f"Successfully terminated workflow '{instance_id}' with output: {output}" - ) - except Exception as e: - logger.error(f"Failed to terminate workflow '{instance_id}'. Error: {e}") - raise Exception(f"Error terminating workflow '{instance_id}': {e}") - - def get_workflow_state(self, instance_id: str) -> Optional[Any]: - """ - Retrieves the state of a workflow instance. - - Args: - instance_id (str): The workflow instance ID. - - Returns: - Optional[Any]: The workflow state if found. - - Raises: - RuntimeError: If retrieving the state fails. - """ - try: - state = self.wf_client.get_workflow_state(instance_id) - logger.info( - f"Retrieved state for workflow {instance_id}: {state.runtime_status}." - ) - return state - except Exception as e: - logger.error(f"Failed to retrieve workflow state for {instance_id}: {e}") - return None - - def wait_for_workflow_completion( - self, - instance_id: str, - fetch_payloads: bool = True, - timeout_in_seconds: int = 120, - ) -> Optional[WorkflowState]: - """ - Waits for a workflow to complete and retrieves its state. - - Args: - instance_id (str): The workflow instance ID. - fetch_payloads (bool): Whether to fetch input/output payloads. - timeout_in_seconds (int): Maximum wait time in seconds. - - Returns: - Optional[WorkflowState]: The final state or None if it times out. - - Raises: - RuntimeError: If waiting for completion fails. - """ - try: - state = self.wf_client.wait_for_workflow_completion( - instance_id, - fetch_payloads=fetch_payloads, - timeout_in_seconds=timeout_in_seconds, - ) - if state: - logger.info( - f"Workflow {instance_id} completed with status: {state.runtime_status}." - ) - else: - logger.warning( - f"Workflow {instance_id} did not complete within the timeout period." - ) - return state - except Exception as e: - logger.error( - f"Error while waiting for workflow {instance_id} completion: {e}" - ) - return None - - def raise_workflow_event( - self, instance_id: str, event_name: str, *, data: Any | None = None - ) -> None: - """ - Raises an event for a running workflow instance. - - Args: - instance_id (str): The workflow instance ID. - event_name (str): The name of the event to raise. - data (Any | None): Optional event data. - - Raises: - Exception: If raising the event fails. - """ - try: - logger.info( - f"Raising workflow event '{event_name}' for instance '{instance_id}'" - ) - # Ensure data is in a serializable format - if is_pydantic_model(type(data)): - # Convert Pydantic model to dict - data = data.model_dump() - # Raise the event using the Dapr workflow client with serialized data - self.wf_client.raise_workflow_event( - instance_id=instance_id, event_name=event_name, data=data - ) - logger.info( - f"Successfully raised workflow event '{event_name}' for instance '{instance_id}'!" - ) - except Exception as e: - logger.error( - f"Error raising workflow event '{event_name}' for instance '{instance_id}'. " - f"Data: {data}, Error: {e}" - ) - raise Exception( - f"Failed to raise workflow event '{event_name}' for instance '{instance_id}': {str(e)}" - ) - - def when_all(self, tasks: List[dtask.Task[T]]) -> dtask.WhenAllTask[T]: - """ - Waits for all given tasks to complete. - - Args: - tasks (List[dtask.Task[T]]): The tasks to wait for. - - Returns: - dtask.WhenAllTask[T]: A task that completes when all tasks finish. - """ - return dtask.when_all(tasks) - - def when_any(self, tasks: List[dtask.Task[T]]) -> dtask.WhenAnyTask: - """ - Waits for any one of the given tasks to complete. - - Args: - tasks (List[dtask.Task[T]]): The tasks to monitor. - - Returns: - dtask.WhenAnyTask: A task that completes when the first task finishes. - """ - return dtask.when_any(tasks) diff --git a/dapr_agents/workflow/decorators/__init__.py b/dapr_agents/workflow/decorators/__init__.py index 5809f6c3..62bd9942 100644 --- a/dapr_agents/workflow/decorators/__init__.py +++ b/dapr_agents/workflow/decorators/__init__.py @@ -1,13 +1,9 @@ -from .core import task, workflow -from .fastapi import route -from .messaging import message_router +from .routers import message_router, http_router from .activities import llm_activity, agent_activity __all__ = [ - "workflow", - "task", - "route", "message_router", + "http_router", "llm_activity", "agent_activity", ] diff --git a/dapr_agents/workflow/decorators/core.py b/dapr_agents/workflow/decorators/core.py deleted file mode 100644 index f3971968..00000000 --- a/dapr_agents/workflow/decorators/core.py +++ /dev/null @@ -1,144 +0,0 @@ -import logging -import functools -import warnings -from typing import Any, Callable, Optional - - -_TASK_DEPRECATION_MESSAGE = ( - "@task is deprecated and will be removed in a future release. " - "Switch to native Dapr activity registration (WorkflowRuntime.activity) and, " - "for LLM/agent helpers, prefer the new @llm_activity or @agent_activity decorators." -) - -_WORKFLOW_DEPRECATION_MESSAGE = ( - "@workflow is deprecated and will be removed in a future release. " - "Switch to native Dapr workflow registration (WorkflowRuntime.workflow) and combine with " - "@message_router or other decorators as needed." -) - - -def task( - func: Optional[Callable] = None, - *, - name: Optional[str] = None, - description: Optional[str] = None, - agent: Optional[Any] = None, - llm: Optional[Any] = None, - include_chat_history: bool = False, - **task_kwargs, -) -> Callable: - """ - Decorator to register a function as a Dapr workflow task. - - This allows configuring a task with an LLM, agent, chat history, and other options. - All additional keyword arguments are stored and forwarded to the WorkflowTask constructor. - - Args: - func (Optional[Callable]): The function to wrap. Can also be used as `@task(...)`. - name (Optional[str]): Optional custom task name. Defaults to the function name. - description (Optional[str]): Optional prompt template for LLM-based execution. - agent (Optional[Any]): Optional agent to handle the task instead of an LLM or function. - llm (Optional[Any]): Optional LLM client used to execute the task. - include_chat_history (bool): Whether to include prior messages in LLM calls. - **task_kwargs: Additional keyword arguments to forward to `WorkflowTask`. - - Returns: - Callable: The decorated function with attached task metadata. - """ - - if isinstance(func, str): - # Allow syntax: @task("some description") - description = func - func = None - - def decorator(f: Callable) -> Callable: - if not callable(f): - raise ValueError(f"@task must be applied to a function, got {type(f)}.") - - warnings.warn( - _TASK_DEPRECATION_MESSAGE, - DeprecationWarning, - stacklevel=2, - ) - - # Attach task metadata - f._is_task = True - f._task_name = name or f.__name__ - f._task_description = description - f._task_agent = agent - f._task_llm = llm - f._task_include_chat_history = include_chat_history - f._explicit_llm = llm is not None or bool(description) - f._task_kwargs = task_kwargs - - # wrap it so we can log, validate, etc., without losing signature/docs - @functools.wraps(f) - def wrapper(*args, **kwargs): - logging.getLogger(__name__).debug(f"Calling task '{f._task_name}'") - return f(*args, **kwargs) - - # copy our metadata onto the wrapper so discovery still sees it - for attr in ( - "_is_task", - "_task_name", - "_task_description", - "_task_agent", - "_task_llm", - "_task_include_chat_history", - "_explicit_llm", - "_task_kwargs", - ): - setattr(wrapper, attr, getattr(f, attr)) - - return wrapper - - return ( - decorator(func) if func else decorator - ) # Supports both @task and @task(name="custom") - - -def workflow( - func: Optional[Callable] = None, *, name: Optional[str] = None -) -> Callable: - """ - Decorator to register a function as a Dapr workflow. - - - Attaches workflow metadata for discovery and registration. - - Works seamlessly with standalone functions, instance methods, and class methods. - - Args: - func (Callable, optional): Function to be decorated as a workflow. - name (Optional[str]): The name to register the workflow with. - - Returns: - Callable: The decorated function with workflow metadata for Dapr compatibility. - """ - - def decorator(f: Callable) -> Callable: - """ - Minimal workflow decorator for Dapr workflows. - - Sets workflow metadata for discovery and registration. - """ - if not callable(f): - raise ValueError(f"@workflow must be applied to a function, got {type(f)}.") - - warnings.warn( - _WORKFLOW_DEPRECATION_MESSAGE, - DeprecationWarning, - stacklevel=2, - ) - - f._is_workflow = True - f._workflow_name = name or f.__name__ - - @functools.wraps(f) - def wrapper(*args, **kwargs): - return f(*args, **kwargs) - - wrapper._is_workflow = True - wrapper._workflow_name = f._workflow_name - return wrapper - - return ( - decorator(func) if func else decorator - ) # Supports both `@workflow` and `@workflow(name="custom")` diff --git a/dapr_agents/workflow/decorators/fastapi.py b/dapr_agents/workflow/decorators/fastapi.py deleted file mode 100644 index a72f0776..00000000 --- a/dapr_agents/workflow/decorators/fastapi.py +++ /dev/null @@ -1,23 +0,0 @@ -def route(path: str, method: str = "GET", **kwargs): - """ - Decorator to mark an instance method as a FastAPI route. - - Args: - path (str): The URL path to bind this route to. - method (str): The HTTP method to use (e.g., 'GET', 'POST'). Defaults to 'GET'. - **kwargs: Additional arguments passed to FastAPI's `add_api_route`. - - Example: - @route("/status", method="GET", summary="Show status", tags=["monitoring"]) - def health(self): - return {"ok": True} - """ - - def decorator(func): - func._is_fastapi_route = True - func._route_path = path - func._route_method = method.upper() - func._route_kwargs = kwargs - return func - - return decorator diff --git a/dapr_agents/workflow/decorators/messaging.py b/dapr_agents/workflow/decorators/messaging.py deleted file mode 100644 index 524d5bd0..00000000 --- a/dapr_agents/workflow/decorators/messaging.py +++ /dev/null @@ -1,101 +0,0 @@ -import logging -import warnings -from copy import deepcopy -from typing import Any, Callable, Optional, get_type_hints - -from dapr_agents.workflow.utils.core import is_valid_routable_model -from dapr_agents.workflow.utils.messaging import extract_message_models - -logger = logging.getLogger(__name__) - -_MESSAGE_ROUTER_DEPRECATION_MESSAGE = ( - "@message_router (legacy version from dapr_agents.workflow.decorators.messaging) " - "is deprecated and will be removed in a future release. " - "Please migrate to the updated decorator in " - "`dapr_agents.workflow.decorators.routers`, which supports " - "Union types, forward references, and explicit Dapr workflow integration." -) - - -def message_router( - func: Optional[Callable[..., Any]] = None, - *, - pubsub: Optional[str] = None, - topic: Optional[str] = None, - dead_letter_topic: Optional[str] = None, - broadcast: bool = False, -) -> Callable[[Callable[..., Any]], Callable[..., Any]]: - """ - [DEPRECATED] Legacy decorator for registering message handlers by inspecting type hints - on the 'message' argument. - - This decorator: - - Extracts the expected message model type from function annotations. - - Stores metadata for routing messages by message schema instead of `event.type`. - - Supports broadcast messaging. - - Supports Union[...] and multiple models. - - Args: - func (Optional[Callable]): The function to decorate. - pubsub (Optional[str]): The name of the pub/sub component. - topic (Optional[str]): The topic name for the handler. - dead_letter_topic (Optional[str]): Dead-letter topic for failed messages. - broadcast (bool): If True, the message is broadcast to all agents. - - Returns: - Callable: The decorated function with additional metadata. - """ - - def decorator(f: Callable[..., Any]) -> Callable[..., Any]: - warnings.warn( - _MESSAGE_ROUTER_DEPRECATION_MESSAGE, - DeprecationWarning, - stacklevel=2, - ) - - is_workflow = hasattr(f, "_is_workflow") - workflow_name = getattr(f, "_workflow_name", None) - - type_hints = get_type_hints(f) - raw_hint = type_hints.get("message", None) - - message_models = extract_message_models(raw_hint) - - if not message_models: - raise ValueError( - f"Message handler '{f.__name__}' must have a 'message' parameter with a valid type hint." - ) - - for model in message_models: - if not is_valid_routable_model(model): - raise TypeError( - f"Handler '{f.__name__}' has unsupported message type: {model}" - ) - - logger.debug( - "@message_router (legacy): '%s' => models %s", - f.__name__, - [m.__name__ for m in message_models], - ) - - # Attach metadata for later registration - f._is_message_handler = True - f._message_router_data = deepcopy( - { - "pubsub": pubsub, - "topic": topic, - "dead_letter_topic": dead_letter_topic - or (f"{topic}_DEAD" if topic else None), - "is_broadcast": broadcast, - "message_schemas": message_models, - "message_types": [model.__name__ for model in message_models], - } - ) - - if is_workflow: - f._is_workflow = True - f._workflow_name = workflow_name - - return f - - return decorator(func) if func else decorator diff --git a/dapr_agents/workflow/decorators/routers.py b/dapr_agents/workflow/decorators/routers.py index 3b803cbd..1f46d88a 100644 --- a/dapr_agents/workflow/decorators/routers.py +++ b/dapr_agents/workflow/decorators/routers.py @@ -1,20 +1,16 @@ from __future__ import annotations -import inspect import logging from copy import deepcopy -from typing import ( - Any, - Callable, - Optional, - get_type_hints, -) +from typing import Any, Callable, List, Literal, Optional, Type, get_type_hints from dapr_agents.workflow.utils.core import is_supported_model from dapr_agents.workflow.utils.routers import extract_message_models logger = logging.getLogger(__name__) +HttpMethod = Literal["GET", "POST", "PUT", "PATCH", "DELETE"] + def message_router( func: Optional[Callable[..., Any]] = None, @@ -23,67 +19,54 @@ def message_router( topic: Optional[str] = None, dead_letter_topic: Optional[str] = None, broadcast: bool = False, + message_model: Optional[Any] = None, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """ - Decorate a message handler with routing metadata. - - The handler must accept a parameter named `message`. Its type hint defines the - expected payload model(s), e.g.: - - @message_router(pubsub="pubsub", topic="orders") - def on_order(message: OrderCreated): ... - - @message_router(pubsub="pubsub", topic="events") - def on_event(message: Union[Foo, Bar]): ... + Tag a callable as a **Pub/Sub → Workflow** entry with routing + schema metadata. Args: - func: (optional) bare-decorator form support. - pubsub: Name of the Dapr pub/sub component (required when used with args). - topic: Topic name to subscribe to (required when used with args). - dead_letter_topic: Optional dead-letter topic (defaults to f"{topic}_DEAD"). - broadcast: Optional flag you can use downstream for fan-out semantics. + func (Optional[Callable[..., Any]]): + The function to decorate (if used without parentheses). + pubsub (Optional[str]): + The name of the Dapr pub/sub component. Optional when wiring via `PubSubRouteSpec`. + topic (Optional[str]): + The pub/sub topic to subscribe to. Optional when wiring via `PubSubRouteSpec`. + dead_letter_topic (Optional[str]): + The dead-letter topic to publish failed messages to. + broadcast (bool): + Whether to treat this as a broadcast subscription. + message_model (Optional[Any]): + The message model class or Union[...] to use for validation. Returns: - The original function tagged with `_message_router_data`. + Callable[[Callable[..., Any]], Callable[..., Any]]: + The decorated function. """ def decorator(f: Callable[..., Any]) -> Callable[..., Any]: - # Validate required kwargs only when decorator is used with args - if pubsub is None or topic is None: - raise ValueError( - "`pubsub` and `topic` are required when using @message_router with arguments." - ) - - sig = inspect.signature(f) - if "message" not in sig.parameters: - raise ValueError(f"'{f.__name__}' must have a 'message' parameter.") - - # Resolve forward refs under PEP 563 / future annotations - try: - hints = get_type_hints(f, globalns=f.__globals__) - except Exception: - logger.debug( - "Failed to fully resolve type hints for %s", f.__name__, exc_info=True - ) - hints = getattr(f, "__annotations__", {}) or {} + # Resolve message model(s) + if message_model is None: + # Back-compat fallback: try to infer from a `message` param if present, but not required. + try: + hints = get_type_hints(f, globalns=f.__globals__) + except Exception: + logger.debug( + "Failed to resolve type hints for %s", f.__name__, exc_info=True + ) + hints = getattr(f, "__annotations__", {}) or {} + inferred = hints.get("message") + models = extract_message_models(inferred) if inferred else [] + else: + models = extract_message_models(message_model) - raw_hint = hints.get("message") - if raw_hint is None: - raise TypeError( - f"'{f.__name__}' must type-hint the 'message' parameter " - "(e.g., 'message: MyModel' or 'message: Union[A, B]')" - ) - - models = extract_message_models(raw_hint) if not models: raise TypeError( - f"Unsupported or unresolved message type for '{f.__name__}': {raw_hint!r}" + "`@message_router` requires `message_model` (class or Union[...])." ) - # Optional early validation of supported schema kinds for m in models: if not is_supported_model(m): - raise TypeError(f"Unsupported model type in '{f.__name__}': {m!r}") + raise TypeError(f"Unsupported model type: {m!r}") data = { "pubsub": pubsub, @@ -91,11 +74,10 @@ def decorator(f: Callable[..., Any]) -> Callable[..., Any]: "dead_letter_topic": dead_letter_topic or (f"{topic}_DEAD" if topic else None), "is_broadcast": broadcast, - "message_schemas": models, # list[type] - "message_types": [m.__name__ for m in models], # list[str] + "message_schemas": models, + "message_types": [m.__name__ for m in models], } - # Attach metadata; deepcopy for defensive isolation setattr(f, "_is_message_handler", True) setattr(f, "_message_router_data", deepcopy(data)) @@ -109,5 +91,82 @@ def decorator(f: Callable[..., Any]) -> Callable[..., Any]: ) return f - # Support both @message_router(...) and bare @message_router usage + return decorator if func is None else decorator(func) + + +def http_router( + func: Optional[Callable[..., Any]] = None, + *, + path: Optional[str] = None, + method: HttpMethod = "POST", + summary: Optional[str] = None, + tags: Optional[List[str]] = None, + response_model: Optional[Type[Any]] = None, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Tag a callable as a **plain-HTTP** endpoint with schema metadata for its JSON body. + + Args: + func (Optional[Callable[..., Any]]): + The function to decorate (if used without parentheses). + path (Optional[str]): + The HTTP path to route to. + method (HttpMethod): + The HTTP method to route to. + summary (Optional[str]): + A short summary of the endpoint. + tags (Optional[List[str]]): + A list of tags for grouping endpoints. + response_model (Optional[Type[Any]]): + The response model class to use for validation. + + Returns: + Callable[[Callable[..., Any]], Callable[..., Any]]: + The decorated function. + """ + + def decorator(f: Callable[..., Any]) -> Callable[..., Any]: + if path is None: + raise ValueError("`@http_router` requires `path`.") + method_upper = method.upper() + + try: + hints = get_type_hints(f, globalns=f.__globals__) + except Exception: + logger.debug( + "Failed to fully resolve type hints for %s", f.__name__, exc_info=True + ) + hints = getattr(f, "__annotations__", {}) or {} + + raw_hint = hints.get("request") + models = extract_message_models(raw_hint) if raw_hint is not None else [] + if not models: + raise TypeError( + "`@http_router` requires a type-hinted `request` parameter." + ) + + for m in models: + if not is_supported_model(m): + raise TypeError(f"Unsupported request model type: {m!r}") + + data = { + "path": path, + "method": method_upper, + "summary": summary, + "tags": (tags or []), + "response_model": response_model, + "request_schemas": models, + "request_type_names": [m.__name__ for m in models], + } + setattr(f, "_is_http_handler", True) + setattr(f, "_http_route_data", deepcopy(data)) + logger.debug( + "@http_router: '%s' => models %s (%s %s)", + f.__name__, + [m.__name__ for m in models], + method_upper, + path, + ) + return f + return decorator if func is None else decorator(func) diff --git a/dapr_agents/workflow/mixins/__init__.py b/dapr_agents/workflow/mixins/__init__.py deleted file mode 100644 index 8a414885..00000000 --- a/dapr_agents/workflow/mixins/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .messaging import MessagingMixin -from .pubsub import PubSubMixin -from .service import ServiceMixin -from .state import StateManagementMixin - -__all__ = [ - "StateManagementMixin", - "ServiceMixin", - "MessagingMixin", - "PubSubMixin", -] diff --git a/dapr_agents/workflow/mixins/messaging.py b/dapr_agents/workflow/mixins/messaging.py deleted file mode 100644 index 990c9fc8..00000000 --- a/dapr_agents/workflow/mixins/messaging.py +++ /dev/null @@ -1,87 +0,0 @@ -import logging -from typing import Union - -from pydantic import BaseModel - -logger = logging.getLogger(__name__) - - -class MessagingMixin: - """ - Mixin providing agent messaging capabilities, including broadcast, direct messaging, and agent metadata management. - """ - - async def broadcast_message( - self, - message: Union[BaseModel, dict], - exclude_orchestrator: bool = False, - **kwargs, - ) -> None: - """ - Send a message to all registered agents. - - Args: - message: The message content as a Pydantic model or dictionary. - exclude_orchestrator: If True, excludes orchestrators from receiving the message. - **kwargs: Additional metadata fields to include in the message. - """ - try: - # Skip broadcasting if no broadcast topic is set - if not self.broadcast_topic_name: - logger.info(f"{self.name} has no broadcast topic; skipping broadcast.") - return - # Skip broadcasting if no agents are registered - agents_metadata = self.get_agents_metadata( - exclude_orchestrator=exclude_orchestrator - ) - if not agents_metadata: - logger.warning("No agents available for broadcast.") - return - # Broadcast the message to all agents - logger.info( - f"{self.name} broadcasting message to {self.broadcast_topic_name}." - ) - await self.publish_event_message( - topic_name=self.broadcast_topic_name, - pubsub_name=self.message_bus_name, - source=self.name, - message=message, - **kwargs, - ) - logger.debug(f"{self.name} broadcasted message.") - except Exception as e: - logger.error(f"Failed to broadcast message: {e}", exc_info=True) - - async def send_message_to_agent( - self, name: str, message: Union[BaseModel, dict], **kwargs - ) -> None: - """ - Send a message to a specific agent. - - Args: - name: The name of the target agent. - message: The message content as a Pydantic model or dictionary. - **kwargs: Additional metadata fields to include in the message. - """ - try: - agents_metadata = self.get_agents_metadata() - if name not in agents_metadata: - logger.warning( - f"Target '{name}' is not registered as an agent. Skipping message send." - ) - return - - agent_metadata = agents_metadata[name] - logger.info(f"{self.name} sending message to agent '{name}'.") - await self.publish_event_message( - topic_name=agent_metadata["topic_name"], - pubsub_name=agent_metadata["pubsub_name"], - source=self.name, - message=message, - **kwargs, - ) - logger.debug(f"{self.name} sent message to agent '{name}'.") - except Exception as e: - logger.error( - f"Failed to send message to agent '{name}': {e}", exc_info=True - ) diff --git a/dapr_agents/workflow/mixins/pubsub.py b/dapr_agents/workflow/mixins/pubsub.py deleted file mode 100644 index a550841a..00000000 --- a/dapr_agents/workflow/mixins/pubsub.py +++ /dev/null @@ -1,403 +0,0 @@ -import logging -import json -import asyncio -import inspect -import threading -import functools -from dataclasses import is_dataclass, asdict -from typing import Optional, Any, Dict, Union, Callable - -from pydantic import BaseModel - -from dapr.aio.clients import DaprClient -from dapr.aio.clients.grpc.subscription import Subscription -from dapr.clients.grpc._response import TopicEventResponse -from dapr.clients.grpc.subscription import StreamInactiveError -from dapr.common.pubsub.subscription import StreamCancelledError, SubscriptionMessage -from dapr_agents.workflow.utils.messaging import ( - extract_cloudevent_data, - validate_message_model, -) -from dapr_agents.workflow.utils.core import ( - get_decorated_methods, - is_pydantic_model, - is_valid_routable_model, -) - -logger = logging.getLogger(__name__) - - -class PubSubMixin: - """ - Mixin providing Dapr-based pub/sub messaging, event publishing, and dynamic message routing. - - Features: - - Publishes messages and events to Dapr topics with optional CloudEvent metadata. - - Registers message handlers dynamically using decorated methods. - - Routes incoming messages to handlers based on CloudEvent `type` and message schema. - - Supports Pydantic models, dataclasses, and dictionaries as message payloads. - - Handles asynchronous message processing and workflow invocation. - - Manages topic subscriptions and message dispatch via Dapr client. - """ - - async def serialize_message(self, message: Any) -> str: - """ - Serializes a message to JSON format. - - Args: - message (Any): The message content to serialize. - - Returns: - str: JSON string of the message. - - Raises: - ValueError: If the message is not serializable. - """ - try: - return json.dumps(message if message is not None else {}) - except TypeError as te: - logger.error(f"Failed to serialize message: {message}. Error: {te}") - raise ValueError(f"Message contains non-serializable data: {te}") - - async def publish_message( - self, - pubsub_name: str, - topic_name: str, - message: Any, - metadata: Optional[Dict[str, Any]] = None, - ) -> None: - """ - Publishes a message to a specific topic with optional metadata. - - Args: - pubsub_name (str): The pub/sub component to use. - topic_name (str): The topic to publish the message to. - message (Any): The message content, can be None or any JSON-serializable type. - metadata (Optional[Dict[str, Any]]): Additional metadata to include in the publish event. - - Raises: - ValueError: If the message contains non-serializable data. - Exception: If publishing the message fails. - """ - try: - json_message = await self.serialize_message(message) - - # TODO: retry publish should be configurable - async with DaprClient() as client: - await client.publish_event( - pubsub_name=pubsub_name or self.message_bus_name, - topic_name=topic_name, - data=json_message, - data_content_type="application/json", - publish_metadata=metadata or {}, - ) - - logger.debug( - f"Message successfully published to topic '{topic_name}' on pub/sub '{pubsub_name}'." - ) - logger.debug(f"Serialized Message: {json_message}, Metadata: {metadata}") - except Exception as e: - logger.error( - f"Error publishing message to topic '{topic_name}' on pub/sub '{pubsub_name}'. " - f"Message: {message}, Metadata: {metadata}, Error: {e}" - ) - raise Exception( - f"Failed to publish message to topic '{topic_name}' on pub/sub '{pubsub_name}': {str(e)}" - ) - - async def publish_event_message( - self, - topic_name: str, - pubsub_name: str, - source: str, - message: Union[BaseModel, dict, Any], - message_type: Optional[str] = None, - **kwargs, - ) -> None: - """ - Publishes an event message to a specified topic with dynamic metadata. - - Args: - topic_name (str): The topic to publish the message to. - pubsub_name (str): The pub/sub component to use. - source (str): The source of the message (e.g., service or agent name). - message (Union[BaseModel, dict, dataclass, Any]): The message content, as a Pydantic model, dictionary, or dataclass instance. - message_type (Optional[str]): The type of the message. Required if `message` is a dictionary. - **kwargs: Additional metadata fields to include in the message. - """ - if isinstance(message, BaseModel): - message_type = message_type or message.__class__.__name__ - message_dict = message.model_dump() - - elif isinstance(message, dict): - if not message_type: - raise ValueError( - "message_type must be provided when message is a dictionary." - ) - message_dict = message - - elif is_dataclass(message): - message_type = message_type or message.__class__.__name__ - message_dict = asdict(message) - - else: - raise ValueError( - "Message must be a Pydantic BaseModel, a dictionary, or a dataclass instance." - ) - - metadata = { - "cloudevent.type": message_type, - "cloudevent.source": source, - } - metadata.update(kwargs) - - logger.debug( - f"{source} preparing to publish '{message_type}' to topic '{topic_name}'." - ) - logger.debug(f"Message: {message_dict}, Metadata: {metadata}") - - await self.publish_message( - topic_name=topic_name, - pubsub_name=pubsub_name or self.message_bus_name, - message=message_dict, - metadata=metadata, - ) - - logger.info(f"{source} published '{message_type}' to topic '{topic_name}'.") - - def register_message_routes(self) -> None: - """ - Registers message handlers dynamically by subscribing once per topic. - Incoming messages are dispatched by CloudEvent `type` to the appropriate handler. - - This function: - - Scans all class methods for the `@message_router` decorator. - - Extracts routing metadata and message model schemas. - - Wraps each handler and maps it by `(pubsub_name, topic_name)` and schema name. - - Ensures only one handler per schema per topic is allowed. - """ - message_handlers = get_decorated_methods(self, "_is_message_handler") - - for method_name, method in message_handlers.items(): - try: - router_data = method._message_router_data.copy() - pubsub_name = router_data.get("pubsub") or self.message_bus_name - is_broadcast = router_data.get("is_broadcast", False) - topic_name = router_data.get("topic") or ( - self.broadcast_topic_name if is_broadcast else self.name - ) - message_schemas = router_data.get("message_schemas", []) - - if not message_schemas: - raise ValueError( - f"No message models found for handler '{method_name}'." - ) - - wrapped_method = self._create_wrapped_method(method) - topic_key = (pubsub_name, topic_name) - - self._topic_handlers.setdefault(topic_key, {}) - - for schema in message_schemas: - if not is_valid_routable_model(schema): - raise ValueError( - f"Unsupported message model for handler '{method_name}': {schema}" - ) - - schema_name = schema.__name__ - logger.debug( - f"Registering handler '{method_name}' for topic '{topic_name}' with model '{schema_name}'" - ) - - # Prevent multiple handlers for the same schema - if schema_name in self._topic_handlers[topic_key]: - raise ValueError( - f"Duplicate handler for model '{schema_name}' on topic '{topic_name}'. " - f"Each model can only be handled by one function per topic." - ) - - self._topic_handlers[topic_key][schema_name] = { - "schema": schema, - "handler": wrapped_method, - } - - except Exception as e: - logger.error( - f"Failed to register handler '{method_name}': {e}", exc_info=True - ) - - # Subscribe once per topic - for pubsub_name, topic_name in self._topic_handlers.keys(): - if topic_name: - # Prevent subscribing to empty or None topics - self._subscribe_with_router(pubsub_name, topic_name) - - logger.info("All message routes registered.") - - def _create_wrapped_method(self, method: Callable) -> Callable: - """ - Wraps a message handler method to ensure it runs asynchronously, - with special handling for workflows. - """ - - @functools.wraps(method) - async def wrapped_method(message: dict): - try: - is_workflow = getattr(method, "_is_workflow", False) - message_type = ( - type(message).__name__ - if hasattr(message, "__class__") - else str(type(message)) - ) - logger.debug( - f"PubSub routing for {method.__name__}: _is_workflow={is_workflow}, message_type={message_type}" - ) - if is_workflow: - workflow_name = getattr(method, "_workflow_name", method.__name__) - # If the message is a Pydantic model, extract metadata and convert to dict - if is_pydantic_model(type(message)): - # Extract metadata if available - metadata = getattr(message, "_message_metadata", None) - # Convert to dict for workflow input - message_dict = message.model_dump() - if metadata is not None: - # Include metadata in the message dict - message_dict["_message_metadata"] = metadata - message = message_dict - - # Prevent triggering multiple orchestrator workflows if one is already running - if ( - workflow_name == "OrchestratorWorkflow" - or workflow_name == "main_workflow" - ): - triggering_workflow_id = message.get("workflow_instance_id") - if triggering_workflow_id: - if hasattr( - self, "_does_workflow_exist" - ) and self._does_workflow_exist(triggering_workflow_id): - logger.info( - f"Triggering workflow {triggering_workflow_id} is still running. Skipping new orchestrator instance." - ) - return None - - # Invoke the workflow - await self.run_and_monitor_workflow_async( - workflow_name, input=message - ) - return None - - if inspect.iscoroutinefunction(method): - return await method(message=message) - else: - return method(message=message) - - except Exception as e: - logger.error( - f"Error invoking handler '{method.__name__}': {e}", exc_info=True - ) - return None - - return wrapped_method - - def _subscribe_with_router(self, pubsub_name: str, topic_name: str): - subscription: Subscription = self._dapr_client.subscribe( - pubsub_name, topic_name - ) - loop = asyncio.get_running_loop() - - def stream_messages(sub: Subscription): - while True: - try: - for message in sub: - if message: - try: - future = asyncio.run_coroutine_threadsafe( - self._route_message( - pubsub_name, topic_name, message - ), - loop, - ) - response = future.result() - sub.respond(message, response.status) - except Exception as e: - print(f"Error handling message: {e}") - else: - continue - except (StreamInactiveError, StreamCancelledError): - break - - def close_subscription(): - subscription.close() - - self._subscriptions[(pubsub_name, topic_name)] = close_subscription - threading.Thread( - target=stream_messages, args=(subscription,), daemon=True - ).start() - - # TODO: retry setup should be configurable - async def _route_message( - self, pubsub_name: str, topic_name: str, message: SubscriptionMessage - ) -> TopicEventResponse: - """ - Routes an incoming message to the correct handler based on CloudEvent `type`. - - Args: - pubsub_name (str): The name of the pubsub component. - topic_name (str): The topic from which the message was received. - message (SubscriptionMessage): The incoming Dapr message. - - Returns: - TopicEventResponse: The response status for the message (success, drop, retry). - """ - try: - handler_map = self._topic_handlers.get((pubsub_name, topic_name), {}) - if not handler_map: - logger.warning( - f"No handlers for topic '{topic_name}' on pubsub '{pubsub_name}'. Dropping message." - ) - return TopicEventResponse("drop") - - # Step 1: Extract CloudEvent metadata and data - event_data, metadata = extract_cloudevent_data(message) - event_type = metadata.get("type") - - # Step 2: Find the handler for the event type - route_entry = handler_map.get(event_type) - if not route_entry: - # If no handler matches the event type, log and drop the message - logger.warning( - f"No handler matched CloudEvent type '{event_type}' on topic '{topic_name}'" - ) - return TopicEventResponse("drop") - - schema = route_entry["schema"] - handler = route_entry["handler"] - - try: - # Step 3: Validate the message against the schema - parsed_message = validate_message_model(schema, event_data) - # Step 4: Attach metadata to the parsed message - if isinstance(parsed_message, dict): - parsed_message["_message_metadata"] = metadata - else: - setattr(parsed_message, "_message_metadata", metadata) - - logger.info( - f"Dispatched to handler '{handler.__name__}' for event type '{event_type}'" - ) - # Step 5: Call the handler with the parsed message - result = await handler(parsed_message) - if result is not None: - return TopicEventResponse("success"), result - - return TopicEventResponse("success") - - except Exception as e: - logger.warning( - f"Failed to validate message against schema '{schema.__name__}': {e}" - ) - return TopicEventResponse("retry") - - except Exception as e: - logger.error(f"Unexpected error during message routing: {e}", exc_info=True) - return TopicEventResponse("retry") diff --git a/dapr_agents/workflow/mixins/service.py b/dapr_agents/workflow/mixins/service.py deleted file mode 100644 index dfd155be..00000000 --- a/dapr_agents/workflow/mixins/service.py +++ /dev/null @@ -1,216 +0,0 @@ -import asyncio -import inspect -import logging -from typing import Optional -from dapr_agents.types.workflow import DaprWorkflowStatus -from dapr_agents.utils import SignalHandlingMixin - -logger = logging.getLogger(__name__) - - -class ServiceMixin(SignalHandlingMixin): - """ - Mixin providing FastAPI service integration and lifecycle management for agentic workflows. - - Features: - - Initializes and manages a FastAPI server for agent workflows. - - Registers HTTP endpoints for workflow status, initiation, and custom user routes. - - Supports both FastAPI service mode and headless (no HTTP server) operation. - - Handles graceful shutdown via signal handling and resource cleanup. - - Integrates workflow execution via HTTP POST and custom endpoints. - - Manages subscription cleanup and workflow runtime shutdown on service stop. - - Provides property access to the FastAPI app instance. - """ - - wf_runtime_is_running: Optional[bool] = None - - @property - def app(self): - """ - Return the FastAPI application initialized via ``as_service``. - - Returns: - FastAPI: The FastAPI app instance. - - Raises: - RuntimeError: If the FastAPI server has not been initialized. - """ - if self._http_server: - return self._http_server.app - raise RuntimeError("FastAPI server not initialized. Call `as_service()` first.") - - def register_routes(self): - """ - Register user-defined FastAPI routes decorated with ``@route``. - """ - for name, method in inspect.getmembers(self, predicate=inspect.ismethod): - if getattr(method, "_is_fastapi_route", False): - path = getattr(method, "_route_path") - method_type = getattr(method, "_route_method", "GET") - extra_kwargs = getattr(method, "_route_kwargs", {}) - logger.info(f"Registering route {method_type} {path} -> {name}") - self.app.add_api_route( - path, method, methods=[method_type], **extra_kwargs - ) - - def as_service(self, port: Optional[int] = None, host: str = "0.0.0.0"): - """ - Enable FastAPI service mode for the agent. - - Args: - port: Required port number. - host: Host address to bind to. - - Returns: - self - - Raises: - ValueError: If port is not provided. - """ - from dapr_agents.service.fastapi import FastAPIServerBase - - if port is None: - raise ValueError("Port must be provided as a parameter") - - self._http_server = FastAPIServerBase( - service_name=self.name, - service_port=port, - service_host=host, - ) - - self.app.add_api_route("/status", lambda: {"ok": True}) - self.app.add_api_route( - "/start-workflow", self.run_workflow_from_request, methods=["POST"] - ) - - self.register_routes() - return self - - async def graceful_shutdown(self) -> None: - """ - Perform graceful shutdown operations for the service. - """ - await self.stop() - - async def start(self): - """ - Start the agent workflow service. - - This method starts the FastAPI server or runs in headless mode. - """ - if self._is_running: - logger.warning( - "Service is already running. Ignoring duplicate start request." - ) - return - - logger.info("Starting Agent Workflow Service...") - self._shutdown_event.clear() - - try: - if not hasattr(self, "_http_server") or self._http_server is None: - logger.info("Running in headless mode.") - # Set up signal handlers using the mixin - self.setup_signal_handlers() - self.register_message_routes() - self._is_running = True - # Wait for shutdown signal - await self.wait_for_shutdown() - else: - logger.info("Running in FastAPI service mode.") - self.register_message_routes() - self._is_running = True - await self._http_server.start() - except asyncio.CancelledError: - logger.info("Service received cancellation signal.") - finally: - await self.stop() - - async def stop(self): - """ - Stop the agent workflow service and clean up resources. - """ - if not self._is_running: - logger.warning("Service is not running. Ignoring stop request.") - return - - logger.info("Stopping Agent Workflow Service...") - - # Save state before shutting down to ensure persistence and agent durability to properly rerun after being stoped - try: - if hasattr(self, "save_state") and hasattr(self, "state"): - # Graceful shutdown compensation: Save incomplete instance if it exists - if hasattr(self, "workflow_instance_id") and self.workflow_instance_id: - if self.workflow_instance_id not in self.state.get("instances", {}): - # This instance was never saved, add it as incomplete - from datetime import datetime, timezone - - incomplete_entry = { - "messages": [], - "start_time": datetime.now(timezone.utc).isoformat(), - "source": "graceful_shutdown", - "triggering_workflow_instance_id": None, - "workflow_name": getattr(self, "_workflow_name", "Unknown"), - "dapr_status": DaprWorkflowStatus.PENDING, - "suspended_reason": "app_terminated", - "trace_context": {"needs_agent_span_on_resume": True}, - } - self.state.setdefault("instances", {})[ - self.workflow_instance_id - ] = incomplete_entry - logger.info( - f"Added incomplete instance {self.workflow_instance_id} during graceful shutdown" - ) - else: - # Mark running instances as needing AGENT spans on resume - if "instances" in self.state: - for instance_id, instance_data in self.state[ - "instances" - ].items(): - # Only mark instances that are still running (no end_time) - if not instance_data.get("end_time"): - instance_data[ - "dapr_status" - ] = DaprWorkflowStatus.SUSPENDED - instance_data["suspended_reason"] = "app_terminated" - - # Mark trace context for AGENT span creation on resume - if instance_data.get("trace_context"): - instance_data["trace_context"][ - "needs_agent_span_on_resume" - ] = True - logger.debug( - f"Marked trace context for AGENT span creation on resume for {instance_id}" - ) - - logger.info( - f"Marked instance {instance_id} as suspended due to app termination" - ) - - self.save_state() - logger.debug("Workflow state saved successfully.") - except Exception as e: - logger.error(f"Failed to save state during shutdown: {e}") - - for (pubsub_name, topic_name), close_fn in self._subscriptions.items(): - try: - logger.info( - f"Unsubscribing from pubsub '{pubsub_name}' topic '{topic_name}'" - ) - close_fn() - except Exception as e: - logger.error(f"Failed to unsubscribe from topic '{topic_name}': {e}") - - self._subscriptions.clear() - - if hasattr(self, "_http_server") and self._http_server: - logger.info("Stopping FastAPI server...") - await self._http_server.stop() - - if getattr(self, "_wf_runtime_is_running", False): - logger.info("Shutting down workflow runtime.") - self.stop_runtime() - self.wf_runtime_is_running = False - - self._is_running = False - logger.info("Agent Workflow Service stopped successfully.") diff --git a/dapr_agents/workflow/mixins/state.py b/dapr_agents/workflow/mixins/state.py deleted file mode 100644 index 03a78362..00000000 --- a/dapr_agents/workflow/mixins/state.py +++ /dev/null @@ -1,252 +0,0 @@ -import json -import logging -import os -import tempfile -import threading -from typing import Optional, Union - -from pydantic import BaseModel, ValidationError - -logger = logging.getLogger(__name__) - -state_lock = threading.Lock() - - -class StateManagementMixin: - """ - Mixin providing workflow state initialization, validation, and persistence. - """ - - def initialize_state(self) -> None: - """ - Initialize workflow state from provided value or storage. - - Raises: - RuntimeError: If state initialization or loading from storage fails. - """ - try: - if self.state is None: - logger.debug("No user-provided state. Attempting to load from storage.") - self.state = self.load_state() - - if isinstance(self.state, BaseModel): - logger.debug( - "User provided a state as a Pydantic model. Converting to dict." - ) - self.state = self.state.model_dump() - - if not isinstance(self.state, dict): - raise TypeError( - f"Invalid state type: {type(self.state)}. Expected dict." - ) - - logger.debug(f"Workflow state initialized with {len(self.state)} key(s).") - self.save_state() - except Exception as e: - raise RuntimeError(f"Error initializing workflow state: {e}") from e - - def validate_state(self, state_data: dict) -> dict: - """ - Validate the workflow state against ``state_format`` if provided. - - Args: - state_data: The raw state data to validate. - - Returns: - dict: The validated and structured state. - - Raises: - ValidationError: If the state data does not conform to the expected schema. - """ - try: - if not self.state_format: - logger.warning( - "No schema (state_format) provided; returning state as-is." - ) - return state_data - - logger.debug("Validating workflow state against schema.") - validated_state: BaseModel = self.state_format(**state_data) - return validated_state.model_dump() - except ValidationError as e: - raise ValidationError(f"Invalid workflow state: {e.errors()}") from e - - def load_state(self) -> dict: - """ - Load the workflow state from the configured Dapr state store. - - Returns: - dict: The loaded and optionally validated state. - - Raises: - RuntimeError: If the state store is not properly configured. - TypeError: If the retrieved state is not a dictionary. - ValidationError: If state schema validation fails. - """ - try: - if ( - not self._state_store_client - or not self.state_store_name - or not self.state_key - ): - logger.error("State store is not configured. Cannot load state.") - raise RuntimeError( - "State store is not configured. Please provide 'state_store_name' and 'state_key'." - ) - - # For durable agents, always load from database to ensure it's the source of truth - has_state, state_data = self._state_store_client.try_get_state( - self.state_key - ) - if has_state and state_data: - logger.debug( - f"Existing state found for key '{self.state_key}'. Validating it." - ) - if not isinstance(state_data, dict): - raise TypeError( - f"Invalid state type retrieved: {type(state_data)}. Expected dict." - ) - - # Set self.state to the loaded data - if self.state_format: - loaded_state = self.validate_state(state_data) - else: - loaded_state = state_data - - self.state = loaded_state - logger.debug(f"Set self.state to loaded data: {self.state}") - - return loaded_state - - logger.debug( - f"No existing state found for key '{self.state_key}'. Initializing empty state." - ) - return {} - except Exception as e: - logger.error(f"Failed to load state for key '{self.state_key}': {e}") - raise RuntimeError(f"Error loading workflow state: {e}") from e - - def get_local_state_file_path(self) -> str: - """ - Return the file path for saving the local state. - - Returns: - str: The absolute path to the local state file. - """ - directory = self.local_state_path or os.getcwd() - os.makedirs(directory, exist_ok=True) - return os.path.join(directory, f"{self.state_key}.json") - - def save_state_to_disk( - self, state_data: str, filename: Optional[str] = None - ) -> None: - """ - Safely save the workflow state to a local JSON file. - - Args: - state_data: The state data to save (as JSON string or dict). - filename: Optional filename for the state file. - - Raises: - RuntimeError: If saving to disk fails. - """ - try: - save_directory = self.local_state_path or os.getcwd() - os.makedirs(save_directory, exist_ok=True) - filename = filename or f"{self.name}_state.json" - file_path = os.path.join(save_directory, filename) - - with tempfile.NamedTemporaryFile( - "w", dir=save_directory, delete=False - ) as tmp_file: - tmp_file.write(state_data) - temp_path = tmp_file.name - - with state_lock: - existing_state = {} - if os.path.exists(file_path): - with open(file_path, "r", encoding="utf-8") as file: - try: - existing_state = json.load(file) - except json.JSONDecodeError: - logger.warning( - "Existing state file is corrupt or empty. Overwriting." - ) - - new_state = ( - json.loads(state_data) - if isinstance(state_data, str) - else state_data - ) - merged_state = {**existing_state, **new_state} - - with open(temp_path, "w", encoding="utf-8") as file: - json.dump(merged_state, file, indent=4) - - os.replace(temp_path, file_path) - - logger.debug(f"Workflow state saved locally at '{file_path}'.") - except Exception as e: - logger.error(f"Failed to save workflow state to disk: {e}") - raise RuntimeError(f"Error saving workflow state to disk: {e}") - - def save_state( - self, - state: Optional[Union[dict, BaseModel, str]] = None, - force_reload: bool = False, - ) -> None: - """ - Save the current workflow state to Dapr and optionally to disk. - - Args: - state: The new state to save. If not provided, saves the existing state. - force_reload: If True, reloads the state from the store after saving. - - Raises: - RuntimeError: If the state store is not configured. - TypeError: If the provided state is not a supported type. - ValueError: If the provided state is a string but not valid JSON. - """ - try: - if ( - not self._state_store_client - or not self.state_store_name - or not self.state_key - ): - logger.error("State store is not configured. Cannot save state.") - raise RuntimeError( - "State store is not configured. Please provide 'state_store_name' and 'state_key'." - ) - - self.state = state or self.state - if not self.state: - logger.warning("Skipping state save: Empty state.") - return - - if isinstance(self.state, BaseModel): - state_to_save = self.state.model_dump_json() - elif isinstance(self.state, dict): - state_to_save = json.dumps(self.state) - elif isinstance(self.state, str): - try: - json.loads(self.state) - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON string provided as state: {e}") - state_to_save = self.state - else: - raise TypeError( - f"Invalid state type: {type(self.state)}. Expected dict, BaseModel, or JSON string." - ) - - self._state_store_client.save_state(self.state_key, state_to_save) - logger.debug(f"Successfully saved state for key '{self.state_key}'.") - - if self.save_state_locally: - self.save_state_to_disk(state_data=state_to_save) - - if force_reload: - self.state = self.load_state() - logger.debug(f"State reloaded after saving for key '{self.state_key}'.") - except Exception as e: - logger.error(f"Failed to save state for key '{self.state_key}': {e}") - raise diff --git a/dapr_agents/workflow/orchestrators/README.md b/dapr_agents/workflow/orchestrators/README.md deleted file mode 100644 index a168d934..00000000 --- a/dapr_agents/workflow/orchestrators/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# Orchestrators - -Available Workflow options to orchestrate communication between agents: - -- LLM-based: Uses a large language model (e.g., GPT-4o) to determine the most suitable agent based on the message and context. -- Random: Selects an agent randomly for each task. -- RoundRobin: Cycles through agents in a fixed order, ensuring each agent gets an equal opportunity to process tasks. - -## Visual representation of each orchestration option: -![Orchestrator workflows visualized](./orchestratorWorkflows.png) \ No newline at end of file diff --git a/dapr_agents/workflow/orchestrators/base.py b/dapr_agents/workflow/orchestrators/base.py deleted file mode 100644 index 8ec06f86..00000000 --- a/dapr_agents/workflow/orchestrators/base.py +++ /dev/null @@ -1,76 +0,0 @@ -import logging -from abc import ABC, abstractmethod -from typing import Any, Optional - -from dapr.ext.workflow import DaprWorkflowContext -from pydantic import Field, model_validator - -from dapr_agents.workflow.agentic import AgenticWorkflow - -logger = logging.getLogger(__name__) - - -class OrchestratorWorkflowBase(AgenticWorkflow, ABC): - orchestrator_topic_name: Optional[str] = Field( - None, - description="The topic name dedicated to this specific orchestrator, derived from the orchestrator's name if not provided.", - ) - - @model_validator(mode="before") - def set_orchestrator_topic_name(cls, values: dict): - # Derive orchestrator_topic_name from agent name - if not values.get("orchestrator_topic_name") and values.get("name"): - values["orchestrator_topic_name"] = values["name"] - - return values - - def model_post_init(self, __context: Any) -> None: - """ - Register agentic workflow. - """ - - # Complete post-initialization - super().model_post_init(__context) - - # Prepare agent metadata - self._agent_metadata = { - "name": self.name, - "topic_name": self.orchestrator_topic_name, - "pubsub_name": self.message_bus_name, - "orchestrator": True, - } - - # Register agent metadata - self.register_agentic_system() - - # Start the runtime if it's not already running - self.start_runtime() - - @abstractmethod - def main_workflow(self, ctx: DaprWorkflowContext, message: Any) -> Any: - """ - Execute the primary workflow that coordinates agent interactions. - - Args: - ctx (DaprWorkflowContext): The workflow execution context - message (Any): The input for this workflow iteration - - Returns: - Any: The workflow result or continuation - """ - pass - - @abstractmethod - async def process_agent_response(self, message: Any) -> None: - """Process responses from agents.""" - pass - - @abstractmethod - async def broadcast_message_to_agents(self, **kwargs) -> None: - """Broadcast a message to all registered agents.""" - pass - - @abstractmethod - async def trigger_agent(self, name: str, instance_id: str, **kwargs) -> None: - """Trigger a specific agent to perform an action.""" - pass diff --git a/dapr_agents/workflow/orchestrators/llm/orchestrator.py b/dapr_agents/workflow/orchestrators/llm/orchestrator.py deleted file mode 100644 index aad561b5..00000000 --- a/dapr_agents/workflow/orchestrators/llm/orchestrator.py +++ /dev/null @@ -1,1568 +0,0 @@ -import json -import logging -from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional - -from dapr.ext.workflow import DaprWorkflowContext -from pydantic import Field - -from dapr_agents.workflow.decorators import message_router, task, workflow -from dapr_agents.workflow.orchestrators.base import OrchestratorWorkflowBase -from dapr_agents.workflow.orchestrators.llm.prompts import ( - NEXT_STEP_PROMPT, - PROGRESS_CHECK_PROMPT, - SUMMARY_GENERATION_PROMPT, - TASK_INITIAL_PROMPT, - TASK_PLANNING_PROMPT, -) -from dapr_agents.workflow.orchestrators.llm.schemas import ( - AgentTaskResponse, - BroadcastMessage, - IterablePlanStep, - NextStep, - ProgressCheckOutput, - TriggerAction, - InternalTriggerAction, - schemas, -) -from dapr_agents.workflow.orchestrators.llm.state import ( - LLMWorkflowEntry, - LLMWorkflowMessage, - LLMWorkflowState, - PlanStep, - TaskResult, -) -from dapr_agents.workflow.orchestrators.llm.utils import ( - find_step_in_plan, - restructure_plan, - update_step_statuses, -) -from dapr_agents.memory import ConversationDaprStateMemory - -logger = logging.getLogger(__name__) - - -class LLMOrchestrator(OrchestratorWorkflowBase): - """ - Implements an agentic workflow where an LLM dynamically selects the next speaker. - The workflow iterates through conversations, updating its state and persisting messages. - - Uses the `continue_as_new` pattern to restart the workflow with updated input at each iteration. - """ - - workflow_instance_id: Optional[str] = Field( - default=None, - description="The current workflow instance ID for this orchestrator.", - ) - memory: Optional[ConversationDaprStateMemory] = Field( - default=None, - description="Persistent memory with session-based state hydration.", - ) - - def model_post_init(self, __context: Any) -> None: - """ - Initializes and configures the LLM-based workflow service. - """ - - # Call OrchestratorWorkflowBase's model_post_init first to initialize state store and other dependencies - # This will properly load state from storage if it exists - super().model_post_init(__context) - - self._workflow_name = "OrchestratorWorkflow" - # TODO(@Sicoyle): fix this later!! - self._is_orchestrator = True # Flag for PubSub deduplication to prevent orchestrator workflows from being triggered multiple times - - if not self.state: - logger.debug("No state found, initializing empty state") - self.state = {"instances": {}} - - if self.memory is not None: - self.memory = ConversationDaprStateMemory( - store_name=self.memory.store_name, - session_id=f"{self.name or 'orchestrator'}_session", - ) - - print(f"sam memory store name is {self.memory.store_name}") - logger.info(f"sam memory store name is {self.memory.store_name}") - - # Load the current workflow instance ID from state using session_id) - if self.state and self.state.get("instances"): - logger.debug(f"Found {len(self.state['instances'])} instances in state") - - current_session_id = ( - self.memory.session_id - if self.memory - else f"{self.name}_default_session" - ) - for instance_id, instance_data in self.state["instances"].items(): - stored_workflow_name = instance_data.get("workflow_name") - stored_session_id = instance_data.get("session_id") - logger.debug( - f"Instance {instance_id}: workflow_name={stored_workflow_name}, session_id={stored_session_id}, current_workflow_name={self._workflow_name}, current_session_id={current_session_id}" - ) - if ( - stored_workflow_name == self._workflow_name - and stored_session_id == current_session_id - ): - self.workflow_instance_id = instance_id - logger.debug( - f"Loaded current workflow instance ID from state using session_id: {instance_id}" - ) - break - else: - logger.debug("No instances found in state or state is empty") - - # Sync workflow state with Dapr runtime after loading - # This ensures our database reflects the actual state of resumed workflows - self._sync_workflow_state_after_startup() - - def _convert_plan_objects_to_dicts( - self, plan_objects: List[Any] - ) -> List[Dict[str, Any]]: - """ - Convert plan objects to dictionaries for JSON serialization. - Handles both Pydantic models and regular dictionaries. - """ - result = [] - for obj in plan_objects: - if hasattr(obj, "model_dump"): - # Pydantic model - result.append(obj.model_dump()) - elif isinstance(obj, dict): - # Already a dictionary - result.append(obj) - else: - # Fallback: try to convert to dict - result.append(dict(obj) if hasattr(obj, "__dict__") else obj) - return result - - def _does_workflow_exist(self, instance_id: str) -> bool: - """ - Check if a workflow instance exists and is accessible via the Dapr client. - - This function attempts to retrieve the workflow metadata from Dapr. A successful - response indicates the workflow exists in Dapr's state store, while failures - (e.g., not found errors) indicate the workflow is no longer accessible. - - Args: - instance_id (str): The workflow instance ID to check - - Returns: - bool: True if the workflow exists and is accessible, False if not found or on error - """ - try: - # Use Dapr client to get workflow instance status - response = self._dapr_client.get_workflow(instance_id=instance_id) - # If we get a response, the workflow exists and is accessible - return response is not None - except Exception as e: - logger.debug(f"Workflow {instance_id} not found or not accessible: {e}") - return False - - @message_router - async def handle_external_trigger(self, message: TriggerAction): - """ - Handle external TriggerAction messages from end users. - This starts a new orchestrator workflow. - """ - try: - logger.info(f"External trigger received: {message.task}") - await self.run_and_monitor_workflow_async( - workflow="OrchestratorWorkflow", input=message - ) - except Exception as e: - logger.error(f"Error handling external trigger: {e}", exc_info=True) - - @workflow(name="OrchestratorWorkflow") - # TODO: set retry policies on the activities! - # TODO: utilize prompt verdict value of failed as we do not currently use. - # https://github.com/dapr/dapr-agents/pull/136#discussion_r2175751545 - def main_workflow(self, ctx: DaprWorkflowContext, message: TriggerAction): - """ - Executes an LLM-driven agentic workflow where the next agent is dynamically selected - based on task progress. Runs for up to `self.max_iterations` turns, then summarizes. - - Args: - ctx (DaprWorkflowContext): The workflow execution context. - message (TriggerAction): Contains the current `task`. - - Returns: - str: The final summary when the workflow terminates. - - Raises: - RuntimeError: If the workflow ends unexpectedly without a final summary. - """ - # Step 1: Retrieve initial task and ensure state entry exists - task = message.get("task") - instance_id = ctx.instance_id - self.state.setdefault("instances", {}).setdefault( - instance_id, LLMWorkflowEntry(input=task).model_dump(mode="json") - ) - # Initialize plan as empty list - it will be set after turn 1 - plan = [] - final_summary: Optional[str] = None - - # Single loop from turn 1 to max_iterations - for turn in range(1, self.max_iterations + 1): - if not ctx.is_replaying: - logger.debug( - f"Workflow turn {turn}/{self.max_iterations} (Instance ID: {instance_id})" - ) - - # Get available agents - agents = yield ctx.call_activity(self.get_available_agents) - - # On turn 1, atomically generate plan and broadcast task - if turn == 1: - if not ctx.is_replaying: - logger.info(f"Initial message from User -> {self.name}") - - init_result = yield ctx.call_activity( - self.initialize_workflow_with_plan, - input={ - "instance_id": instance_id, - "task": task, - "agents": agents, - "wf_time": ctx.current_utc_datetime.isoformat(), - }, - ) - logger.info(f"Workflow initialized with plan: {init_result['status']}") - plan = init_result["plan"] - - # Determine next step and dispatch - # Plan is now always a list of dictionaries after turn 1 - plan_objects = plan if plan else [] - - # If plan is empty, read from workflow state - if not plan_objects: - plan_objects = self.state["instances"][instance_id].get("plan", []) - plan = plan_objects - next_step = yield ctx.call_activity( - self.generate_next_step, - input={ - "task": task, - "agents": agents, - "plan": json.dumps( - self._convert_plan_objects_to_dicts(plan_objects), indent=2 - ), - "next_step_schema": schemas.next_step, - }, - ) - # Additional Properties from NextStep - next_agent = next_step["next_agent"] - instruction = next_step["instruction"] - step_id = next_step.get("step", None) - substep_id = next_step.get("substep", None) - - # Validate Step Before Proceeding - valid_step = yield ctx.call_activity( - self.validate_next_step, - input={ - "instance_id": instance_id, - "plan": self._convert_plan_objects_to_dicts(plan_objects), - "step": step_id, - "substep": substep_id, - }, - ) - - if valid_step: - # Atomically execute agent task and mark step as in_progress - execution_result = yield ctx.call_activity( - self.execute_agent_task_with_progress_tracking, - input={ - "instance_id": instance_id, - "next_agent": next_agent, - "step_id": step_id, - "substep_id": substep_id, - "instruction": instruction, - "task": task, - "plan_objects": self._convert_plan_objects_to_dicts( - plan_objects - ), - }, - ) - plan_objects = execution_result["plan"] - - # Wait for agent response or timeout - if not ctx.is_replaying: - logger.debug(f"Waiting for {next_agent}'s response...") - - event_data = ctx.wait_for_external_event("AgentTaskResponse") - timeout_task = ctx.create_timer(timedelta(seconds=self.timeout)) - any_results = yield self.when_any([event_data, timeout_task]) - - # Handle Agent Response or Timeout - if any_results == timeout_task: - if not ctx.is_replaying: - logger.warning( - f"Agent response timed out (Iteration: {turn}, Instance ID: {instance_id})." - ) - task_results = { - "name": self.name, - "role": "user", - "content": f"Timeout occurred. {next_agent} did not respond on time. We need to try again...", - } - else: - task_results = yield event_data - if not ctx.is_replaying: - logger.info(f"{task_results['name']} sent a response.") - - # Atomically process agent response, update history, check progress, and update plan - response_result = yield ctx.call_activity( - self.process_agent_response_with_progress, - input={ - "instance_id": instance_id, - "agent": next_agent, - "step_id": step_id, - "substep_id": substep_id, - "task_results": task_results, - "task": task, - "plan_objects": self._convert_plan_objects_to_dicts( - plan_objects - ), - }, - ) - - # Update local variables with results - plan_objects = response_result["plan"] - verdict = response_result["verdict"] - if not ctx.is_replaying: - logger.debug(f"Progress verdict: {verdict}") - logger.debug(f"Status updates: {response_result['status_updates']}") - logger.debug(f"Plan updates: {response_result['plan_updates']}") - - # Update the plan variable to reflect the current state - plan = plan_objects - else: - if not ctx.is_replaying: - logger.warning( - f"Invalid step {step_id}/{substep_id} in plan for instance {instance_id}. Retrying..." - ) - # Recovery Task: No updates, just iterate again - verdict = "continue" - task_results = { - "name": self.name, - "role": "user", - "content": f"Step {step_id}, Substep {substep_id} does not exist in the plan. Adjusting workflow...", - } - - # Process progress suggestions and next iteration count - if verdict != "continue" or turn == self.max_iterations: - if not ctx.is_replaying: - finale = ( - "max_iterations_reached" - if turn == self.max_iterations - else verdict - ) - logger.info(f"Ending workflow with verdict: {finale}") - - # Atomically generate summary and finalize workflow - final_summary = yield ctx.call_activity( - self.finalize_workflow_with_summary, - input={ - "instance_id": instance_id, - "task": task, - "verdict": verdict, - "plan_objects": self._convert_plan_objects_to_dicts( - plan_objects - ), - "step_id": step_id, - "substep_id": substep_id, - "agent": next_agent, - "result": task_results["content"], - "wf_time": ctx.current_utc_datetime.isoformat(), - }, - ) - - # Return the final summary - this should terminate the workflow - if not ctx.is_replaying: - logger.info(f"Workflow {instance_id} finalized.") - return final_summary - else: - # --- PREPARE NEXT TURN --- - task = task_results["content"] - - # Should never reach here - raise RuntimeError(f"OrchestratorWorkflow {instance_id} exited without summary") - - @task - def get_available_agents(self) -> str: - """ - Retrieves and formats metadata about available agents. - - Returns: - str: A formatted string listing the available agents and their roles. - """ - agents_metadata = self.get_agents_metadata(exclude_orchestrator=True) - if not agents_metadata: - return "No available agents to assign tasks." - - # Format agent details into a readable string - agent_list = "\n".join( - [ - f"- {name}: {metadata.get('role', 'Unknown role')} (Goal: {metadata.get('goal', 'Unknown')})" - for name, metadata in agents_metadata.items() - ] - ) - - return agent_list - - @task(description=NEXT_STEP_PROMPT) - async def generate_next_step( - self, task: str, agents: str, plan: str, next_step_schema: str - ) -> NextStep: - """ - Determines the next agent to respond in a workflow. - - Args: - task (str): The current task description. - agents (str): A list of available agents. - plan (str): The structured execution plan. - next_step_schema (str): The next step schema. - - Returns: - NextStep: A structured response with the next agent, an instruction, and step ids. - """ - # Use the original prompt template - prompt = NEXT_STEP_PROMPT.format( - task=task, agents=agents, plan=plan, next_step_schema=next_step_schema - ) - - # Call LLM with prompt - response = self.llm.generate( - inputs=[{"role": "user", "content": prompt}], - response_format=NextStep, - structured_mode="json", - ) - - # Parse the response - if hasattr(response, "choices") and response.choices: - # If it's still a raw response, parse it - next_step_data = response.choices[0].message.content - logger.debug(f"Next step generation response: {next_step_data}") - next_step_dict = json.loads(next_step_data) - return NextStep(**next_step_dict) - else: - # If it's already a Pydantic model - return response - - @task - async def validate_next_step( - self, - instance_id: str, - plan: List[Dict[str, Any]], - step: int, - substep: Optional[float], - ) -> bool: - """ - Validates if the next step exists in the current execution plan. - - Args: - instance_id (str): The workflow instance ID. - plan (List[Dict[str, Any]]): The current execution plan. - step (int): The step number. - substep (Optional[float]): The substep number. - - Returns: - bool: True if the step exists, False if it does not. - """ - step_entry = find_step_in_plan(plan, step, substep) - if not step_entry: - logger.error( - f"Step {step}, Substep {substep} not found in plan for instance {instance_id}." - ) - return False - return True - - # ============================================================================ - # ABSTRACT METHOD IMPLEMENTATIONS REQUIRED BY OrchestratorWorkflowBase - # ============================================================================ - - async def broadcast_message_to_agents(self, **kwargs) -> None: - """ - Broadcast a message to all registered agents. - Required by OrchestratorWorkflowBase abstract method. - """ - instance_id = kwargs.get("instance_id") - message = kwargs.get("message") - if instance_id and message: - await self.broadcast_message_to_agents_internal(instance_id, message) - - async def trigger_agent(self, name: str, instance_id: str, **kwargs) -> None: - """ - Trigger a specific agent to perform an action. - Required by OrchestratorWorkflowBase abstract method. - """ - step = kwargs.get("step") - substep = kwargs.get("substep") - instruction = kwargs.get("instruction") - plan = kwargs.get("plan", []) - - if step is not None and instruction: - await self.trigger_agent_internal( - instance_id=instance_id, - name=name, - step=step, - substep=substep, - instruction=instruction, - plan=plan, - ) - - # NOTE: The remaining @task decorated functions handle specific workflow activities - - @task - async def update_plan( - self, - instance_id: str, - plan: List[Dict[str, Any]], - status_updates: Optional[List[Dict[str, Any]]] = None, - plan_updates: Optional[List[Dict[str, Any]]] = None, - ): - """ - Updates the execution plan based on status changes and/or plan restructures. - - Args: - instance_id (str): The workflow instance ID. - plan (List[Dict[str, Any]]): The current execution plan. - status_updates (Optional[List[Dict[str, Any]]]): List of updates for step statuses. - plan_updates (Optional[List[Dict[str, Any]]]): List of full step modifications. - - Raises: - ValueError: If a specified step or substep is not found. - """ - logger.debug(f"Updating plan for instance {instance_id}") - - # Step 1: Apply status updates directly to `plan` - if status_updates: - logger.info(f"Applying {len(status_updates)} status updates to plan") - for update in status_updates: - step_id = update["step"] - substep_id = update.get("substep") - new_status = update["status"] - - logger.info( - f"Updating step {step_id}, substep {substep_id} to '{new_status}'" - ) - step_entry = find_step_in_plan(plan, step_id, substep_id) - if not step_entry: - error_msg = f"Step {step_id}, Substep {substep_id} not found in the current plan." - logger.error(error_msg) - raise ValueError(error_msg) - - # Apply status update - step_entry["status"] = new_status - logger.info( - f"Successfully updated status of step {step_id}, substep {substep_id} to '{new_status}'" - ) - - # Step 2: Apply plan restructuring updates (if provided) - if plan_updates: - plan = restructure_plan(plan, plan_updates) - logger.debug( - f"Applied restructuring updates for {len(plan_updates)} steps." - ) - - # Step 3: Apply global consistency checks for statuses - plan = update_step_statuses(plan) - - # Save to state and update workflow - await self.update_workflow_state(instance_id=instance_id, plan=plan) - - logger.info(f"Plan successfully updated for instance {instance_id}") - - @task - async def finish_workflow( - self, - instance_id: str, - plan: List[Dict[str, Any]], - step: int, - substep: Optional[float], - verdict: str, - summary: str, - wf_time: str, - ): - """ - Finalizes the workflow by updating the plan, marking the provided step/substep as completed if applicable, - and storing the summary and verdict. - - Args: - instance_id (str): The workflow instance ID. - plan (List[Dict[str, Any]]): The current execution plan. - step (int): The step that was last worked on. - substep (Optional[float]): The substep that was last worked on (if applicable). - verdict (str): The final workflow verdict (`completed`, `failed`, or `max_iterations_reached`). - summary (str): The generated summary of the workflow execution. - - Returns: - None - """ - status_updates = [] - - if verdict == "completed": - # Find and validate the step or substep - step_entry = find_step_in_plan(plan, step, substep) - if not step_entry: - raise ValueError( - f"Step {step}, Substep {substep} not found in the current plan. Cannot mark as completed." - ) - - # Mark the step or substep as completed - step_entry["status"] = "completed" - status_updates.append( - {"step": step, "substep": substep, "status": "completed"} - ) - - # If it's a substep, check if all sibling substeps are completed - parent_step = find_step_in_plan( - plan, step - ) # Retrieve parent without `substep` - if parent_step: - # Ensure "substeps" is a valid list before iteration - if not isinstance(parent_step.get("substeps"), list): - parent_step["substeps"] = [] - - all_substeps_completed = all( - ss.get("status") == "completed" for ss in parent_step["substeps"] - ) - if all_substeps_completed: - parent_step["status"] = "completed" - status_updates.append({"step": step, "status": "completed"}) - - # Apply updates in one call - if status_updates: - await self.update_plan( - instance_id=instance_id, plan=plan, status_updates=status_updates - ) - - # Store the final summary and verdict in workflow state - await self.update_workflow_state( - instance_id=instance_id, wf_time=wf_time, final_output=summary - ) - - @task - async def initialize_workflow_with_plan( - self, instance_id: str, task: str, agents: str, wf_time: str - ) -> Dict[str, Any]: - """ - Atomically generates a plan and broadcasts it to all agents. - If a plan already exists in state, it will be reused (state hydration). - - Args: - instance_id (str): The workflow instance ID. - task (str): The task description. - agents (str): Formatted list of available agents. - wf_time (str): Workflow timestamp. - - Returns: - Dict containing the generated plan and broadcast status - """ - try: - # Look for existing plan using session_id - existing_plan = None - for stored_instance_id, instance_data in self.state.get( - "instances", {} - ).items(): - stored_session_id = instance_data.get("session_id") - current_session_id = ( - self.memory.session_id - if self.memory - else f"{self.name}_default_session" - ) - if stored_session_id == current_session_id: - existing_plan = instance_data.get("plan", []) - logger.debug( - f"Found existing plan for session_id {current_session_id} in instance {stored_instance_id}" - ) - break - - if existing_plan: - logger.debug( - f"Found existing plan in workflow state, reusing it: {len(existing_plan)} steps" - ) - plan_objects = existing_plan - else: - # Generate new plan using the LLM - logger.debug( - "No existing plan found in workflow state, generating new plan" - ) - response = self.llm.generate( - messages=[ - { - "role": "user", - "content": TASK_PLANNING_PROMPT.format( - task=task, - agents=agents, - plan_schema=schemas.plan, - ), - } - ], - response_format=IterablePlanStep, - structured_mode="json", - ) - - # Parse the response - if isinstance(response, str): - # If it's a raw JSON string - plan_dict = json.loads(response) - plan_objects = [ - PlanStep(**step_dict) - for step_dict in plan_dict.get("objects", []) - ] - elif hasattr(response, "choices") and response.choices: - # If it's an OpenAI-style response with multiple choices - plan_objects = [] - for choice in response.choices: - plan_data = choice.message.content - if isinstance(plan_data, str): - plan_dict = json.loads(plan_data) - plan_objects.extend( - PlanStep(**step_dict) - for step_dict in plan_dict.get("objects", []) - ) - elif hasattr(plan_data, "objects"): - plan_objects.extend(plan_data.objects) - else: - # If it's already a Pydantic model - plan_objects = ( - response.objects if hasattr(response, "objects") else [] - ) - - logger.debug( - f"Plan generation response with {len(plan_objects)} objects: {plan_objects}" - ) - - # Format and broadcast message - plan_dicts = self._convert_plan_objects_to_dicts(plan_objects) - formatted_message = TASK_INITIAL_PROMPT.format( - task=task, agents=agents, plan=json.dumps(plan_dicts, indent=2) - ) - - if not existing_plan: - await self.update_workflow_state( - instance_id=instance_id, plan=plan_dicts, wf_time=wf_time - ) - - # Store the workflow instance ID for session-based state rehydration - self.workflow_instance_id = instance_id - logger.debug(f"Stored workflow instance ID: {instance_id}") - - # Broadcast to agents - await self.broadcast_message_to_agents_internal( - instance_id=instance_id, message=formatted_message - ) - - return {"plan": plan_dicts, "broadcast_sent": True, "status": "success"} - - except Exception as e: - logger.error(f"Failed to initialize workflow: {e}") - # Rollback: clear any partial state - await self.rollback_workflow_initialization(instance_id) - raise - - @task - async def execute_agent_task_with_progress_tracking( - self, - instance_id: str, - next_agent: str, - step_id: int, - substep_id: Optional[float], - instruction: str, - task: str, - plan_objects: List[Dict[str, Any]], - ) -> Dict[str, Any]: - """ - Atomically executes agent task and marks step as in_progress. - - Args: - instance_id (str): The workflow instance ID. - next_agent (str): The agent to trigger. - step_id (int): The step number. - substep_id (Optional[float]): The substep number. - instruction (str): The instruction for the agent. - task (str): The current task description. - plan_objects (List[Dict[str, Any]]): The current plan. - - Returns: - Dict containing updated plan and status - """ - try: - # Trigger agent and mark step as in_progress - updated_plan = await self.trigger_agent_internal( - instance_id=instance_id, - name=next_agent, - step=step_id, - substep=substep_id, - instruction=instruction, - plan=plan_objects, - ) - - return {"plan": updated_plan, "status": "agent_triggered"} - - except Exception as e: - logger.error(f"Failed to execute agent task: {e}") - # Rollback: revert step status - await self.rollback_agent_trigger(instance_id, step_id, substep_id) - raise - - @task - async def process_agent_response_with_progress( - self, - instance_id: str, - agent: str, - step_id: int, - substep_id: Optional[float], - task_results: Dict[str, Any], - task: str, - plan_objects: List[Dict[str, Any]], - ) -> Dict[str, Any]: - """ - Atomically processes agent response, updates history, checks progress, and updates plan. - - Args: - instance_id (str): The workflow instance ID. - agent (str): The agent name. - step_id (int): The step number. - substep_id (Optional[float]): The substep number. - task_results (Dict[str, Any]): The agent's response. - task (str): The current task description. - plan_objects (List[Dict[str, Any]]): The current plan. - - Returns: - Dict containing updated plan, verdict, and status updates - """ - try: - # Step 1: Update task history - await self.update_task_history_internal( - instance_id=instance_id, - agent=agent, - step=step_id, - substep=substep_id, - results=task_results, - plan=plan_objects, - ) - - # Step 2: Check progress using LLM directly - progress_response = self.llm.generate( - messages=[ - { - "role": "user", - "content": PROGRESS_CHECK_PROMPT.format( - task=task, - plan=json.dumps( - self._convert_plan_objects_to_dicts(plan_objects), - indent=2, - ), - step=step_id, - substep=substep_id if substep_id is not None else "N/A", - results=task_results["content"], - progress_check_schema=schemas.progress_check, - ), - } - ], - response_format=ProgressCheckOutput, - structured_mode="json", - ) - - # Parse the response - now we get a Pydantic model directly - if hasattr(progress_response, "choices") and progress_response.choices: - # If it's still a raw response, parse it - progress_data = progress_response.choices[0].message.content - logger.debug(f"Progress check response: {progress_data}") - progress_dict = json.loads(progress_data) - progress = ProgressCheckOutput(**progress_dict) - else: - # If it's already a Pydantic model - progress = progress_response - logger.debug(f"Progress check response (Pydantic): {progress}") - - # Step 3: Apply plan updates atomically - verdict = progress.verdict - status_updates = progress.plan_status_update or [] - plan_updates = progress.plan_restructure or [] - - # Convert Pydantic models to dictionaries for JSON serialization - status_updates_dicts = [ - update.model_dump() if hasattr(update, "model_dump") else update - for update in status_updates - ] - plan_updates_dicts = [ - update.model_dump() if hasattr(update, "model_dump") else update - for update in plan_updates - ] - - if status_updates or plan_updates: - updated_plan = await self.update_plan_internal( - instance_id=instance_id, - plan=plan_objects, - status_updates=status_updates_dicts, - plan_updates=plan_updates_dicts, - ) - else: - updated_plan = plan_objects - - return { - "plan": updated_plan, - "verdict": verdict, - "status_updates": status_updates_dicts, - "plan_updates": plan_updates_dicts, - "status": "success", - } - - except Exception as e: - logger.error(f"Failed to process agent response: {e}") - # Rollback: revert task history and plan changes - await self.rollback_agent_response_processing( - instance_id, agent, step_id, substep_id - ) - - # Save failure state to workflow state - from datetime import timezone - - await self.update_workflow_state( - instance_id=instance_id, - message={ - "name": agent, - "role": "system", - "content": f"Failed to process agent response: {str(e)}", - "step": step_id, - "substep": substep_id, - "timestamp": datetime.now(timezone.utc).isoformat(), - }, - ) - - # Return a failure response to prevent workflow from continuing - return { - "plan": plan_objects, - "verdict": "failed", - "status_updates": [], - "plan_updates": [], - "status": "failed", - } - - @task - async def finalize_workflow_with_summary( - self, - instance_id: str, - task: str, - verdict: str, - plan_objects: List[Dict[str, Any]], - step_id: int, - substep_id: Optional[float], - agent: str, - result: str, - wf_time: str, - ) -> str: - """ - Atomically generates summary and finalizes workflow. - - Args: - instance_id (str): The workflow instance ID. - task (str): The original task description. - verdict (str): The final verdict. - plan_objects (List[Dict[str, Any]]): The current plan. - step_id (int): The last step worked on. - substep_id (Optional[float]): The last substep worked on. - agent (str): The last agent that worked. - result (str): The last result. - wf_time (str): Workflow timestamp. - - Returns: - Final summary string - """ - try: - # Step 1: Generate summary using LLM directly - summary_response = self.llm.generate( - messages=[ - { - "role": "user", - "content": SUMMARY_GENERATION_PROMPT.format( - task=task, - verdict=verdict, - plan=json.dumps( - self._convert_plan_objects_to_dicts(plan_objects), - indent=2, - ), - step=step_id, - substep=substep_id if substep_id is not None else "N/A", - agent=agent, - result=result, - ), - } - ], - ) - - # Parse the response - handle both raw responses and direct content - if hasattr(summary_response, "choices") and summary_response.choices: - # If it's still a raw response, parse it - summary = summary_response.choices[0].message.content - logger.debug(f"Summary generation response: {summary}") - else: - # If it's already processed content - summary = str(summary_response) - logger.debug(f"Summary generation response (processed): {summary}") - - # Step 2: Finalize workflow with plan updates - await self.finish_workflow_internal( - instance_id=instance_id, - plan=plan_objects, - step=step_id, - substep=substep_id, - verdict=verdict, - summary=summary, - wf_time=wf_time, - ) - - return summary - - except Exception as e: - logger.error(f"Failed to finalize workflow: {e}") - # Rollback: ensure workflow state is consistent - await self.rollback_workflow_finalization(instance_id) - raise - - # TODO: this should be a compensating activity called in the event of an error from any other activity. - async def update_workflow_state( - self, - instance_id: str, - message: Optional[Dict[str, Any]] = None, - final_output: Optional[str] = None, - plan: Optional[List[Dict[str, Any]]] = None, - wf_time: Optional[str] = None, - ): - """ - Updates the workflow state with a new message, execution plan, or final output. - - Args: - instance_id (str): The unique identifier of the workflow instance. - message (Optional[Dict[str, Any]]): A structured message to be added to the workflow state. - final_output (Optional[str]): The final result of the workflow execution. - plan (Optional[List[Dict[str, Any]]]): The execution plan associated with the workflow instance. - - Raises: - ValueError: If the workflow instance ID is not found in the local state. - """ - workflow_entry = self.state["instances"].get(instance_id) - if not workflow_entry: - raise ValueError( - f"No workflow entry found for instance_id {instance_id} in local state." - ) - - # Only update the provided fields - if plan is not None: - workflow_entry["plan"] = plan - if message is not None: - serialized_message = LLMWorkflowMessage(**message).model_dump(mode="json") - - # Update workflow state messages - workflow_entry["messages"].append(serialized_message) - workflow_entry["last_message"] = serialized_message - - # Update the local chat history if memory is enabled - if self.memory: - self.memory.add_message(message) - - if final_output is not None: - workflow_entry["output"] = final_output - if wf_time is not None: - workflow_entry["end_time"] = wf_time - - # Store workflow instance ID, workflow name, and session_id for session-based state rehydration - workflow_entry["workflow_instance_id"] = instance_id - workflow_entry["workflow_name"] = self._workflow_name - workflow_entry["session_id"] = ( - self.memory.session_id if self.memory else f"{self.name}_default_session" - ) - - # Persist updated state - self.save_state() - - @message_router - async def process_agent_response(self, message: AgentTaskResponse): - """ - Processes agent response messages sent directly to the agent's topic. - - Args: - message (AgentTaskResponse): The agent's response containing task results. - - Returns: - None: The function raises a workflow event with the agent's response. - """ - try: - workflow_instance_id = getattr(message, "workflow_instance_id", None) - - if not workflow_instance_id: - logger.error( - f"{self.name} received an agent response without a valid workflow_instance_id. Ignoring." - ) - return - # Log the received response - logger.debug( - f"{self.name} received response for workflow {workflow_instance_id}" - ) - logger.debug(f"Full response: {message}") - # Raise a workflow event with the Agent's Task Response - self.raise_workflow_event( - instance_id=workflow_instance_id, - event_name="AgentTaskResponse", - data=message, - ) - - except Exception as e: - logger.exception(f"Error processing agent response: {e}", exc_info=True) - - async def broadcast_message_to_agents_internal( - self, instance_id: str, message: str - ) -> None: - """ - Internal helper for broadcasting messages to agents. - """ - logger.info(f"Broadcasting message to all agents (Instance ID: {instance_id})") - - # Create broadcast message - broadcast_msg = BroadcastMessage(content=message, name=self.name, role="user") - - # Add workflow instance ID to metadata - broadcast_msg._message_metadata = { - "workflow_instance_id": instance_id, - "source": self.name, - "type": "BroadcastMessage", - } - - # Send to beacon_channel topic - await self.send_message_to_agent(name="beacon_channel", message=broadcast_msg) - - async def trigger_agent_internal( - self, - instance_id: str, - name: str, - step: int, - substep: Optional[float], - instruction: str, - plan: List[Dict[str, Any]], - ) -> List[Dict[str, Any]]: - """ - Internal helper for triggering agents and updating plan status. - """ - logger.info( - f"Triggering agent {name} for step {step}, substep {substep} (Instance ID: {instance_id})" - ) - - # Get the workflow entry from self.state - workflow_entry = self.state["instances"].get(instance_id) - if not workflow_entry: - raise ValueError(f"No workflow entry found for instance_id: {instance_id}") - - # Ensure step or substep exists - step_entry = find_step_in_plan(plan, step, substep) - if not step_entry: - if substep is not None: - raise ValueError( - f"Substep {substep} in Step {step} not found in the current plan." - ) - raise ValueError(f"Step {step} not found in the current plan.") - - # Mark step or substep as "in_progress" - step_entry["status"] = "in_progress" - logger.debug(f"Marked step {step}, substep {substep} as 'in_progress'") - - # Apply global status updates to maintain consistency - updated_plan = update_step_statuses(plan) - - # Save updated plan state - await self.update_workflow_state(instance_id=instance_id, plan=updated_plan) - - # Send message to agent with specific task instruction - await self.send_message_to_agent( - name=name, - message=InternalTriggerAction( - task=instruction, workflow_instance_id=instance_id - ), - ) - - return updated_plan - - async def update_task_history_internal( - self, - instance_id: str, - agent: str, - step: int, - substep: Optional[float], - results: Dict[str, Any], - plan: List[Dict[str, Any]], - ): - """ - Internal helper for updating task history. - """ - logger.debug( - f"Updating task history for {agent} at step {step}, substep {substep} (Instance ID: {instance_id})" - ) - - # Store the agent's response in the message history - await self.update_workflow_state(instance_id=instance_id, message=results) - - # Retrieve Workflow state - workflow_entry = self.state["instances"].get(instance_id) - if not workflow_entry: - raise ValueError(f"No workflow entry found for instance_id: {instance_id}") - - # Create a TaskResult object - task_result = TaskResult( - agent=agent, step=step, substep=substep, result=results["content"] - ) - - # Append the result to task history - workflow_entry["task_history"].append(task_result.model_dump(mode="json")) - - # Persist state - await self.update_workflow_state( - instance_id=instance_id, plan=workflow_entry["plan"] - ) - - async def update_plan_internal( - self, - instance_id: str, - plan: List[Dict[str, Any]], - status_updates: Optional[List[Dict[str, Any]]] = None, - plan_updates: Optional[List[Dict[str, Any]]] = None, - ) -> List[Dict[str, Any]]: - """ - Internal helper for updating the execution plan. - """ - logger.debug(f"Updating plan for instance {instance_id}") - - # Step 1: Apply status updates directly to `plan` - if status_updates: - logger.info(f"Applying {len(status_updates)} status updates to plan") - for update in status_updates: - step_id = update["step"] - substep_id = update.get("substep") - new_status = update["status"] - - logger.info( - f"Updating step {step_id}, substep {substep_id} to '{new_status}'" - ) - step_entry = find_step_in_plan(plan, step_id, substep_id) - if not step_entry: - error_msg = f"Step {step_id}, Substep {substep_id} not found in the current plan." - logger.error(error_msg) - raise ValueError(error_msg) - - # Apply status update - step_entry["status"] = new_status - logger.info( - f"Successfully updated status of step {step_id}, substep {substep_id} to '{new_status}'" - ) - - # Step 2: Apply plan restructuring updates (if provided) - if plan_updates: - plan = restructure_plan(plan, plan_updates) - logger.debug( - f"Applied restructuring updates for {len(plan_updates)} steps." - ) - - # Step 3: Apply global consistency checks for statuses - plan = update_step_statuses(plan) - - # Save to state and update workflow - await self.update_workflow_state(instance_id=instance_id, plan=plan) - - logger.info(f"Plan successfully updated for instance {instance_id}") - return plan - - async def finish_workflow_internal( - self, - instance_id: str, - plan: List[Dict[str, Any]], - step: int, - substep: Optional[float], - verdict: str, - summary: str, - wf_time: str, - ): - """ - Internal helper for finalizing workflow. - """ - status_updates = [] - - if verdict == "completed": - # Find and validate the step or substep - step_entry = find_step_in_plan(plan, step, substep) - if not step_entry: - raise ValueError( - f"Step {step}, Substep {substep} not found in the current plan. Cannot mark as completed." - ) - - # Mark the step or substep as completed - step_entry["status"] = "completed" - status_updates.append( - {"step": step, "substep": substep, "status": "completed"} - ) - - # If it's a substep, check if all sibling substeps are completed - parent_step = find_step_in_plan( - plan, step - ) # Retrieve parent without `substep` - if parent_step: - # Ensure "substeps" is a valid list before iteration - if not isinstance(parent_step.get("substeps"), list): - parent_step["substeps"] = [] - - all_substeps_completed = all( - ss.get("status") == "completed" for ss in parent_step["substeps"] - ) - if all_substeps_completed: - parent_step["status"] = "completed" - status_updates.append({"step": step, "status": "completed"}) - - # Apply updates in one call - if status_updates: - await self.update_plan_internal( - instance_id=instance_id, plan=plan, status_updates=status_updates - ) - - # Store the final summary and verdict in workflow state - await self.update_workflow_state( - instance_id=instance_id, wf_time=wf_time, final_output=summary - ) - - # ============================================================================ - # ROLLBACK AND COMPENSATION METHODS - # ============================================================================ - - async def rollback_workflow_initialization(self, instance_id: str): - """ - Rollback workflow initialization by clearing partial state. - """ - try: - if instance_id in self.state["instances"]: - # Clear the plan if it was partially created - self.state["instances"][instance_id]["plan"] = [] - self.save_state() - logger.info(f"Rolled back workflow initialization for {instance_id}") - except Exception as e: - logger.error(f"Failed to rollback workflow initialization: {e}") - - async def rollback_agent_trigger( - self, instance_id: str, step_id: int, substep_id: Optional[float] - ): - """ - Rollback agent trigger by reverting step status. - """ - try: - workflow_entry = self.state["instances"].get(instance_id) - if workflow_entry and "plan" in workflow_entry: - plan = workflow_entry["plan"] - step_entry = find_step_in_plan(plan, step_id, substep_id) - if step_entry and step_entry["status"] == "in_progress": - step_entry["status"] = "not_started" - await self.update_workflow_state(instance_id=instance_id, plan=plan) - logger.info( - f"Rolled back agent trigger for step {step_id}, substep {substep_id}" - ) - except Exception as e: - logger.error(f"Failed to rollback agent trigger: {e}") - - async def rollback_agent_response_processing( - self, instance_id: str, agent: str, step_id: int, substep_id: Optional[float] - ): - """ - Rollback agent response processing by reverting changes. - """ - try: - workflow_entry = self.state["instances"].get(instance_id) - if workflow_entry: - # Remove the last task result if it was added - if "task_history" in workflow_entry and workflow_entry["task_history"]: - # Find and remove the last entry for this agent/step - task_history = workflow_entry["task_history"] - for i in range(len(task_history) - 1, -1, -1): - task = task_history[i] - if ( - task.get("agent") == agent - and task.get("step") == step_id - and task.get("substep") == substep_id - ): - task_history.pop(i) - break - - # Revert step status if it was changed - if "plan" in workflow_entry: - plan = workflow_entry["plan"] - step_entry = find_step_in_plan(plan, step_id, substep_id) - if step_entry and step_entry["status"] == "completed": - step_entry["status"] = "in_progress" - await self.update_workflow_state( - instance_id=instance_id, plan=plan - ) - - logger.info( - f"Rolled back agent response processing for {agent} at step {step_id}, substep {substep_id}" - ) - except Exception as e: - logger.error(f"Failed to rollback agent response processing: {e}") - - async def rollback_workflow_finalization(self, instance_id: str): - """ - Rollback workflow finalization to ensure consistent state. - """ - try: - workflow_entry = self.state["instances"].get(instance_id) - if workflow_entry: - # Clear final output if it was set - if "output" in workflow_entry: - workflow_entry["output"] = None - if "end_time" in workflow_entry: - workflow_entry["end_time"] = None - - self.save_state() - logger.info(f"Rolled back workflow finalization for {instance_id}") - except Exception as e: - logger.error(f"Failed to rollback workflow finalization: {e}") - - # ============================================================================ - # COMPENSATION ACTIVITY FOR FAILED COMBINED ACTIVITIES - # ============================================================================ - - @task - async def compensate_failed_activity( - self, instance_id: str, failed_activity: str, activity_context: Dict[str, Any] - ) -> Dict[str, Any]: - """ - Compensates for a failed combined activity by rolling back changes and restoring state. - - Args: - instance_id (str): The workflow instance ID. - failed_activity (str): The name of the failed activity. - activity_context (Dict[str, Any]): Context about the failed activity. - - Returns: - Dict containing compensation status and any recovery actions taken. - """ - try: - logger.warning( - f"Compensating for failed activity: {failed_activity} (Instance: {instance_id})" - ) - - compensation_actions = [] - - if failed_activity == "initialize_workflow_with_plan": - await self.rollback_workflow_initialization(instance_id) - compensation_actions.append("cleared_partial_plan") - - elif failed_activity == "execute_agent_task_with_progress_tracking": - step_id = activity_context.get("step_id") - substep_id = activity_context.get("substep_id") - if step_id is not None: - await self.rollback_agent_trigger(instance_id, step_id, substep_id) - compensation_actions.append("reverted_step_status") - - elif failed_activity == "process_agent_response_with_progress": - agent = activity_context.get("agent") - step_id = activity_context.get("step_id") - substep_id = activity_context.get("substep_id") - if agent and step_id is not None: - await self.rollback_agent_response_processing( - instance_id, agent, step_id, substep_id - ) - compensation_actions.append("reverted_response_processing") - - elif failed_activity == "finalize_workflow_with_summary": - await self.rollback_workflow_finalization(instance_id) - compensation_actions.append("reverted_finalization") - - # Ensure workflow state is consistent after compensation - await self.ensure_workflow_state_consistency(instance_id) - - return { - "status": "compensated", - "failed_activity": failed_activity, - "compensation_actions": compensation_actions, - "instance_id": instance_id, - } - - except Exception as e: - logger.error(f"Failed to compensate for activity {failed_activity}: {e}") - return { - "status": "compensation_failed", - "failed_activity": failed_activity, - "error": str(e), - "instance_id": instance_id, - } - - async def ensure_workflow_state_consistency(self, instance_id: str): - """ - Ensures workflow state is consistent after compensation. - """ - try: - workflow_entry = self.state["instances"].get(instance_id) - if not workflow_entry: - logger.warning( - f"No workflow entry found for {instance_id} during consistency check" - ) - return - - # Ensure plan exists and is valid - if "plan" not in workflow_entry or not workflow_entry["plan"]: - workflow_entry["plan"] = [] - - # Ensure task_history exists - if "task_history" not in workflow_entry: - workflow_entry["task_history"] = [] - - # Ensure messages exists - if "messages" not in workflow_entry: - workflow_entry["messages"] = [] - - # Save the consistent state - self.save_state() - logger.info(f"Ensured workflow state consistency for {instance_id}") - - except Exception as e: - logger.error(f"Failed to ensure workflow state consistency: {e}") - - # ============================================================================ - # ERROR HANDLING WRAPPER FOR COMBINED ACTIVITIES - # ============================================================================ - - async def execute_with_compensation( - self, activity_func, activity_name: str, instance_id: str, **kwargs - ) -> Any: - """ - Executes a combined activity with automatic compensation on failure. - - Args: - activity_func: The activity function to execute. - activity_name: The name of the activity for logging and compensation. - instance_id: The workflow instance ID. - **kwargs: Arguments to pass to the activity function. - - Returns: - The result of the activity function. - - Raises: - Exception: If the activity fails and compensation also fails. - """ - try: - return await activity_func(**kwargs) - except Exception as e: - logger.error(f"Activity {activity_name} failed: {e}") - - # Prepare context for compensation - activity_context = {"instance_id": instance_id, "error": str(e), **kwargs} - - # Attempt compensation - compensation_result = await self.compensate_failed_activity( - instance_id=instance_id, - failed_activity=activity_name, - activity_context=activity_context, - ) - - if compensation_result["status"] == "compensated": - logger.info(f"Successfully compensated for failed {activity_name}") - else: - logger.error( - f"Compensation failed for {activity_name}: {compensation_result}" - ) - - # Re-raise the original exception - raise diff --git a/dapr_agents/workflow/orchestrators/orchestratorWorkflows.png b/dapr_agents/workflow/orchestrators/orchestratorWorkflows.png deleted file mode 100644 index 78d43231..00000000 Binary files a/dapr_agents/workflow/orchestrators/orchestratorWorkflows.png and /dev/null differ diff --git a/dapr_agents/workflow/orchestrators/orchestratorWorkflows.svg b/dapr_agents/workflow/orchestrators/orchestratorWorkflows.svg deleted file mode 100644 index b3df5da4..00000000 --- a/dapr_agents/workflow/orchestrators/orchestratorWorkflows.svg +++ /dev/null @@ -1,31 +0,0 @@ -

Orchestrator workflows - managing comm btw agents


  1. llm/orchestrator.py



get_agents_metadata_as_string

generate_plan

prepare_initial_message

broadcast_message_to_agents

generate_next_step

validate_next_step

broadcast_message_to_agents

if valid

trigger_agent

trigger next agent

start

AgentTaskResponse

wait for external event of  -agent task responsewait for external event of  -agent task response

timeout timer

wait for agent response  -timeoutwait for agent response  -timeout

whichever comes first!

update_task_history

check_progress

verdict = "continue"

generate_summary

finish_workflow

if verdict != "continue  -and still have next  -iteration countif verdict != "continue  -and still have next  -iteration count

update_plan

status updates  -or plan updatesstatus updates  -or plan updatescontinue as newcontinue as new
  1. RandomWorkflow

start

process_input

iteration 0iteration 0

select_random_speaker

broadcast_message_to_agents

iteration !0iteration !0

trigger_agent

trigger next agent

AgentTaskResponse

wait for external event of  -agent task responsewait for external event of  -agent task response

timeout timer

wait for agent response  -timeoutwait for agent response  -timeoutmore iterations leftmore iterations left

end

>  -max_iteration -s>  -max_iteration -scontinue as newcontinue as new
  1. RoundRobinWorkflow

start

process_input

iteration 0iteration 0

select_next_speaker

broadcast_message_to_agents

iteration !0iteration !0

trigger_agent

trigger next agent

AgentTaskResponse

wait for external event of  -agent task responsewait for external event of  -agent task response

timeout timer

wait for agent response  -timeoutwait for agent response  -timeoutmore iterations leftmore iterations left

end

>  -max_iteration -s>  -max_iteration -scontinue as newcontinue as new
\ No newline at end of file diff --git a/dapr_agents/workflow/orchestrators/random.py b/dapr_agents/workflow/orchestrators/random.py deleted file mode 100644 index e823da59..00000000 --- a/dapr_agents/workflow/orchestrators/random.py +++ /dev/null @@ -1,223 +0,0 @@ -import logging -import random -from datetime import timedelta -from typing import Any, Dict, Optional - -from dapr.ext.workflow import DaprWorkflowContext -from pydantic import BaseModel, Field - -from dapr_agents.types import BaseMessage -from dapr_agents.workflow.decorators import message_router, task, workflow -from dapr_agents.workflow.orchestrators.base import OrchestratorWorkflowBase - -logger = logging.getLogger(__name__) - - -class BroadcastMessage(BaseMessage): - """ - Represents a broadcast message from an agent. - """ - - -class AgentTaskResponse(BaseMessage): - """ - Represents a response message from an agent after completing a task. - """ - - workflow_instance_id: Optional[str] = Field( - default=None, description="Dapr workflow instance id from source if available" - ) - - -class TriggerAction(BaseModel): - """ - Represents a message used to trigger an agent's activity within the workflow. - """ - - task: Optional[str] = Field( - None, - description="The specific task to execute. If not provided, the agent will act " - "based on its memory or predefined behavior.", - ) - workflow_instance_id: Optional[str] = Field( - default=None, description="Dapr workflow instance id from source if available" - ) - - -class RandomOrchestrator(OrchestratorWorkflowBase): - """ - Implements a random workflow where agents are selected randomly to perform tasks. - The workflow iterates through conversations, selecting a random agent at each step. - - Runs in a single for-loop, breaking when max_iterations is reached. - """ - - current_speaker: Optional[str] = Field( - default=None, - init=False, - description="Current speaker in the conversation, to avoid immediate repeats when possible.", - ) - - def model_post_init(self, __context: Any) -> None: - """ - Initializes and configures the random workflow service. - """ - self._workflow_name = "RandomWorkflow" - super().model_post_init(__context) - - @workflow(name="RandomWorkflow") - # TODO: add retry policies on activities. - def main_workflow(self, ctx: DaprWorkflowContext, input: TriggerAction): - """ - Executes the random workflow in up to `self.max_iterations` turns, selecting - a different (or same) agent at random each turn. - - Args: - ctx (DaprWorkflowContext): Workflow context. - input (TriggerAction): Contains `task`. - - Returns: - str: The final message content when the workflow terminates. - """ - # Step 1: Gather initial task and instance ID - task = input.get("task") - instance_id = ctx.instance_id - final_output: Optional[str] = None - - # Single loop from turn 1 to max_iterations inclusive - for turn in range(1, self.max_iterations + 1): - if not ctx.is_replaying: - logger.info( - f"Random workflow turn {turn}/{self.max_iterations} " - f"(Instance ID: {instance_id})" - ) - - # Step 2: On turn 1, process initial task and broadcast - if turn == 1: - message = yield ctx.call_activity( - self.process_input, input={"task": task} - ) - logger.info(f"Initial message from {message['role']} -> {self.name}") - yield ctx.call_activity( - self.broadcast_message_to_agents, input={"message": message} - ) - - # Step 3: Select a random speaker - random_speaker = yield ctx.call_activity(self.select_random_speaker) - if not ctx.is_replaying: - logger.info(f"{self.name} selected {random_speaker} (Turn {turn}).") - - # Step 4: Trigger the agent - yield ctx.call_activity( - self.trigger_agent, - input={"name": random_speaker, "instance_id": instance_id}, - ) - - # Step 5: Await for agent response or timeout - if not ctx.is_replaying: - logger.debug("Waiting for agent response...") - event_data = ctx.wait_for_external_event("AgentTaskResponse") - timeout_task = ctx.create_timer(timedelta(seconds=self.timeout)) - any_results = yield self.when_any([event_data, timeout_task]) - - # Step 6: Handle response or timeout - if any_results == timeout_task: - if not ctx.is_replaying: - logger.warning( - f"Turn {turn}: agent response timed out (Instance ID: {instance_id})." - ) - result = { - "name": "timeout", - "content": "⏰ Timeout occurred. Continuing...", - } - else: - result = yield event_data - if not ctx.is_replaying: - logger.info(f"{result['name']} -> {self.name}") - - # Step 7: If this is the last allowed turn, mark final_output and break - if turn == self.max_iterations: - if not ctx.is_replaying: - logger.info( - f"Turn {turn}: max iterations reached (Instance ID: {instance_id})." - ) - final_output = result["content"] - break - - # Otherwise, feed into next turn - task = result["content"] - - # Sanity check (should never happen) - if final_output is None: - raise RuntimeError( - "RandomWorkflow completed without producing a final_output" - ) - - # Return the final message content - return final_output - - @task - async def process_input(self, task: str) -> Dict[str, Any]: - """ - Wraps the raw task into a UserMessage dict. - """ - return {"role": "user", "name": self.name, "content": task} - - @task - async def broadcast_message_to_agents(self, message: Dict[str, Any]): - """ - Broadcasts a message to all agents (excluding orchestrator). - """ - task_message = BroadcastMessage(**message) - await self.broadcast_message(message=task_message, exclude_orchestrator=True) - - @task - def select_random_speaker(self) -> str: - """ - Selects a random speaker, avoiding repeats when possible. - """ - agents = self.get_agents_metadata(exclude_orchestrator=True) - if not agents: - logger.error("No agents available for selection.") - raise ValueError("Agents list is empty.") - - names = list(agents.keys()) - # Avoid repeating previous speaker if more than one agent - if len(names) > 1 and self.current_speaker in names: - names.remove(self.current_speaker) - - choice = random.choice(names) - self.current_speaker = choice - return choice - - @task - async def trigger_agent(self, name: str, instance_id: str) -> None: - """ - Sends a TriggerAction to the selected agent. - """ - logger.info(f"Triggering agent {name} (Instance ID: {instance_id})") - await self.send_message_to_agent( - name=name, - message=TriggerAction(workflow_instance_id=instance_id), - ) - - @message_router - async def process_agent_response(self, message: AgentTaskResponse): - """ - Handles incoming AgentTaskResponse events and re-raises them into the workflow. - """ - workflow_instance_id = getattr(message, "workflow_instance_id", None) - if not workflow_instance_id: - logger.error("Missing workflow_instance_id on AgentTaskResponse; ignoring.") - return - # Log the received response - logger.debug( - f"{self.name} received response for workflow {workflow_instance_id}" - ) - logger.debug(f"Full response: {message}") - # Raise a workflow event with the Agent's Task Response - self.raise_workflow_event( - instance_id=workflow_instance_id, - event_name="AgentTaskResponse", - data=message, - ) diff --git a/dapr_agents/workflow/orchestrators/roundrobin.py b/dapr_agents/workflow/orchestrators/roundrobin.py deleted file mode 100644 index 8ec38d53..00000000 --- a/dapr_agents/workflow/orchestrators/roundrobin.py +++ /dev/null @@ -1,241 +0,0 @@ -import logging -from datetime import timedelta -from typing import Any, Dict, Optional - -from dapr.ext.workflow import DaprWorkflowContext -from pydantic import BaseModel, Field - -from dapr_agents.types import BaseMessage -from dapr_agents.workflow.decorators import message_router, task, workflow -from dapr_agents.workflow.orchestrators.base import OrchestratorWorkflowBase - -logger = logging.getLogger(__name__) - - -class AgentTaskResponse(BaseMessage): - """ - Represents a response message from an agent after completing a task. - """ - - workflow_instance_id: Optional[str] = Field( - default=None, description="Dapr workflow instance id from source if available" - ) - - -class BroadcastMessage(BaseMessage): - """ - Represents a broadcast message from an agent. - """ - - -class TriggerAction(BaseModel): - """ - Represents a message used to trigger an agent's activity within the workflow. - """ - - task: Optional[str] = Field( - None, - description="The specific task to execute. If not provided, the agent will act " - "based on its memory or predefined behavior.", - ) - workflow_instance_id: Optional[str] = Field( - default=None, description="Dapr workflow instance id from source if available" - ) - - -class RoundRobinOrchestrator(OrchestratorWorkflowBase): - """ - Implements a round-robin workflow where agents take turns performing tasks. - Iterates for up to `self.max_iterations` turns, then returns the last reply. - """ - - def model_post_init(self, __context: Any) -> None: - """ - Initializes and configures the round-robin workflow. - """ - self._workflow_name = "RoundRobinWorkflow" - super().model_post_init(__context) - - @workflow(name="RoundRobinWorkflow") - # TODO: add retry policies on activities. - def main_workflow(self, ctx: DaprWorkflowContext, input: TriggerAction) -> str: - """ - Drives the round-robin loop in up to `max_iterations` turns. - - Args: - ctx (DaprWorkflowContext): Workflow context. - input (TriggerAction): Contains the initial `task`. - - Returns: - str: The final message content when the workflow terminates. - """ - # Step 1: Extract task and instance ID from input - task = input.get("task") - instance_id = ctx.instance_id - final_output: Optional[str] = None - - # Loop from 1..max_iterations - for turn in range(1, self.max_iterations + 1): - if not ctx.is_replaying: - logger.info( - f"Round-robin turn {turn}/{self.max_iterations} " - f"(Instance ID: {instance_id})" - ) - - # Step 2: On turn 1, process input and broadcast message - if turn == 1: - message = yield ctx.call_activity( - self.process_input, input={"task": task} - ) - if not ctx.is_replaying: - logger.info( - f"Initial message from {message['role']} -> {self.name}" - ) - yield ctx.call_activity( - self.broadcast_message_to_agents, input={"message": message} - ) - - # Step 3: Select next speaker in round-robin order - speaker = yield ctx.call_activity( - self.select_next_speaker, input={"turn": turn} - ) - if not ctx.is_replaying: - logger.info(f"Selected agent {speaker} for turn {turn}") - - # Step 4: Trigger that agent - yield ctx.call_activity( - self.trigger_agent, - input={"name": speaker, "instance_id": instance_id}, - ) - - # Step 5: Wait for agent response or timeout - if not ctx.is_replaying: - logger.debug("Waiting for agent response...") - event_data = ctx.wait_for_external_event("AgentTaskResponse") - timeout_task = ctx.create_timer(timedelta(seconds=self.timeout)) - any_results = yield self.when_any([event_data, timeout_task]) - - # Step 6: Handle result or timeout - if any_results == timeout_task: - if not ctx.is_replaying: - logger.warning( - f"Turn {turn}: response timed out " - f"(Instance ID: {instance_id})" - ) - result = { - "name": "timeout", - "content": "Timeout occurred. Continuing...", - } - else: - result = yield event_data - if not ctx.is_replaying: - logger.info(f"{result['name']} -> {self.name}") - - # Step 7: If this is the last allowed turn, capture and break - if turn == self.max_iterations: - if not ctx.is_replaying: - logger.info( - f"Turn {turn}: max iterations reached (Instance ID: {instance_id})." - ) - final_output = result["content"] - break - - # Otherwise, feed into next iteration - task = result["content"] - - # Sanity check: final_output must be set - if final_output is None: - raise RuntimeError( - "RoundRobinWorkflow completed without producing final_output" - ) - - return final_output - - @task - async def process_input(self, task: str) -> Dict[str, Any]: - """ - Processes the input message for the workflow. - - Args: - task (str): The user-provided input task. - Returns: - dict: Serialized UserMessage with the content. - """ - return {"role": "user", "name": self.name, "content": task} - - @task - async def broadcast_message_to_agents(self, message: Dict[str, Any]): - """ - Broadcasts a message to all agents. - - Args: - message (Dict[str, Any]): The message content and additional metadata. - """ - # Format message for broadcasting - task_message = BroadcastMessage(**message) - # Send broadcast message - await self.broadcast_message(message=task_message, exclude_orchestrator=True) - - @task - async def select_next_speaker(self, turn: int) -> str: - """ - Selects the next speaker in round-robin order. - - Args: - turn (int): The current turn number (1-based). - Returns: - str: The name of the selected agent. - """ - agents_metadata = self.get_agents_metadata(exclude_orchestrator=True) - if not agents_metadata: - logger.warning("No agents available for selection.") - raise ValueError("Agents metadata is empty. Cannot select next speaker.") - - agent_names = list(agents_metadata.keys()) - next_speaker = agent_names[(turn - 1) % len(agent_names)] - return next_speaker - - @task - async def trigger_agent(self, name: str, instance_id: str) -> None: - """ - Triggers the specified agent to perform its activity. - - Args: - name (str): Name of the agent to trigger. - instance_id (str): Workflow instance ID for context. - """ - logger.info(f"Triggering agent {name} (Instance ID: {instance_id})") - await self.send_message_to_agent( - name=name, - message=TriggerAction(workflow_instance_id=instance_id), - ) - - @message_router - async def process_agent_response(self, message: AgentTaskResponse): - """ - Processes agent response messages sent directly to the agent's topic. - - Args: - message (AgentTaskResponse): The agent's response containing task results. - - Returns: - None: The function raises a workflow event with the agent's response. - """ - workflow_instance_id = getattr(message, "workflow_instance_id", None) - - if not workflow_instance_id: - logger.error( - f"{self.name} received an agent response without a valid workflow_instance_id. Ignoring." - ) - return - # Log the received response - logger.debug( - f"{self.name} received response for workflow {workflow_instance_id}" - ) - logger.debug(f"Full response: {message}") - # Raise a workflow event with the Agent's Task Response - self.raise_workflow_event( - instance_id=workflow_instance_id, - event_name="AgentTaskResponse", - data=message, - ) diff --git a/dapr_agents/workflow/runners/__init__.py b/dapr_agents/workflow/runners/__init__.py new file mode 100644 index 00000000..077c38cd --- /dev/null +++ b/dapr_agents/workflow/runners/__init__.py @@ -0,0 +1,4 @@ +from .base import WorkflowRunner +from .agent import AgentRunner + +__all__ = ["WorkflowRunner", "AgentRunner"] diff --git a/dapr_agents/workflow/runners/agent.py b/dapr_agents/workflow/runners/agent.py new file mode 100644 index 00000000..4778147f --- /dev/null +++ b/dapr_agents/workflow/runners/agent.py @@ -0,0 +1,319 @@ +from __future__ import annotations + +import asyncio +import concurrent.futures +import logging +import threading +from typing import Any, Callable, Dict, Literal, Optional, TypeVar, Union + +from fastapi import FastAPI + +from dapr_agents.types.workflow import PubSubRouteSpec +from dapr_agents.workflow.runners.base import WorkflowRunner +from dapr_agents.workflow.utils.core import get_decorated_methods +from dapr_agents.workflow.utils.registration import ( + register_http_routes, + register_message_routes, +) + +logger = logging.getLogger(__name__) + +R = TypeVar("R") + + +def workflow_entry(func: Callable[..., R]) -> Callable[..., R]: + """ + Mark a method/function as the workflow entrypoint for an Agent. + + This decorator does not wrap the function; it simply annotates the callable + with `_is_workflow_entry = True` so AgentRunner can discover it on the agent + instance via reflection. + + Usage: + class MyAgent: + @workflow_entry + def my_workflow(self, ctx: DaprWorkflowContext, wf_input: dict) -> str: + ... + + Returns: + The same callable (unmodified), with an identifying attribute. + """ + setattr(func, "_is_workflow_entry", True) # type: ignore[attr-defined] + return func + + +class AgentRunner(WorkflowRunner): + """ + Runner specialized for Agent classes. + """ + + def __init__( + self, + *, + name: str = "agent-runner", + wf_client=None, + timeout_in_seconds: int = 600, + auto_install_signals: bool = False, + ) -> None: + """ + Initialize an AgentRunner. + + Args: + name: Logical name used in logs (defaults to "agent-runner"). + wf_client: Optional injected DaprWorkflowClient. If omitted, a new one is created. + timeout_in_seconds: Default timeout used when waiting for workflow completion. + auto_install_signals: If True, installs SIGINT/SIGTERM handlers automatically + when used as a context manager (with/async with) and removes them on exit. + """ + super().__init__( + name=name, + wf_client=wf_client, + timeout_in_seconds=timeout_in_seconds, + auto_install_signals=auto_install_signals, + ) + + async def run( + self, + agent: Any, + payload: Optional[Union[str, Dict[str, Any]]] = None, + *, + instance_id: Optional[str] = None, + wait: bool = True, + timeout_in_seconds: Optional[int] = None, + fetch_payloads: bool = True, + log: bool = True, + ) -> Union[str, Optional[str]]: + """ + Run an Agent's workflow entry. + + Args: + agent: Agent instance containing exactly one bound method marked with `@workflow_entry`. + payload: Workflow input (JSON-serializable dict or string). + instance_id: Workflow instance id; if omitted, a new UUID is generated. + wait: If True, wait for completion and return serialized output; otherwise return instance id immediately. + timeout_in_seconds: Max time to wait when wait=True. If omitted (Runner's timeout), defaults to the runner's configured timeout. + Ignored when wait=False. + fetch_payloads: Whether to fetch input/output payloads when waiting. + log: If True, log the final outcome (sync if `wait=True`, background if `wait=False`). + + Returns: + - If `wait=False`: the workflow instance id (str). + - If `wait=True`: the serialized output string, or `None` on timeout/error. + + Raises: + RuntimeError: If zero or multiple entry methods are found on the Agent. + """ + logger.debug( + "[%s] Start run: agent=%s payload=%s wait=%s timeout=%s", + self._name, + type(agent).__name__, + payload, + wait, + timeout_in_seconds, + ) + + entry = self.discover_entry(agent) + logger.debug("[%s] Discovered workflow entry: %s", self._name, entry.__name__) + + return await self.run_workflow_async( + entry, + payload, + instance_id=instance_id, + timeout_in_seconds=timeout_in_seconds, + fetch_payloads=fetch_payloads, + detach=not wait, + log=log, + ) + + def run_sync( + self, + agent: Any, + payload: Optional[Union[str, Dict[str, Any]]] = None, + *, + instance_id: Optional[str] = None, + timeout_in_seconds: Optional[int] = None, + fetch_payloads: bool = True, + log: bool = True, + ) -> Optional[str]: + """ + Synchronously run an Agent's workflow entry and wait for completion. + + Args: + agent: Agent instance containing exactly one bound method marked with `@workflow_entry`. + payload: Workflow input (JSON-serializable dict or string). + instance_id: Workflow instance id; if omitted, a new UUID is generated. + timeout_in_seconds: Max time to wait when wait=True. If omitted (Runner's timeout), defaults to the runner's configured timeout. + Ignored when wait=False. + fetch_payloads: Whether to fetch input/output payloads when waiting. + log: If True, log the final outcome. + + Returns: + Serialized output string, or `None` on timeout/error. + """ + coro = self.run( + agent, + payload, + instance_id=instance_id, + wait=True, + timeout_in_seconds=timeout_in_seconds, + fetch_payloads=fetch_payloads, + log=log, + ) + try: + asyncio.get_running_loop() + return self._run_coro_in_new_loop_thread(coro) + except RuntimeError: + return asyncio.run(coro) + + def discover_entry(self, agent: Any) -> Callable[..., Any]: + """ + Locate exactly one bound method on `agent` marked with `@workflow_entry`. + + Returns: + The bound method to schedule. + + Raises: + RuntimeError: If zero or multiple @workflow_entry methods are found. + """ + candidates: list[Callable[..., Any]] = [] + for attr in dir(agent): + fn = getattr(agent, attr) + if callable(fn) and getattr(fn, "_is_workflow_entry", False): + # Ensure it's bound to THIS instance (not a function on the class) + if getattr(fn, "__self__", None) is agent: + candidates.append(fn) + + if not candidates: + raise RuntimeError("Agent has no @workflow_entry method.") + if len(candidates) > 1: + names = ", ".join(getattr(c, "__name__", "") for c in candidates) + raise RuntimeError(f"Agent has multiple @workflow_entry methods: {names}") + return candidates[0] + + @staticmethod + def _run_coro_in_new_loop_thread( + coro: "asyncio.Future[R] | asyncio.coroutines.Coroutine[Any, Any, R]", + ) -> R: + """ + Execute an async coroutine in a brand-new event loop on a background thread, + then return its result to the current thread (which may already be running a loop). + + This enables `run_sync` to work in notebooks and ASGI servers. + + Args: + coro: The coroutine to run. + + Returns: + The coroutine's result, or raises its exception. + """ + fut: "concurrent.futures.Future[R]" = concurrent.futures.Future() + + def _runner() -> None: + loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(loop) + result = loop.run_until_complete(coro) + fut.set_result(result) + except Exception as exc: # noqa: BLE001 + fut.set_exception(exc) + finally: + try: + loop.run_until_complete(loop.shutdown_asyncgens()) + finally: + loop.close() + + t = threading.Thread(target=_runner, daemon=True) + t.start() + return fut.result() + + def register_routes( + self, + agent: Any, + *, + fastapi_app: Optional[FastAPI] = None, + delivery_mode: Literal["sync", "async"] = "sync", + queue_maxsize: int = 1024, + await_result: bool = False, + await_timeout: Optional[int] = None, + fetch_payloads: bool = True, + log_outcome: bool = False, + ) -> None: + """ + Register message/HTTP routes for a single durable agent instance. + + Args: + agent: The agent instance whose routes should be registered. + fastapi_app: Optional FastAPI app to register HTTP routes on. If omitted, no HTTP routes are registered. + delivery_mode: "sync" or "async" delivery for message handlers. + queue_maxsize: Max size of internal message queues. + await_result: If True, message handlers will await workflow results. + await_timeout: Max time to wait for workflow results when `await_result=True`. If omitted (None), waits indefinitely. + fetch_payloads: Whether to fetch input/output payloads for awaited workflows. + log_outcome: Whether to log the final outcome of awaited workflows. + """ + config = getattr(agent, "pubsub_config", None) + if config is None: + logger.debug( + "[%s] Agent %s has no pubsub_config; skipping pub/sub route registration.", + self._name, + getattr(agent, "name", agent), + ) + else: + specs = self._build_pubsub_specs(agent, config) + if specs: + self._ensure_dapr_client() + + if not self._wired_pubsub and self._dapr_client is not None: + closers = register_message_routes( + routes=specs, + dapr_client=self._dapr_client, + delivery_mode=delivery_mode, + queue_maxsize=queue_maxsize, + wf_client=self._wf_client, + await_result=await_result, + await_timeout=await_timeout, + fetch_payloads=fetch_payloads, + log_outcome=log_outcome, + ) + self._pubsub_closers.extend(closers) + self._wired_pubsub = True + + if fastapi_app is not None and not self._wired_http: + register_http_routes( + app=fastapi_app, + targets=[agent], + routes=None, + ) + self._wired_http = True + + def _build_pubsub_specs(self, agent: Any, config: Any) -> list[PubSubRouteSpec]: + handlers = get_decorated_methods(agent, "_is_message_handler") + if not handlers: + return [] + + specs: list[PubSubRouteSpec] = [] + for _, handler in handlers.items(): + meta = getattr(handler, "_message_router_data", {}) + is_broadcast = meta.get("is_broadcast", False) + topic: Optional[str] = ( + config.broadcast_topic if is_broadcast else config.agent_topic + ) + if not topic: + kind = "broadcast" if is_broadcast else "direct" + raise ValueError( + f"AgentPubSubConfig missing topic for {kind} handler {handler.__name__}" + ) + + schemas = meta.get("message_schemas") or [] + message_model = schemas[0] if schemas else None + + specs.append( + PubSubRouteSpec( + pubsub_name=config.pubsub_name, + topic=topic, + handler_fn=handler, + message_model=message_model, + ) + ) + + return specs diff --git a/dapr_agents/workflow/runners/base.py b/dapr_agents/workflow/runners/base.py new file mode 100644 index 00000000..97e06e3a --- /dev/null +++ b/dapr_agents/workflow/runners/base.py @@ -0,0 +1,643 @@ +from __future__ import annotations + +import asyncio +import functools +import logging +import threading +import uuid +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Literal, + Optional, + Union, + overload, +) + +from dapr.clients import DaprClient +from dapr.ext.workflow import DaprWorkflowClient +from dapr.ext.workflow.workflow_state import WorkflowState +from fastapi import FastAPI + +from dapr_agents.types.workflow import HttpRouteSpec, PubSubRouteSpec +from dapr_agents.utils.signal.mixin import SignalMixin +from dapr_agents.workflow.utils.registration import ( + register_http_routes, + register_message_routes, +) + +logger = logging.getLogger(__name__) + + +class WorkflowRunner(SignalMixin): + """ + Host around DaprWorkflowClient with workflow scheduling and optional route wiring. + + Provides: + • Sync/async workflow scheduling + completion waiting. + • Optional wiring for `@message_router` (pub/sub) and `@http_router` (FastAPI). + • Graceful shutdown and OS signal handling. + """ + + def __init__( + self, + *, + name: str = "dapr-workflow-app", + wf_client: Optional[DaprWorkflowClient] = None, + timeout_in_seconds: int = 600, + auto_install_signals: bool = False, + dapr_client: Optional[DaprClient] = None, + ) -> None: + """ + Initialize the runner. + + Args: + name: Logical name used in logs. + wf_client: Existing DaprWorkflowClient. If omitted, a new client is created and owned. + timeout_in_seconds: Default timeout when waiting for completion. + auto_install_signals: Install SIGINT/SIGTERM handlers on context entry. + dapr_client: Optional ready-to-use DaprClient. If omitted, a default one is created/owned. + + Returns: + None + """ + super().__init__() + self._name = name + self._wf_client: DaprWorkflowClient = wf_client or DaprWorkflowClient() + self._wf_client_owned = wf_client is None + self._timeout_in_seconds = timeout_in_seconds + self._client_lock = threading.Lock() + self._auto_install_signals = auto_install_signals + self._signals_installed_by_us = False + + # Router wiring state + self._dapr_client: Optional[DaprClient] = dapr_client + self._dapr_client_owned: bool = dapr_client is None + self._pubsub_closers: List[Callable[[], None]] = [] + self._wired_pubsub = False + self._wired_http = False + + def __enter__(self) -> "WorkflowRunner": + """ + Enter a synchronous context. + + Returns: + Self, so callers can schedule workflows or register routes. + """ + if self._auto_install_signals: + self.install_signal_handlers() + self._signals_installed_by_us = True + return self + + def __exit__(self, exc_type, exc, tb) -> None: + """ + Exit a synchronous context, removing signal handlers and shutting down. + + Args: + exc_type: Exception type if any. + exc: Exception instance if any. + tb: Traceback if any. + + Returns: + None + """ + if exc_type: + logger.error( + "[%s] Context exited with exception", + self._name, + exc_info=(exc_type, exc, tb), + ) + else: + logger.debug("[%s] Context exited cleanly (sync).", self._name) + + try: + if self._signals_installed_by_us: + self.remove_signal_handlers() + self._signals_installed_by_us = False + finally: + self.shutdown() + + async def __aenter__(self) -> "WorkflowRunner": + """ + Enter an asynchronous context. + + Returns: + Self, so callers can schedule workflows or register routes. + """ + if self._auto_install_signals: + self.install_signal_handlers() + self._signals_installed_by_us = True + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + """ + Exit an asynchronous context, removing signal handlers and shutting down. + + Args: + exc_type: Exception type if any. + exc: Exception instance if any. + tb: Traceback if any. + + Returns: + None + """ + if exc_type: + logger.error( + "[%s] Async context exited with exception", + self._name, + exc_info=(exc_type, exc, tb), + ) + else: + logger.debug("[%s] Async context exited cleanly.", self._name) + + try: + await self.graceful_shutdown() + finally: + if self._signals_installed_by_us: + self.remove_signal_handlers() + self._signals_installed_by_us = False + self.shutdown() + + def __del__(self) -> None: + """Best-effort GC close; never raises.""" + try: + self.shutdown() + except Exception: + pass + + def shutdown(self) -> None: + """ + Unwire subscriptions and close owned clients. + + Returns: + None + """ + try: + self.unwire_pubsub() + finally: + self._close_dapr_client() + self._close_wf_client() + + def _ensure_dapr_client(self) -> None: + """ + Ensure a DaprClient exists; create and own a default client if absent. + + Returns: + None + """ + if self._dapr_client is None: + self._dapr_client = DaprClient() + self._dapr_client_owned = True + + def _close_dapr_client(self) -> None: + """ + Close the DaprClient if it is owned by this runner. + + Returns: + None + """ + if self._dapr_client is not None and self._dapr_client_owned: + try: + self._dapr_client.close() + except Exception: + logger.debug("Ignoring error closing DaprClient", exc_info=True) + self._dapr_client = None + self._dapr_client_owned = False + + def _close_wf_client(self) -> None: + """ + Close the DaprWorkflowClient if it is owned by this runner. + + Returns: + None + """ + if self._wf_client is not None and self._wf_client_owned: + try: + self._wf_client.close() + except Exception: + logger.debug( + "Ignoring error while closing DaprWorkflowClient", exc_info=True + ) + + def register_routes( + self, + *, + targets: Optional[Iterable[Any]] = None, + routes: Optional[Iterable[Union[PubSubRouteSpec, HttpRouteSpec]]] = None, + delivery_mode: Literal["sync", "async"] = "sync", + deduper: Optional[Any] = None, + subscribe: Optional[Callable[..., Callable[[], None]]] = None, + await_result: bool = False, + await_timeout: Optional[int] = None, + fetch_payloads: bool = True, + log_outcome: bool = False, + fastapi_app: Optional[FastAPI] = None, + ) -> None: + """ + Wire routes in one of two modes (mutually exclusive): + + 1) Discovery mode: provide `targets` to auto-discover `@message_router` + and `@http_router` handlers. + 2) Explicit mode: provide `routes` (list of PubSubRouteSpec | HttpRouteSpec) + to create subscriptions/endpoints directly. In explicit mode, if a spec + omits `message_model`/`request_model` and `handler_fn` is decorated, the + decorator's schema is used; otherwise `dict`. + + Args: + targets: Instances/functions to scan for decorator metadata. (Discovery mode) + routes: Explicit route specs to wire. (Explicit mode) + delivery_mode: "sync" blocks Dapr thread; "async" enqueues to a worker. + deduper: Optional idempotency backend with `seen(key)` / `mark(key)`. + subscribe: Optional custom subscriber (defaults to DaprClient.subscribe_with_handler). + await_result: If True (sync only), wait for completion and ACK/NACK accordingly. + await_timeout: Timeout (seconds) for completion wait when `await_result=True`. + fetch_payloads: Include payloads when waiting for completion. + log_outcome: Log final workflow outcome (awaited or detached). + fastapi_app: If provided, mount HTTP endpoints. + + Returns: + None + + Raises: + ValueError: If both `targets` and `routes` are provided, or neither is provided. + RuntimeError: If `delivery_mode="async"` without a running event loop. + """ + self._ensure_dapr_client() + + use_targets = targets is not None + use_routes = routes is not None + + if use_targets and use_routes: + raise ValueError( + "Provide either `targets` (discovery) OR `routes` (explicit), not both." + ) + if not use_targets and not use_routes: + raise ValueError( + "You must provide `targets` (discovery) OR `routes` (explicit)." + ) + + # ---- Discovery mode (targets) ---- + if use_targets: + if not self._wired_pubsub and self._dapr_client is not None: + closers = register_message_routes( + dapr_client=self._dapr_client, + targets=targets or [], + routes=None, + delivery_mode=delivery_mode, + deduper=deduper, + subscribe=subscribe, + wf_client=self._wf_client, + await_result=await_result, + await_timeout=await_timeout, + fetch_payloads=fetch_payloads, + log_outcome=log_outcome, + ) + self._pubsub_closers.extend(closers) + self._wired_pubsub = True + + if fastapi_app is not None and not self._wired_http: + register_http_routes( + app=fastapi_app, + targets=targets or [], + routes=None, + ) + self._wired_http = True + return + + # ---- Explicit mode (routes) ---- + specs = list(routes or []) + pubsub_specs = [r for r in specs if isinstance(r, PubSubRouteSpec)] + http_specs = [r for r in specs if isinstance(r, HttpRouteSpec)] + + if pubsub_specs and not self._wired_pubsub and self._dapr_client is not None: + closers = register_message_routes( + routes=pubsub_specs, + dapr_client=self._dapr_client, + delivery_mode=delivery_mode, + deduper=deduper, + subscribe=subscribe, + wf_client=self._wf_client, + await_result=await_result, + await_timeout=await_timeout, + fetch_payloads=fetch_payloads, + log_outcome=log_outcome, + ) + self._pubsub_closers.extend(closers) + self._wired_pubsub = True + + if http_specs and fastapi_app is not None and not self._wired_http: + register_http_routes( + routes=http_specs, + app=fastapi_app, + ) + self._wired_http = True + + def unwire_pubsub(self) -> None: + """ + Unsubscribe all pub/sub handlers wired by this runner. + + Returns: + None + """ + for close in self._pubsub_closers: + try: + close() + except Exception: + logger.exception("Error while closing subscription") + self._pubsub_closers.clear() + self._wired_pubsub = False + + # -------------------- workflow scheduling APIs ---------------------- + + def workflow_client(self) -> DaprWorkflowClient: + """ + Get the underlying DaprWorkflowClient. + + Returns: + DaprWorkflowClient: The active workflow client. + """ + return self._wf_client + + def run_workflow( + self, + workflow: Callable[..., Any], + payload: Optional[Union[str, Dict[str, Any]]] = None, + instance_id: Optional[str] = None, + ) -> str: + """ + Schedule a registered workflow and return its instance id. + + Args: + workflow: Callable pointing to a registered workflow. + payload: Workflow input (dict or serialized string). + instance_id: Optional explicit instance id; autogenerated if omitted. + + Returns: + str: The new workflow instance id. + + Raises: + ValueError: If `workflow` is not callable. + Exception: Any error bubbling up from the Dapr client while scheduling. + """ + if not callable(workflow): + raise ValueError("workflow must be a callable (already registered).") + + chosen_id = instance_id or uuid.uuid4().hex + logger.debug( + "[%s] Scheduling workflow %s id=%s", + self._name, + getattr(workflow, "__name__", workflow), + chosen_id, + ) + + try: + with self._client_lock: + result = self._wf_client.schedule_new_workflow( + workflow=workflow, + input=payload, + instance_id=chosen_id, + ) + logger.debug("[%s] Scheduled workflow id=%s", self._name, result) + return result + except Exception as e: + logger.error("[%s] Failed to schedule workflow: %s", self._name, str(e)) + raise + + @overload + async def run_workflow_async( + self, + workflow: Callable[..., Any], + payload: Optional[Union[str, Dict[str, Any]]] = ..., + instance_id: Optional[str] = ..., + *, + timeout_in_seconds: Optional[int] = ..., + fetch_payloads: bool = ..., + detach: Literal[True], + log: bool = ..., + ) -> str: + ... + + @overload + async def run_workflow_async( + self, + workflow: Callable[..., Any], + payload: Optional[Union[str, Dict[str, Any]]] = ..., + instance_id: Optional[str] = ..., + *, + timeout_in_seconds: Optional[int] = ..., + fetch_payloads: bool = ..., + detach: Literal[False] = ..., + log: bool = ..., + ) -> Optional[str]: + ... + + async def run_workflow_async( + self, + workflow: Callable[..., Any], + payload: Optional[Union[str, Dict[str, Any]]] = None, + instance_id: Optional[str] = None, + *, + timeout_in_seconds: Optional[int] = None, + fetch_payloads: bool = True, + detach: bool = False, + log: bool = False, + ) -> Union[str, Optional[str]]: + """ + Schedule a workflow and optionally wait for completion. + + Args: + workflow: Callable pointing to a registered workflow. + payload: Workflow input (dict or serialized string). + instance_id: Optional explicit instance id; autogenerated if omitted. + timeout_in_seconds: Wait timeout when `detach=False` (defaults to Runner's timeout). + fetch_payloads: Include payloads when waiting for completion. + detach: If True, return instance id immediately; otherwise, wait and return output. + log: If True, log the final outcome (COMPLETED/FAILED). + + Returns: + str | None: + - If `detach=True`: the workflow instance id. + - If `detach=False`: the serialized output string, or `None` on timeout/error. + """ + schedule = functools.partial(self.run_workflow, workflow, payload, instance_id) + instance = await asyncio.to_thread(schedule) + + effective_timeout = timeout_in_seconds or self._timeout_in_seconds + + if detach: + logger.info("[%s] Running in detached mode", self._name) + if log: + asyncio.create_task( + self._await_and_log_state( + instance, effective_timeout, fetch_payloads + ) + ) + return instance + + logger.info("[%s] Waiting for workflow completion...", self._name) + state = await self._await_state(instance, effective_timeout, fetch_payloads) + + if log: + self._log_state(instance, state) + return getattr(state, "serialized_output", None) if state else None + + def wait_for_workflow_completion( + self, + instance_id: str, + *, + fetch_payloads: bool = True, + timeout_in_seconds: Optional[int] = None, + ) -> Optional[WorkflowState]: + """ + Block until a workflow completes and return its final state. + + Args: + instance_id: Workflow instance id to wait on. + fetch_payloads: Include payloads in the returned state. + timeout_in_seconds: Per-call timeout (defaults to Runner's timeout). + + Returns: + WorkflowState | None: Final state, or None on timeout/error. + """ + effective_timeout = timeout_in_seconds or self._timeout_in_seconds + try: + with self._client_lock: + return self._wf_client.wait_for_workflow_completion( + instance_id, + fetch_payloads=fetch_payloads, + timeout_in_seconds=effective_timeout, + ) + except Exception as exc: + logger.error("Error while waiting for %s completion: %s", instance_id, exc) + return None + + # ----------------------- internal helpers --------------------------- + + async def _await_state( + self, + instance_id: str, + timeout_in_seconds: int, + fetch_payloads: bool, + ) -> Optional[WorkflowState]: + """ + Await a workflow's completion using a thread offload. + + Args: + instance_id: Workflow instance id. + timeout_in_seconds: Timeout in seconds. + fetch_payloads: Include payloads. + + Returns: + WorkflowState | None: Final state, or None on timeout/error. + """ + + def _wait() -> Optional[WorkflowState]: + with self._client_lock: + return self._wf_client.wait_for_workflow_completion( + instance_id, + fetch_payloads=fetch_payloads, + timeout_in_seconds=timeout_in_seconds, + ) + + return await asyncio.to_thread(_wait) + + async def _await_and_log_state( + self, + instance_id: str, + timeout_in_seconds: int, + fetch_payloads: bool, + ) -> None: + """ + Await and log a workflow's final state (fire-and-forget). + + Args: + instance_id: Workflow instance id. + timeout_in_seconds: Timeout in seconds. + fetch_payloads: Include payloads. + + Returns: + None + """ + try: + state = await self._await_state( + instance_id, timeout_in_seconds, fetch_payloads + ) + self._log_state(instance_id, state) + except Exception: + logger.exception( + "[%s] %s: error while monitoring workflow outcome", + self._name, + instance_id, + ) + + def _log_state(self, instance_id: str, state: Optional[WorkflowState]) -> None: + """ + Compact logger for final workflow state. + + Args: + instance_id: Workflow instance id. + state: Final state (may be None on timeout/error). + + Returns: + None + """ + if not state: + logger.warning( + "[%s] %s: no state returned (timeout or missing).", + self._name, + instance_id, + ) + return + + status = getattr(state.runtime_status, "name", str(state.runtime_status)) + if status == "COMPLETED": + logger.info( + "[%s] %s completed. Final Output=%s", + self._name, + instance_id, + getattr(state, "serialized_output", None), + ) + return + + fd = getattr(state, "failure_details", None) + if fd: + logger.error( + "[%s] %s: FAILED. type=%s message=%s\n%s", + self._name, + instance_id, + getattr(fd, "error_type", None), + getattr(fd, "message", None), + getattr(fd, "stack_trace", "") or "", + ) + else: + logger.error( + "[%s] %s: finished with status=%s. custom_status=%s", + self._name, + instance_id, + status, + getattr(state, "serialized_custom_status", None), + ) + + # ----------------------- admin utilities ---------------------------- + + def terminate_workflow( + self, + instance_id: str, + *, + output: Optional[Any] = None, + ) -> None: + """ + Terminate a running workflow. + + Args: + instance_id: Workflow instance ID to terminate. + output: Optional output to set for the terminated workflow. + + Returns: + None + """ + with self._client_lock: + self._wf_client.terminate_workflow(instance_id=instance_id, output=output) diff --git a/dapr_agents/workflow/task.py b/dapr_agents/workflow/task.py deleted file mode 100644 index 34f212ab..00000000 --- a/dapr_agents/workflow/task.py +++ /dev/null @@ -1,412 +0,0 @@ -import asyncio -import inspect -import logging -from dataclasses import is_dataclass -from functools import update_wrapper -from types import SimpleNamespace -from typing import Any, Callable, Dict, List, Literal, Optional - -from dapr.ext.workflow import WorkflowActivityContext -from pydantic import BaseModel, ConfigDict, Field - -from dapr_agents.agents.base import AgentBase -from dapr_agents.llm.chat import ChatClientBase -from dapr_agents.llm.openai import OpenAIChatClient -from dapr_agents.llm.utils import StructureHandler -from dapr_agents.prompt.utils.chat import ChatPromptHelper -from dapr_agents.types import BaseMessage, UserMessage, LLMChatResponse - -logger = logging.getLogger(__name__) - - -class WorkflowTask(BaseModel): - """ - Encapsulates task logic for execution by an LLM, agent, or Python function. - - Supports both synchronous and asynchronous tasks, with optional output validation - using Pydantic models or specified return types. - """ - - func: Optional[Callable] = Field( - None, description="The original function to be executed, if provided." - ) - description: Optional[str] = Field( - None, description="A description template for the task, used with LLM or agent." - ) - agent: Optional[AgentBase] = Field( - None, description="The agent used for task execution, if applicable." - ) - llm: Optional[ChatClientBase] = Field( - None, description="The LLM client for executing the task, if applicable." - ) - include_chat_history: Optional[bool] = Field( - False, - description="Whether to include past conversation history in the LLM call.", - ) - workflow_app: Optional[Any] = Field( - None, description="Reference to the WorkflowApp instance." - ) - structured_mode: Literal["json", "function_call"] = Field( - default="json", - description="Structured response mode for LLM output. Valid values: 'json', 'function_call'.", - ) - task_kwargs: Dict[str, Any] = Field( - default_factory=dict, - exclude=True, - description="Additional keyword arguments passed via the @task decorator.", - ) - - # Initialized during setup - signature: Optional[inspect.Signature] = Field( - None, init=False, description="The signature of the provided function." - ) - - model_config = ConfigDict(arbitrary_types_allowed=True) - - def model_post_init(self, __context: Any) -> None: - """ - Post-initialization to set up function signatures and default LLM clients. - """ - # Default to OpenAIChatClient if prompt‐based but no llm provided - if self.description and not self.llm: - try: - self.llm = OpenAIChatClient() - except Exception as e: - logger.warning( - f"Could not create default OpenAI client: {e}. Task will require explicit LLM." - ) - self.llm = None - - if self.func: - # Preserve name / docs for stack traces - try: - update_wrapper(self, self.func) - except AttributeError: - # If the function doesn't have the expected attributes, skip update_wrapper - logger.debug( - f"Could not update wrapper for function {self.func}, skipping" - ) - pass - - # Capture signature for input / output handling - self.signature = inspect.signature(self.func) if self.func else None - - # Honor any structured_mode override - if not self.structured_mode and "structured_mode" in self.task_kwargs: - self.structured_mode = self.task_kwargs["structured_mode"] - - # Proceed with base model setup - super().model_post_init(__context) - - async def __call__(self, ctx: WorkflowActivityContext, payload: Any = None) -> Any: - """ - Executes the task, routing to agent, LLM, or pure-Python logic. - - Dispatches to Python, Agent, or LLM paths and validates output. - - Args: - ctx (WorkflowActivityContext): The workflow execution context. - payload (Any): The task input. - - Returns: - Any: The result of the task. - """ - # Prepare input dict - data = self._normalize_input(payload) if payload is not None else {} - func_name = getattr(self.func, "__name__", "unknown_function") - logger.info(f"Executing task '{func_name}'") - logger.debug(f"Executing task '{func_name}' with input {data!r}") - - try: - executor = self._choose_executor() - if executor in ("agent", "llm"): - if executor == "llm" and not self.description: - raise ValueError("LLM tasks require a description template") - elif executor == "agent": - # For agents, prefer string input for natural conversation - if self.description: - # Use description template with parameter substitution - prompt = self.format_description(self.description, data) - else: - # Pass string input naturally for direct agent conversation - prompt = self._format_natural_agent_input(payload, data) - else: - # LLM with description - prompt = self.format_description(self.description, data) - raw = await self._run_via_ai(prompt, executor) - else: - raw = await self._run_python(data) - - validated = await self._validate_output(raw) - return validated - - except Exception: - func_name = getattr(self.func, "__name__", "unknown_function") - logger.exception(f"Error in task '{func_name}'") - raise - - def _choose_executor(self) -> Literal["agent", "llm", "python"]: - """ - Pick execution path. - - Returns: - One of "agent", "llm", or "python". - - Raises: - ValueError: If no valid executor is configured. - """ - if self.agent: - return "agent" - if self.llm: - return "llm" - if self.func: - return "python" - raise ValueError("No execution path found for this task") - - async def _run_python(self, data: dict) -> Any: - """ - Invoke the Python function directly. - - Args: - data: Keyword arguments for the function. - - Returns: - The function's return value. - """ - logger.debug("Invoking regular Python function") - if asyncio.iscoroutinefunction(self.func): - return await self.func(**data) - else: - return self.func(**data) - - async def _run_via_ai(self, prompt: Any, executor: Literal["agent", "llm"]) -> Any: - """ - Run the prompt through an Agent or LLM. - - Args: - prompt: The prompt data - string for LLM, string/dict/Any for agent. - executor: "agent" or "llm". - - Returns: - Raw result from the AI path. - """ - logger.debug(f"Invoking task via {executor.upper()}") - logger.debug(f"Invoking task with prompt: {prompt!r}") - if executor == "agent": - # Agents can handle string, dict, or other input types - result = await self.agent.run(prompt) - else: - # LLM expects a string prompt - if not isinstance(prompt, str): - raise ValueError( - f"LLM executor requires string prompt, got {type(prompt)}" - ) - result = await self._invoke_llm(prompt) - return self._convert_result(result) - - async def _invoke_llm(self, prompt: str) -> Any: - """ - Build messages and call the LLM client. - - Args: - prompt: The formatted prompt string. - - Returns: - LLM-generated result. - """ - # Gather history if needed - history: List[BaseMessage] = [] - if self.include_chat_history and self.workflow_app: - logger.debug("Retrieving chat history") - history_dicts = self.workflow_app.get_chat_history() - history = ChatPromptHelper.normalize_chat_messages(history_dicts) - - messages: List[BaseMessage] = history + [UserMessage(prompt)] - params: Dict[str, Any] = {"messages": messages} - - # Add structured formatting if return type is a Pydantic model - if ( - self.signature - and self.signature.return_annotation is not inspect.Signature.empty - ): - model_cls = StructureHandler.resolve_response_model( - self.signature.return_annotation - ) - if model_cls: - params["response_format"] = self.signature.return_annotation - params["structured_mode"] = self.structured_mode - - logger.debug(f"LLM call params: {params}") - return self.llm.generate(**params) - - def _normalize_input(self, raw_input: Any) -> dict: - """ - Normalize various input types into a dict. - - Args: - raw_input: Dataclass, SimpleNamespace, single value, or dict. - - Returns: - A dict suitable for function invocation. - - Raises: - ValueError: If signature is missing when wrapping a single value. - """ - if is_dataclass(raw_input): - return raw_input.__dict__ - if isinstance(raw_input, SimpleNamespace): - return vars(raw_input) - if not isinstance(raw_input, dict): - # wrap single argument - if not self.signature or len(self.signature.parameters) == 0: - # No signature or no parameters - return empty dict for consistency - return {} - name = next(iter(self.signature.parameters)) - return {name: raw_input} - return raw_input - - async def _validate_output(self, result: Any) -> Any: - """ - Await and validate the result against return-type model. - - Args: - result: Raw result from executor. - - Returns: - Validated/transformed result. - """ - if asyncio.iscoroutine(result): - result = await result - - if ( - not self.signature - or self.signature.return_annotation is inspect.Signature.empty - ): - return result - - return StructureHandler.validate_against_signature( - result, self.signature.return_annotation - ) - - def _convert_result(self, result: Any) -> Any: - """ - Unwrap AI return types into plain Python. - - Args: - result: One of: - - LLMChatResponse - - BaseModel (Pydantic) - - List[BaseModel] - - primitive (str/int/etc) or dict - - Returns: - • str (assistant content) when `LLMChatResponse` - • dict when a single BaseModel - • List[dict] when a list of BaseModels - • otherwise, the raw `result` - """ - # 1) Unwrap our unified LLMChatResponse → return the assistant's text - if isinstance(result, LLMChatResponse): - logger.debug("Extracted message content from LLMChatResponse.") - msg = result.get_message() - return getattr(msg, "content", None) - - # 2) Single Pydantic model → dict - if isinstance(result, BaseModel): - logger.debug("Converting Pydantic model to dictionary.") - return result.model_dump() - - # 3) List of Pydantic models → list of dicts - if isinstance(result, list) and all(isinstance(x, BaseModel) for x in result): - logger.debug("Converting list of Pydantic models to list of dictionaries.") - return [x.model_dump() for x in result] - - # 4) Fallback: primitive, dict, etc. - logger.info("Returning final task result.") - return result - - def format_description(self, template: str, data: dict) -> str: - """ - Interpolate inputs into the prompt template. - - Args: - template: The `{}`-style template string. - data: Mapping of variable names to values. - - Returns: - The fully formatted prompt. - """ - if self.signature: - bound = self.signature.bind(**data) - bound.apply_defaults() - return template.format(**bound.arguments) - return template.format(**data) - - def _format_natural_agent_input(self, payload: Any, data: dict) -> str: - """ - Format input for natural agent conversation. - Favors string input over dictionary for better agent interaction. - - Args: - payload: The original raw payload from the workflow - data: The normalized dictionary version - - Returns: - String input for natural agent conversation - """ - if payload is None: - return "" - - # If payload is already a simple string/number, use it directly - if isinstance(payload, (str, int, float, bool)): - return str(payload) - - # If we have function parameters, format them naturally - if data and len(data) == 1: - # Single parameter: extract the value - value = next(iter(data.values())) - return str(value) if value is not None else "" - elif data: - # Multiple parameters: format as natural text - parts = [] - for key, value in data.items(): - if value is not None: - parts.append(f"{key}: {value}") - return "\n".join(parts) - else: - # Fallback to string representation of payload - return str(payload) - - -class TaskWrapper: - """ - A wrapper for WorkflowTask that preserves callable behavior and attributes like __name__. - """ - - def __init__(self, task_instance: WorkflowTask, name: str): - """ - Initialize the TaskWrapper. - - Args: - task_instance (WorkflowTask): The task instance to wrap. - name (str): The task name. - """ - self.task_instance = task_instance - self.__name__ = name - self.__doc__ = getattr(task_instance.func, "__doc__", None) - self.__module__ = getattr(task_instance.func, "__module__", None) - - def __call__(self, *args, **kwargs): - """ - Delegate the call to the wrapped WorkflowTask instance. - """ - return self.task_instance(*args, **kwargs) - - def __getattr__(self, item): - """ - Delegate attribute access to the wrapped task. - """ - return getattr(self.task_instance, item) - - def __repr__(self): - return f"" diff --git a/dapr_agents/workflow/utils/core.py b/dapr_agents/workflow/utils/core.py index fe0744ab..762ed99b 100644 --- a/dapr_agents/workflow/utils/core.py +++ b/dapr_agents/workflow/utils/core.py @@ -1,12 +1,34 @@ +import asyncio import inspect import logging -from dataclasses import is_dataclass +import signal +from dataclasses import asdict, is_dataclass +from typing import Any, Callable, Dict, Type + from pydantic import BaseModel -from typing import Any, Callable, Dict logger = logging.getLogger(__name__) +async def wait_for_shutdown() -> None: + """Block until Ctrl+C or SIGTERM is received.""" + loop = asyncio.get_running_loop() + stop = asyncio.Event() + + def _set_stop(*_: object) -> None: + stop.set() + + try: + loop.add_signal_handler(signal.SIGINT, _set_stop) + loop.add_signal_handler(signal.SIGTERM, _set_stop) + except (NotImplementedError, RuntimeError): + # Windows fallback where add_signal_handler is unavailable + signal.signal(signal.SIGINT, lambda *_: _set_stop()) + signal.signal(signal.SIGTERM, lambda *_: _set_stop()) + + await stop.wait() + + def is_pydantic_model(obj: Any) -> bool: """Check if the given type is a subclass of Pydantic's BaseModel.""" return isinstance(obj, type) and issubclass(obj, BaseModel) @@ -68,3 +90,76 @@ def get_decorated_methods(instance: Any, attribute_name: str) -> Dict[str, Calla logger.debug(f"Discovered decorated method: {name}") return discovered + + +def to_payload(model: Any) -> Dict[str, Any]: + """ + Convert supported message instances to a JSON-serializable dict. + + Supports: + - dict: returned as a shallow copy + - dataclass: converted via asdict() + - Pydantic v2: model_dump(exclude_none=True) + - Pydantic v1: dict(exclude_none=True) + + Falls back to dict() if possible, otherwise raises. + """ + if isinstance(model, dict): + return dict(model) + + if is_dataclass(model): + return asdict(model) + + # Pydantic v2 + if hasattr(model, "model_dump"): + return model.model_dump(exclude_none=True) + + # Pydantic v1 + if hasattr(model, "dict"): + return model.dict(exclude_none=True) + + try: + return dict(model) + except Exception as e: + raise TypeError( + f"Unsupported message payload type for serialization: {type(model)!r}" + ) from e + + +def coerce_to_model(model_type: Type[Any], value: Any) -> Any: + """ + Best-effort coercion of `value` into `model_type` where possible. + + Args: + model_type: Expected model class/type. + value: Incoming value. + + Returns: + Any: Value coerced to `model_type` when feasible; original value otherwise. + """ + if model_type is Any or model_type is dict: + return value + + try: + if isinstance(value, model_type): + return value + except TypeError: + # model_type may be typing constructs (e.g., typing.Dict) that break isinstance + pass + + if hasattr(model_type, "model_validate"): + return model_type.model_validate(value) + + if hasattr(model_type, "parse_obj"): + return model_type.parse_obj(value) + + if is_dataclass(model_type): + try: + if isinstance(value, model_type): + return value + except TypeError: + pass + if isinstance(value, dict): + return model_type(**value) + + return value diff --git a/dapr_agents/workflow/utils/messaging.py b/dapr_agents/workflow/utils/messaging.py deleted file mode 100644 index 880c99d9..00000000 --- a/dapr_agents/workflow/utils/messaging.py +++ /dev/null @@ -1,153 +0,0 @@ -import logging -from dataclasses import is_dataclass -from typing import Any, Tuple, Type, Union, Optional, get_args, get_origin -from dapr.common.pubsub.subscription import SubscriptionMessage -from dapr_agents.types.message import EventMessageMetadata -from dapr_agents.workflow.utils.core import is_pydantic_model, is_supported_model - -logger = logging.getLogger(__name__) - - -def extract_message_models(type_hint: Any) -> list[type]: - """ - Extracts one or more message types from a type hint. - - Supports: - - Single type hint: `MyMessage` - - Union types: `Union[MessageA, MessageB]` - - Fallback to empty list if not valid - """ - if type_hint is None: - return [] - - origin = get_origin(type_hint) - if origin is Union: - return list(get_args(type_hint)) - else: - return [type_hint] - - -def extract_cloudevent_data( - message: Union[SubscriptionMessage, dict], -) -> Tuple[dict, dict]: - """ - Extracts CloudEvent metadata and raw payload data from a SubscriptionMessage or dict. - - Args: - message (Union[SubscriptionMessage, dict]): The raw message received from pub/sub. - - Returns: - Tuple[dict, dict]: (event_data, metadata) where event_data is the message payload, and - metadata is the parsed CloudEvent metadata as a dictionary. - - Raises: - ValueError: If message type is unsupported. - """ - if isinstance(message, SubscriptionMessage): - metadata = EventMessageMetadata( - id=message.id(), - datacontenttype=message.data_content_type(), - pubsubname=message.pubsub_name(), - source=message.source(), - specversion=message.spec_version(), - time=None, - topic=message.topic(), - traceid=None, - traceparent=None, - type=message.type(), - tracestate=None, - headers=message.extensions(), - ).model_dump() - event_data = message.data() - - elif isinstance(message, dict): - metadata = EventMessageMetadata( - id=message.get("id"), - datacontenttype=message.get("datacontenttype"), - pubsubname=message.get("pubsubname"), - source=message.get("source"), - specversion=message.get("specversion"), - time=message.get("time"), - topic=message.get("topic"), - traceid=message.get("traceid"), - traceparent=message.get("traceparent"), - type=message.get("type"), - tracestate=message.get("tracestate"), - headers=message.get("extensions", {}), - ).model_dump() - event_data = message.get("data", {}) - - else: - raise ValueError(f"Unexpected message type: {type(message)}") - - return event_data, metadata - - -def validate_message_model(model: Type[Any], event_data: dict) -> Any: - """ - Validates and parses event data against the provided message model. - - Args: - model (Type[Any]): The message model class. - event_data (dict): The raw event payload data. - - Returns: - Any: An instance of the message model (or raw dict if `model` is `dict`). - - Raises: - TypeError: If the model is not supported. - ValueError: If model validation fails. - """ - if not is_supported_model(model): - raise TypeError(f"Unsupported model type: {model}") - - try: - logger.info(f"Validating payload with model '{model.__name__}'...") - - if model is dict: - return event_data - elif is_dataclass(model): - return model(**event_data) - elif is_pydantic_model(model): - return model.model_validate(event_data) - - except Exception as e: - logger.error(f"Message validation failed for model '{model.__name__}': {e}") - raise ValueError(f"Message validation failed: {e}") - - -def parse_cloudevent( - message: Union[SubscriptionMessage, dict], model: Optional[Type[Any]] = None -) -> Tuple[Any, dict]: - """ - Parses and validates a CloudEvent from a SubscriptionMessage or dict. - - This combines both metadata extraction and message model validation for direct use. - - Args: - message (Union[SubscriptionMessage, dict]): The incoming pub/sub message. - model (Optional[Type[Any]]): The schema used to validate the message body. - - Returns: - Tuple[Any, dict]: The validated message (or raw dict) and its metadata. - - Raises: - ValueError: If metadata or validation fails. - """ - try: - event_data, metadata = extract_cloudevent_data(message) - - if model is None: - raise ValueError("Message validation failed: No model provided.") - - validated_message = validate_message_model(model, event_data) - - logger.info("Message successfully parsed and validated") - logger.debug(f"Data: {validated_message}") - logger.debug(f"metadata: {metadata}") - - return validated_message, metadata - - except Exception as e: - logger.error(f"Failed to parse CloudEvent: {e}", exc_info=True) - raise ValueError(f"Invalid CloudEvent: {str(e)}") diff --git a/dapr_agents/workflow/utils/pubsub.py b/dapr_agents/workflow/utils/pubsub.py new file mode 100644 index 00000000..88e28333 --- /dev/null +++ b/dapr_agents/workflow/utils/pubsub.py @@ -0,0 +1,267 @@ +from __future__ import annotations + +import json +import logging +from dataclasses import asdict, is_dataclass +from typing import Any, Callable, Dict, Mapping, MutableMapping, Optional, Union + +from dapr.aio.clients import DaprClient +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + +JsonDict = Dict[str, Any] +AgentsMetadata = Mapping[str, Mapping[str, Any]] + + +async def serialize_message(message: Any) -> str: + """ + Serialize an arbitrary message payload into a JSON string. + + Args: + message: The message payload to serialize. Expected to be JSON serializable. + + Returns: + A JSON string representation of the payload (empty object for ``None``). + + Raises: + ValueError: If the payload cannot be serialized to JSON. + """ + try: + return json.dumps(message if message is not None else {}) + except TypeError as exc: # pragma: no cover - mirrors previous behaviour + logger.error("Failed to serialize message %r: %s", message, exc) + raise ValueError(f"Message contains non-serializable data: {exc}") from exc + + +async def publish_message( + *, + pubsub_name: str, + topic_name: str, + message: Any, + metadata: Optional[JsonDict] = None, + default_pubsub: Optional[str] = None, + client_factory: Callable[[], DaprClient] = DaprClient, + logger_: logging.Logger = logger, +) -> None: + """ + Publish a raw JSON-serializable payload to a Dapr pub/sub topic. + + Args: + pubsub_name: Pub/Sub component to target. Falls back to ``default_pubsub`` if falsy. + topic_name: Destination topic. + message: Payload to publish (will be serialized via :func:`serialize_message`). + metadata: Optional CloudEvent metadata. + default_pubsub: Component used when ``pubsub_name`` is falsy. + client_factory: Callable returning an async Dapr client (primarily for testing). + logger_: Logger used for diagnostic output. + """ + json_body = await serialize_message(message) + target_pubsub = pubsub_name or default_pubsub + if not target_pubsub: + raise ValueError("pubsub_name or default_pubsub must be provided.") + + meta = metadata or {} + + try: + async with client_factory() as client: + await client.publish_event( + pubsub_name=target_pubsub, + topic_name=topic_name, + data=json_body, + data_content_type="application/json", + publish_metadata=meta, + ) + logger_.debug( + "Published message to pubsub=%s topic=%s metadata=%s payload=%s", + target_pubsub, + topic_name, + meta, + json_body, + ) + except Exception as exc: # pragma: no cover - network failures + logger_.error( + "Error publishing message to pubsub=%s topic=%s: %s", + target_pubsub, + topic_name, + exc, + exc_info=True, + ) + raise + + +async def publish_event_message( + *, + topic_name: str, + pubsub_name: str, + source: str, + message: Union[BaseModel, JsonDict, Any], + message_type: Optional[str] = None, + metadata: Optional[JsonDict] = None, + default_pubsub: Optional[str] = None, + client_factory: Callable[[], DaprClient] = DaprClient, + logger_: logging.Logger = logger, +) -> None: + """ + Publish a CloudEvent-style payload to a topic with convenient schema support. + + Args: + topic_name: Destination topic. + pubsub_name: Pub/Sub component to use. + source: Logical message source (used for CloudEvent metadata). + message: Payload as Pydantic model, dataclass, dict, or JSON string. + message_type: Optional CloudEvent type override. + metadata: Additional metadata entries merged with CloudEvent defaults. + default_pubsub: Component to use when ``pubsub_name`` is falsy. + client_factory: Callable returning an async Dapr client. + logger_: Logger used for diagnostics. + + Raises: + ValueError: For unsupported payload types or missing ``message_type`` on dict payloads. + """ + if isinstance(message, BaseModel): + message_type = message_type or message.__class__.__name__ + payload = message.model_dump() + elif isinstance(message, dict): + if not message_type: + raise ValueError( + "message_type must be provided when message is a dictionary." + ) + payload = message + elif is_dataclass(message): + message_type = message_type or message.__class__.__name__ + payload = asdict(message) + else: + raise ValueError( + "Message must be a Pydantic BaseModel, dataclass, or dictionary.", + ) + + combined_metadata: MutableMapping[str, Any] = { + "cloudevent.type": message_type, + "cloudevent.source": source, + } + if metadata: + combined_metadata.update(metadata) + + logger_.debug( + "%s publishing event type=%s to topic=%s metadata=%s", + source, + message_type, + topic_name, + dict(combined_metadata), + ) + + await publish_message( + pubsub_name=pubsub_name, + topic_name=topic_name, + message=payload, + metadata=dict(combined_metadata), + default_pubsub=default_pubsub, + client_factory=client_factory, + logger_=logger_, + ) + + logger_.info("%s published '%s' to topic '%s'.", source, message_type, topic_name) + + +async def broadcast_message( + *, + message: Union[BaseModel, JsonDict], + broadcast_topic: Optional[str], + message_bus: str, + source: str, + agents_metadata: AgentsMetadata, + exclude_orchestrator: bool = False, + metadata: Optional[JsonDict] = None, + client_factory: Callable[[], DaprClient] = DaprClient, + logger_: logging.Logger = logger, +) -> None: + """ + Broadcast a message to every agent in the supplied metadata mapping. + + Args: + message: Payload to publish (Pydantic model or dict). + broadcast_topic: Topic used for team broadcasts; if falsy the call is ignored. + message_bus: Default pub/sub component for broadcasts. + source: Emitting agent/service name. + agents_metadata: Mapping of agent name -> metadata (requires ``topic_name`` & ``pubsub_name``). + exclude_orchestrator: Skip agents flagged with ``orchestrator=True``. + metadata: Additional CloudEvent metadata. + client_factory: Callable returning an async Dapr client. + logger_: Logger used for diagnostics. + """ + if not broadcast_topic: + logger_.info("%s has no broadcast topic; skipping broadcast.", source) + return + + recipients = { + name: meta + for name, meta in agents_metadata.items() + if not (exclude_orchestrator and meta.get("orchestrator")) + } + if not recipients: + logger_.warning("No agents available for broadcast from %s.", source) + return + + await publish_event_message( + topic_name=broadcast_topic, + pubsub_name=message_bus, + source=source, + message=message, + metadata=metadata, + default_pubsub=message_bus, + client_factory=client_factory, + logger_=logger_, + ) + logger_.debug("%s broadcasted message to %d agents.", source, len(recipients)) + + +async def send_message_to_agent( + *, + target_agent: str, + message: Union[BaseModel, JsonDict], + agents_metadata: AgentsMetadata, + source: str, + metadata: Optional[JsonDict] = None, + client_factory: Callable[[], DaprClient] = DaprClient, + logger_: logging.Logger = logger, +) -> None: + """ + Send a direct message to a single agent using its registry metadata. + + Args: + target_agent: Logical agent name to address. + message: Payload as Pydantic model or dict. + agents_metadata: Mapping of agent metadata (must include ``topic_name`` & ``pubsub_name``). + source: Name of the sender (used in CloudEvent metadata). + metadata: Additional CloudEvent metadata. + client_factory: Callable returning an async Dapr client. + logger_: Logger used for diagnostics. + """ + meta = agents_metadata.get(target_agent) + if not meta: + logger_.warning( + "Target '%s' is not registered; skipping message.", target_agent + ) + return + + topic = meta.get("topic_name") + pubsub_name = meta.get("pubsub_name") + if not topic or not pubsub_name: + logger_.warning( + "Agent '%s' metadata missing topic_name/pubsub_name; skipping message.", + target_agent, + ) + return + + await publish_event_message( + topic_name=str(topic), + pubsub_name=str(pubsub_name), + source=source, + message=message, + metadata=metadata, + default_pubsub=str(pubsub_name), + client_factory=client_factory, + logger_=logger_, + ) + logger_.debug("Sent message from %s to agent %s.", source, target_agent) diff --git a/dapr_agents/workflow/utils/registration.py b/dapr_agents/workflow/utils/registration.py index 9a0f2dce..4627f684 100644 --- a/dapr_agents/workflow/utils/registration.py +++ b/dapr_agents/workflow/utils/registration.py @@ -3,154 +3,731 @@ import asyncio import inspect import logging -from typing import Any, Callable, Iterable, List, Optional, Type +from dataclasses import asdict, dataclass, is_dataclass +from typing import ( + Any, + Callable, + Iterable, + List, + Literal, + Optional, + Protocol, + Type, +) +import dapr.ext.workflow as wf from dapr.clients import DaprClient from dapr.clients.grpc._response import TopicEventResponse from dapr.common.pubsub.subscription import SubscriptionMessage +from dapr.ext.workflow.workflow_state import WorkflowState +from fastapi import Body, FastAPI +from fastapi.responses import JSONResponse, Response -from dapr_agents.workflow.utils.messaging import ( +from dapr_agents.types.workflow import HttpRouteSpec, PubSubRouteSpec +from dapr_agents.workflow.utils.routers import ( extract_cloudevent_data, + parse_http_json, validate_message_model, ) logger = logging.getLogger(__name__) -def register_message_handlers( - targets: Iterable[Any], - dapr_client: DaprClient, +class DedupeBackend(Protocol): + """Idempotency backend contract (best-effort duplicate detection).""" + + def seen(self, key: str) -> bool: + ... + + def mark(self, key: str) -> None: + ... + + +SubscribeFn = Callable[..., Callable[[], None]] +SchedulerFn = Callable[[Callable[..., Any], dict], Optional[str]] + + +@dataclass +class _MessageRouteBinding: + handler: Callable[..., Any] + schemas: List[Type[Any]] + pubsub: str + topic: str + dead_letter_topic: Optional[str] + name: str + + +@dataclass +class _HttpRouteBinding: + handler: Callable[..., Any] + schemas: List[Type[Any]] + method: str + path: str + summary: Optional[str] + tags: List[str] + response_model: Optional[Type[Any]] + name: str + + +def _resolve_loop( + loop: Optional[asyncio.AbstractEventLoop], +) -> asyncio.AbstractEventLoop: + if loop is not None: + return loop + try: + return asyncio.get_running_loop() + except RuntimeError: + return asyncio.get_event_loop() + + +def _iter_decorated(target: Any, attr: str): + meta = getattr(target, attr, None) + if callable(target) and meta: + yield None, target, meta + return + for name in dir(target): + fn = getattr(target, name) + meta = getattr(fn, attr, None) + if callable(fn) and meta: + yield target, fn, meta + + +def _collect_message_bindings( *, - loop: Optional[asyncio.AbstractEventLoop] = None, + targets: Optional[Iterable[Any]], + routes: Optional[Iterable[PubSubRouteSpec]], +) -> List[_MessageRouteBinding]: + bindings: List[_MessageRouteBinding] = [] + + if targets: + for target in targets: + for owner, handler, meta in _iter_decorated(target, "_message_router_data"): + bound = ( + handler + if owner is None + else handler.__get__(owner, owner.__class__) + ) + schemas = list(meta.get("message_schemas") or [dict]) + pubsub = meta.get("pubsub") + topic = meta.get("topic") + if not pubsub or not topic: + raise ValueError( + f"@message_router '{getattr(bound, '__name__', bound)}' is missing pubsub/topic. " + "Provide them in the decorator or register via PubSubRouteSpec with explicit values." + ) + bindings.append( + _MessageRouteBinding( + handler=bound, + schemas=schemas, + pubsub=pubsub, + topic=topic, + dead_letter_topic=meta.get("dead_letter_topic"), + name=getattr(bound, "__name__", str(bound)), + ) + ) + + if routes: + for spec in routes: + bound = spec.handler_fn + meta = getattr(bound, "_message_router_data", None) + if spec.message_model is not None: + schemas = [spec.message_model] + elif meta and meta.get("message_schemas"): + schemas = list(meta.get("message_schemas")) + else: + schemas = [dict] + bindings.append( + _MessageRouteBinding( + handler=bound, + schemas=schemas, + pubsub=spec.pubsub_name, + topic=spec.topic, + dead_letter_topic=spec.dead_letter_topic, + name=getattr(bound, "__name__", str(bound)), + ) + ) + + return bindings + + +def _collect_http_bindings( + *, + targets: Optional[Iterable[Any]], + routes: Optional[Iterable[HttpRouteSpec]], +) -> List[_HttpRouteBinding]: + bindings: List[_HttpRouteBinding] = [] + + if targets: + for target in targets: + for owner, handler, meta in _iter_decorated(target, "_http_route_data"): + bound = ( + handler + if owner is None + else handler.__get__(owner, owner.__class__) + ) + schemas = list(meta.get("request_schemas") or [dict]) + bindings.append( + _HttpRouteBinding( + handler=bound, + schemas=schemas, + method=meta["method"], + path=meta["path"], + summary=meta.get("summary"), + tags=list(meta.get("tags") or []), + response_model=meta.get("response_model"), + name=getattr(bound, "__name__", str(bound)), + ) + ) + + if routes: + for spec in routes: + bound = spec.handler_fn + meta = getattr(bound, "_http_route_data", None) + if spec.request_model is not None: + schemas = [spec.request_model] + elif meta and meta.get("request_schemas"): + schemas = list(meta.get("request_schemas")) + else: + schemas = [dict] + + summary = ( + spec.summary + if spec.summary is not None + else (meta.get("summary") if meta else None) + ) + tags = list(spec.tags or (meta.get("tags") if meta else []) or []) + response_model = ( + spec.response_model + if spec.response_model is not None + else (meta.get("response_model") if meta else None) + ) + method = spec.method or (meta.get("method") if meta else "POST") + + bindings.append( + _HttpRouteBinding( + handler=bound, + schemas=schemas, + method=method, + path=spec.path, + summary=summary, + tags=tags, + response_model=response_model, + name=getattr(bound, "__name__", str(bound)), + ) + ) + + return bindings + + +def _subscribe_message_bindings( + bindings: List[_MessageRouteBinding], + *, + dapr_client: DaprClient, + loop: Optional[asyncio.AbstractEventLoop], + delivery_mode: Literal["sync", "async"], + queue_maxsize: int, + deduper: Optional[DedupeBackend], + subscribe: Optional[SubscribeFn], + scheduler: Optional[SchedulerFn], + wf_client: Optional[wf.DaprWorkflowClient], + await_result: bool, + await_timeout: Optional[int], + fetch_payloads: bool, + log_outcome: bool, ) -> List[Callable[[], None]]: - """Discover and subscribe handlers decorated with `@message_router`. + if not bindings: + return [] - Scans each target: - - If the target itself is a decorated function (has `_message_router_data`), it is registered. - - If the target is an object, all its attributes are scanned for decorated callables. + loop = _resolve_loop(loop) + if subscribe is None: + subscribe = dapr_client.subscribe_with_handler # type: ignore[assignment] + if delivery_mode not in ("sync", "async"): + raise ValueError("delivery_mode must be 'sync' or 'async'") - Subscriptions use Dapr's streaming API (`subscribe_with_handler`) which invokes your handler - on a background thread. This function returns a list of "closer" callables. Invoking a closer - will unsubscribe the corresponding handler. + queue: Optional[asyncio.Queue] = None + worker_tasks: List[asyncio.Task] = [] - Args: - targets: Functions and/or instances to inspect for `_message_router_data`. - dapr_client: Active Dapr client used to create subscriptions. - loop: Event loop to await async handlers. If omitted, uses the running loop - or falls back to `asyncio.get_event_loop()`. + if delivery_mode == "async": + if not loop or not loop.is_running(): + raise RuntimeError( + "delivery_mode='async' requires an active running event loop." + ) + queue = asyncio.Queue(maxsize=max(1, queue_maxsize)) - Returns: - A list of callables. Each callable, when invoked, closes the associated subscription. - """ - # Resolve loop strategy once up front. - if loop is None: + _wf_client = wf_client or wf.DaprWorkflowClient() + + def _default_scheduler( + workflow_callable: Callable[..., Any], wf_input: dict + ) -> Optional[str]: try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = asyncio.get_event_loop() + import json - closers: List[Callable[[], None]] = [] + logger.debug( + "➡️ Scheduling workflow: %s | input=%s", + getattr(workflow_callable, "__name__", str(workflow_callable)), + json.dumps(wf_input, ensure_ascii=False, indent=2), + ) + except Exception: + logger.warning("Could not serialize wf_input for logging", exc_info=True) + return _wf_client.schedule_new_workflow( + workflow=workflow_callable, input=wf_input + ) - def _iter_handlers(obj: Any): - """Yield (owner, fn) pairs for decorated handlers on `obj`. + _scheduler: SchedulerFn = scheduler or _default_scheduler - If `obj` is itself a decorated function, yield (None, obj). - If `obj` is an instance, scan its attributes for decorated callables. - """ - meta = getattr(obj, "_message_router_data", None) - if callable(obj) and meta: - yield None, obj + def _log_state(instance_id: str, state: Optional[WorkflowState]) -> None: + if not state: + logger.warning("[wf] %s: no state (timeout/missing).", instance_id) return + status = getattr(state.runtime_status, "name", str(state.runtime_status)) + if status == "COMPLETED": + if log_outcome: + logger.info( + "[wf] %s COMPLETED output=%s", + instance_id, + getattr(state, "serialized_output", None), + ) + return + failure = getattr(state, "failure_details", None) + if failure: + logger.error( + "[wf] %s FAILED type=%s message=%s\n%s", + instance_id, + getattr(failure, "error_type", None), + getattr(failure, "message", None), + getattr(failure, "stack_trace", "") or "", + ) + else: + logger.error( + "[wf] %s finished with status=%s custom_status=%s", + instance_id, + status, + getattr(state, "serialized_custom_status", None), + ) + + def _wait_for_completion(instance_id: str) -> Optional[WorkflowState]: + try: + return _wf_client.wait_for_workflow_completion( + instance_id, + fetch_payloads=fetch_payloads, + timeout_in_seconds=await_timeout, + ) + except Exception: + logger.exception("[wf] %s: error while waiting for completion", instance_id) + return None - for name in dir(obj): - fn = getattr(obj, name) - if callable(fn) and getattr(fn, "_message_router_data", None): - yield obj, fn + async def _await_and_log(instance_id: str) -> None: + state = await asyncio.to_thread(_wait_for_completion, instance_id) + _log_state(instance_id, state) + + async def _schedule( + bound_workflow: Callable[..., Any], parsed: Any + ) -> TopicEventResponse: + try: + metadata: Optional[dict] = None + if isinstance(parsed, dict): + wf_input = dict(parsed) + metadata = wf_input.get("_message_metadata") + elif hasattr(parsed, "model_dump"): + metadata = getattr(parsed, "_message_metadata", None) + wf_input = parsed.model_dump() + elif is_dataclass(parsed): + metadata = getattr(parsed, "_message_metadata", None) + wf_input = asdict(parsed) + else: + metadata = getattr(parsed, "_message_metadata", None) + wf_input = {"data": parsed} - for target in targets: - for owner, handler in _iter_handlers(target): - meta = getattr(handler, "_message_router_data") - schemas: List[Type[Any]] = meta.get("message_schemas") or [] + if metadata: + wf_input["_message_metadata"] = dict(metadata) - # Bind method to instance if needed (descriptor protocol). - bound = ( - handler if owner is None else handler.__get__(owner, owner.__class__) + instance_id = await asyncio.to_thread(_scheduler, bound_workflow, wf_input) + logger.info( + "Scheduled workflow=%s instance=%s", + getattr(bound_workflow, "__name__", str(bound_workflow)), + instance_id, ) - async def _invoke( - bound_handler: Callable[..., Any], - parsed: Any, - ) -> TopicEventResponse: - """Invoke the user handler (sync or async) and normalize the result.""" - result = bound_handler(parsed) - if inspect.iscoroutine(result): - result = await result - if isinstance(result, TopicEventResponse): - return result - # Treat any truthy/None return as success unless user explicitly returns a response. - return TopicEventResponse("success") - - def _make_handler( - bound_handler: Callable[..., Any], - ) -> Callable[[SubscriptionMessage], TopicEventResponse]: - """Create a Dapr-compatible handler for a single decorated function.""" - - def handler_fn(message: SubscriptionMessage) -> TopicEventResponse: + if await_result and delivery_mode == "sync": + state = await asyncio.to_thread(_wait_for_completion, instance_id) + _log_state(instance_id, state) + if state and getattr(state.runtime_status, "name", "") == "COMPLETED": + return TopicEventResponse("success") + return TopicEventResponse("retry") + + asyncio.create_task(_await_and_log(instance_id)) + return TopicEventResponse("success") + except Exception: + logger.exception("Workflow scheduling failed; requesting retry.") + return TopicEventResponse("retry") + + if queue is not None: + + async def _worker() -> None: + while True: + workflow_callable, payload = await queue.get() + try: + await _schedule(workflow_callable, payload) + except Exception: + logger.exception("Async worker crashed while scheduling workflow.") + finally: + queue.task_done() + + for _ in range(max(1, len(bindings))): + worker_tasks.append(loop.create_task(_worker())) # type: ignore[union-attr] + + # ---------------- NEW: group by (pubsub, topic) and build ONE composite handler per topic ------------- + from collections import defaultdict + + grouped: dict[tuple[str, str], list[_MessageRouteBinding]] = defaultdict(list) + for b in bindings: + grouped[(b.pubsub, b.topic)].append(b) + + def _composite_handler_fn( + group: list[_MessageRouteBinding], + ) -> Callable[[SubscriptionMessage], TopicEventResponse]: + # Flatten a plan: [(binding, model), ...] preserving declaration order + plan: list[tuple[_MessageRouteBinding, Type[Any]]] = [] + for b in group: + for m in b.schemas or [dict]: + plan.append((b, m)) + + def handler(message: SubscriptionMessage) -> TopicEventResponse: + try: + event_data, metadata = extract_cloudevent_data(message) + + # Optional: simple idempotency hook + if deduper is not None: + candidate_id = (metadata or {}).get( + "id" + ) or f"{group[0].topic}:{hash(str(event_data))}" + try: + if deduper.seen(candidate_id): + logger.info( + "Duplicate detected id=%s topic=%s; dropping.", + candidate_id, + group[0].topic, + ) + return TopicEventResponse("success") + deduper.mark(candidate_id) + except Exception: + logger.debug("Dedupe backend error; continuing.", exc_info=True) + + # (Optional) fast-path by CloudEvent type == model name (if publisher sets ce-type) + ce_type = (metadata or {}).get("type") + ordered_iter = plan + if ce_type: + preferred = [ + pair + for pair in plan + if getattr(pair[1], "__name__", "") == ce_type + ] + if preferred: + # Try preferred models first, then the rest + tail = [pair for pair in plan if pair not in preferred] + ordered_iter = preferred + tail + + # Try to validate against each model and dispatch to its handler + for binding, model in ordered_iter: try: - # 1) Extract payload + CloudEvent metadata (bytes/str/dict are also supported by the extractor) - event_data, metadata = extract_cloudevent_data(message) - - # 2) Validate against the first matching schema (or dict as fallback) - parsed = None - for model in schemas or [dict]: - try: - parsed = validate_message_model(model, event_data) - break - except Exception: - # Try the next schema; log at debug for signal without noise. - logger.debug( - "Schema %r did not match payload; trying next.", - model, - exc_info=True, - ) - continue - - if parsed is None: - # Permanent schema mismatch → drop (DLQ if configured by Dapr) - logger.warning( - "No matching schema for message on topic %r; dropping. Raw payload: %r", - meta["topic"], - event_data, + payload = ( + event_data + if isinstance(event_data, dict) + else {"data": event_data} + ) + parsed = validate_message_model(model, payload) + # attach metadata + try: + if isinstance(parsed, dict): + parsed["_message_metadata"] = metadata + else: + setattr(parsed, "_message_metadata", metadata) + except Exception: + logger.debug( + "Could not attach _message_metadata; continuing.", + exc_info=True, ) - return TopicEventResponse("drop") - # 3) Attach CE metadata for downstream consumers - if isinstance(parsed, dict): - parsed["_message_metadata"] = metadata - else: - setattr(parsed, "_message_metadata", metadata) + # enqueue/schedule to the right handler + if delivery_mode == "async": + assert queue is not None + loop.call_soon_threadsafe( + queue.put_nowait, (binding.handler, parsed) + ) # type: ignore[union-attr] + return TopicEventResponse("success") - # 4) Bridge worker thread → event loop if loop and loop.is_running(): fut = asyncio.run_coroutine_threadsafe( - _invoke(bound_handler, parsed), loop + _schedule(binding.handler, parsed), loop ) return fut.result() - return asyncio.run(_invoke(bound_handler, parsed)) + + return asyncio.run(_schedule(binding.handler, parsed)) except Exception: - # Transient failure (I/O, handler crash, etc.) → retry - logger.exception("Message handler error; requesting retry.") - return TopicEventResponse("retry") + # Not a match for this model → keep trying + continue - return handler_fn + # No model matched for this topic → drop (or switch to "retry" if you prefer) + logger.warning( + "No matching schema for topic=%r; dropping. raw=%r", + group[0].topic, + event_data, + ) + return TopicEventResponse("drop") - close_fn = dapr_client.subscribe_with_handler( - pubsub_name=meta["pubsub"], - topic=meta["topic"], - handler_fn=_make_handler(bound), - dead_letter_topic=meta.get("dead_letter_topic"), - ) - closers.append(close_fn) + except Exception: + logger.exception("Message handler error; requesting retry.") + return TopicEventResponse("retry") + + return handler + + closers: List[Callable[[], None]] = [] + + # subscribe one composite handler per (pubsub, topic) + for (pubsub_name, topic_name), group in grouped.items(): + handler_fn = _composite_handler_fn(group) + close_fn = subscribe( # type: ignore[misc] + pubsub_name=pubsub_name, + topic=topic_name, + handler_fn=handler_fn, + dead_letter_topic=group[0].dead_letter_topic, + ) + logger.info( + "Subscribed COMPOSITE(%d handlers) to pubsub=%s topic=%s (delivery=%s await=%s)", + len(group), + pubsub_name, + topic_name, + delivery_mode, + await_result, + ) + closers.append(close_fn) + + if worker_tasks: + + def _make_cancel_all(tasks: List[asyncio.Task]) -> Callable[[], None]: + def _cancel() -> None: + for task in tasks: + try: + task.cancel() + except Exception: + logger.debug("Error cancelling worker task.", exc_info=True) + + return _cancel + + closers.append(_make_cancel_all(worker_tasks)) + + return closers + + +def _mount_http_bindings( + bindings: List[_HttpRouteBinding], + *, + app: FastAPI, + loop: Optional[asyncio.AbstractEventLoop], +) -> List[Callable[[], None]]: + if not bindings: + return [] + + _ = _resolve_loop( + loop + ) # Parity with message registrar; FastAPI does not require it yet. + closers: List[Callable[[], None]] = [] + + async def _invoke(bound_handler: Callable[..., Any], parsed: Any) -> Any: + result = bound_handler(parsed) + if inspect.iscoroutine(result): + result = await result + return result + + for binding in bindings: + _schemas = binding.schemas or [dict] + _method = binding.method + _path = binding.path + _summary = binding.summary + _tags = list(binding.tags) + _response_model = binding.response_model + _name = binding.name + _handler = binding.handler + + def _make_endpoint( + *, + bound_handler: Callable[..., Any], + schemas_b: List[Type[Any]], + method_b: str, + path_b: str, + name_b: str, + ) -> Callable[..., Any]: + async def endpoint(body: Any = Body(...)) -> Any: + try: + parsed = None + matched_model: Optional[Type[Any]] = None + for model in schemas_b: + try: + candidate, _ = parse_http_json( + body, model=model, attach_metadata=False + ) + parsed = candidate + matched_model = model + break + except Exception: + logger.debug( + "HTTP schema %r did not match; trying next.", + model, + exc_info=True, + ) + + if parsed is None: + return JSONResponse( + status_code=422, + content={ + "detail": "Request body did not match any expected schema" + }, + ) + + if matched_model is not None: + logger.debug( + "Validated HTTP request for %s %s with model=%s", + method_b, + path_b, + getattr(matched_model, "__name__", str(matched_model)), + ) + + result = await _invoke(bound_handler, parsed) + + if isinstance(result, Response): + return result + if isinstance( + result, (dict, list, str, int, float, bool, type(None)) + ): + return result + return JSONResponse(content=result) + + except Exception: + logger.exception("HTTP handler error for %s %s.", method_b, path_b) + return JSONResponse( + status_code=500, content={"detail": "Internal Server Error"} + ) + + endpoint.__name__ = f"{name_b}_endpoint" + return endpoint + + endpoint = _make_endpoint( + bound_handler=_handler, + schemas_b=_schemas, + method_b=_method, + path_b=_path, + name_b=_name, + ) + + app.add_api_route( + _path, + endpoint, + methods=[_method], + summary=_summary, + tags=_tags, + response_model=_response_model, + ) + + closers.append(lambda: None) + logger.info("Mounted HTTP route %s %s -> %s", _method, _path, _name) return closers + + +def register_message_routes( + *, + dapr_client: DaprClient, + targets: Optional[Iterable[Any]] = None, + routes: Optional[Iterable[PubSubRouteSpec]] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + delivery_mode: Literal["sync", "async"] = "sync", + queue_maxsize: int = 1024, + deduper: Optional[DedupeBackend] = None, + subscribe: Optional[SubscribeFn] = None, + scheduler: Optional[SchedulerFn] = None, + wf_client: Optional[wf.DaprWorkflowClient] = None, + await_result: bool = False, + await_timeout: Optional[int] = None, + fetch_payloads: bool = True, + log_outcome: bool = True, +) -> List[Callable[[], None]]: + """ + Register workflow-backed pub/sub routes via decorator discovery and/or explicit specs. + + Args: + dapr_client: Active Dapr client used to create subscriptions. + targets: Objects/functions containing `@message_router` callables to auto-discover. + routes: Explicit `PubSubRouteSpec` entries to register. + loop: Event loop used to await async work (required for `delivery_mode="async"`). + delivery_mode: `"sync"` blocks the Dapr thread; `"async"` enqueues onto a worker queue. + queue_maxsize: Max in-flight messages when `delivery_mode="async"`. + deduper: Optional idempotency backend keyed by CloudEvent id/hash. + subscribe: Optional override for `dapr_client.subscribe_with_handler`. + scheduler: Optional `(callable, input_dict) -> instance_id` function. + wf_client: Reused `DaprWorkflowClient` for scheduling/waiting. + await_result: If `True` (sync only), wait for workflow completion and request retry on failure. + await_timeout: Optional wait timeout in seconds. + fetch_payloads: Include workflow payloads when waiting for completion. + log_outcome: Log COMPLETED/FAILED status (either inline or via detached task). + + Returns: + List of closers that unsubscribe handlers and cancel async workers. + """ + if targets is None and routes is None: + raise ValueError( + "Provide `targets` and/or `routes` when registering message routes." + ) + + bindings = _collect_message_bindings(targets=targets, routes=routes) + if not bindings: + logger.info("No message routes discovered.") + return [] + + return _subscribe_message_bindings( + bindings, + dapr_client=dapr_client, + loop=loop, + delivery_mode=delivery_mode, + queue_maxsize=queue_maxsize, + deduper=deduper, + subscribe=subscribe, + scheduler=scheduler, + wf_client=wf_client, + await_result=await_result, + await_timeout=await_timeout, + fetch_payloads=fetch_payloads, + log_outcome=log_outcome, + ) + + +def register_http_routes( + *, + app: FastAPI, + targets: Optional[Iterable[Any]] = None, + routes: Optional[Iterable[HttpRouteSpec]] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, +) -> List[Callable[[], None]]: + """ + Mount FastAPI endpoints from `@http_router` targets and/or explicit `HttpRouteSpec` entries. + + Args: + app: FastAPI application to register routes on. + targets: Objects/functions containing decorated HTTP handlers to auto-discover. + routes: Explicit HTTP specs to mount. + loop: Optional loop reference (retained for symmetry/future async needs). + + Returns: + List of no-op closers (API symmetry with message registrar). + """ + if targets is None and routes is None: + raise ValueError( + "Provide `targets` and/or `routes` when registering HTTP routes." + ) + + bindings = _collect_http_bindings(targets=targets, routes=routes) + if not bindings: + logger.info("No HTTP routes discovered.") + return [] + + return _mount_http_bindings(bindings, app=app, loop=loop) diff --git a/dapr_agents/workflow/utils/routers.py b/dapr_agents/workflow/utils/routers.py index 20a04ed2..14f00369 100644 --- a/dapr_agents/workflow/utils/routers.py +++ b/dapr_agents/workflow/utils/routers.py @@ -10,46 +10,34 @@ from dapr.common.pubsub.subscription import SubscriptionMessage from dapr_agents.types.message import EventMessageMetadata -from dapr_agents.workflow.utils.core import is_pydantic_model, is_supported_model +from dapr_agents.workflow.utils.core import is_supported_model logger = logging.getLogger(__name__) def extract_message_models(type_hint: Any) -> list[type]: - """Normalize a message type hint into a concrete list of classes. + """ + Turn a single class or a Union[...] into a list of concrete classes (filters None/Any). - Supports: - - Single class: `MyMessage` → `[MyMessage]` - - Union: `Union[Foo, Bar]` or `Foo | Bar` → `[Foo, Bar]` - - Optional: `Optional[Foo]` (i.e., `Union[Foo, None]`) → `[Foo]` + Args: + type_hint (Any): + The type hint to extract classes from. - Notes: - - Forward refs should be resolved by the caller (e.g., via `typing.get_type_hints`). - - Non-class entries (e.g., `None`, `typing.Any`) are filtered out. - - Returns an empty list when the hint isn't a usable class or union of classes. + Returns: + list[type]: A list of concrete classes extracted from the type hint. """ if type_hint is None: return [] - origin = get_origin(type_hint) - if origin in (Union, types.UnionType): # handle both `Union[...]` and `A | B` + if origin in (Union, types.UnionType): return [ t for t in get_args(type_hint) if t is not NoneType and isinstance(t, type) ] - return [type_hint] if isinstance(type_hint, type) else [] def _maybe_json_loads(payload: Any, content_type: Optional[str]) -> Any: - """ - Best-effort JSON parsing based on content type and payload shape. - - - If payload is `dict`/`list` → return as-is. - - If bytes/str and content-type hints JSON (or text looks like JSON) → parse to Python. - - Otherwise → return the original payload. - - This helper is intentionally forgiving; callers should validate downstream. - """ + """Best-effort: parse JSON by content-type hint or shape; otherwise return original value.""" try: if isinstance(payload, (dict, list)): return payload @@ -74,23 +62,66 @@ def _maybe_json_loads(payload: Any, content_type: Optional[str]) -> Any: return payload +def _maybe_json_body(body: Any) -> Any: + """HTTP helper: parse str/bytes into JSON once; otherwise return as-is.""" + if isinstance(body, (bytes, str)): + try: + return json.loads(body) + except Exception: + return body + return body + + +def validate_message_model(model: Type[Any], event_data: dict) -> Any: + """ + Validate/coerce event_data into model (dict, dataclass, Pydantic v1/v2). + + Args: + model (Type[Any]): + The model class to validate against. + event_data (dict): + The event data to validate. + + Returns: + Any: The validated/coerced message instance. + """ + if not is_supported_model(model): + raise TypeError(f"Unsupported model type: {model!r}") + + try: + logger.info("Validating payload with model '%s'...", model.__name__) + + if model is dict: + return event_data + + if is_dataclass(model): + return model(**event_data) + + if hasattr(model, "model_validate"): # Pydantic v2 + return model.model_validate(event_data) + + if hasattr(model, "parse_obj"): # Pydantic v1 + return model.parse_obj(event_data) + + raise TypeError(f"Unsupported model type: {model!r}") + + except Exception as e: + logger.error("Message validation failed for model '%s': %s", model.__name__, e) + raise ValueError(f"Message validation failed: {e}") from e + + def extract_cloudevent_data( message: Union[SubscriptionMessage, dict, bytes, str], -) -> Tuple[dict, dict]: +) -> Tuple[Any, dict]: """ - Extract CloudEvent metadata and payload (attempting JSON parsing when appropriate). + Extract CloudEvent .data and metadata from Dapr SubscriptionMessage or similar shapes. - Accepts: - - `SubscriptionMessage` (Dapr SDK) - - `dict` (raw CloudEvent envelope) - - `bytes`/`str` (data-only; metadata is synthesized) + Args: + message (Union[SubscriptionMessage, dict, bytes, str]): + The incoming CloudEvent message from Dapr pub/sub. Returns: - (event_data, metadata) as dictionaries. `event_data` may be non-dict JSON - (e.g., list) if the payload is an array; callers expecting dicts should handle it. - - Raises: - ValueError: For unsupported `message` types. + Tuple[Any, dict]: A tuple containing the event data and its metadata. """ if isinstance(message, SubscriptionMessage): content_type = message.data_content_type() @@ -102,7 +133,7 @@ def extract_cloudevent_data( pubsubname=message.pubsub_name(), source=message.source(), specversion=message.spec_version(), - time=None, # not always populated by SDK + time=None, topic=message.topic(), traceid=None, traceparent=None, @@ -131,7 +162,6 @@ def extract_cloudevent_data( ).model_dump() elif isinstance(message, (bytes, str)): - # No CloudEvent envelope; treat payload as data-only and synthesize minimal metadata. content_type = "application/json" event_data = _maybe_json_loads(message, content_type) metadata = EventMessageMetadata( @@ -154,75 +184,80 @@ def extract_cloudevent_data( if not isinstance(event_data, dict): logger.debug( - "Event data is not a dict (type=%s); value=%r", type(event_data), event_data + "CloudEvent data is not a dict (type=%s); value=%r", + type(event_data), + event_data, ) return event_data, metadata -def validate_message_model(model: Type[Any], event_data: dict) -> Any: +def parse_cloudevent( + message: Union[SubscriptionMessage, dict, bytes, str], + model: Optional[Type[Any]] = None, +) -> Tuple[Any, dict]: """ - Validate and coerce `event_data` into `model`. + Parse a pub/sub CloudEvent and validate its `.data` against model. - Supports: - - dict: returns `event_data` unchanged - - dataclass: constructs the dataclass - - Pydantic v2 model: uses `model_validate` + Args: + message (Union[SubscriptionMessage, dict, bytes, str]): + The incoming CloudEvent message from Dapr pub/sub. + model (Optional[Type[Any]], optional): + The model class to validate the event data against. Defaults to None. - Raises: - TypeError: If the model is not a supported kind. - ValueError: If validation/construction fails. + Returns: + Tuple[Any, dict]: A tuple containing the validated message and its metadata. """ - if not is_supported_model(model): - raise TypeError(f"Unsupported model type: {model!r}") - try: - logger.info(f"Validating payload with model '{model.__name__}'...") + if model is None: + raise ValueError("Message validation failed: No model provided.") + + event_data, metadata = extract_cloudevent_data(message) + if not isinstance(event_data, dict): + event_data = {"data": event_data} + + validated_message = validate_message_model(model, event_data) + logger.info("CloudEvent successfully parsed and validated") + logger.debug("Data: %r", validated_message) + logger.debug("metadata: %r", metadata) + return validated_message, metadata - if model is dict: - return event_data - if is_dataclass(model): - return model(**event_data) - if is_pydantic_model(model): - return model.model_validate(event_data) - raise TypeError(f"Unsupported model type: {model!r}") except Exception as e: - logger.error(f"Message validation failed for model '{model.__name__}': {e}") - raise ValueError(f"Message validation failed: {e}") + logger.error("Failed to parse CloudEvent: %s", e, exc_info=True) + raise ValueError(f"Invalid CloudEvent: {str(e)}") from e -def parse_cloudevent( - message: Union[SubscriptionMessage, dict, bytes, str], +def parse_http_json( + body: Any, model: Optional[Type[Any]] = None, + *, + attach_metadata: bool = False, ) -> Tuple[Any, dict]: """ - Parse a CloudEvent-like input and validate its payload against ``model``. + Parse a plain JSON HTTP body and validate against model (no CloudEvent semantics). Args: - message (Union[SubscriptionMessage, dict, bytes, str]): Incoming message; can be a Dapr ``SubscriptionMessage``, a raw - CloudEvent ``dict``, or bare ``bytes``/``str`` payloads. - model (Optional[Type[Any]]): Schema for payload validation (required). + body (Any): + The incoming HTTP request body. + model (Optional[Type[Any]], optional): + The model class to validate the body against. Defaults to None. + attach_metadata (bool, optional): + Whether to attach empty metadata dict. Defaults to False. Returns: Tuple[Any, dict]: A tuple containing the validated message and its metadata. - - Raises: - ValueError: If no model is provided or validation fails. """ - try: - event_data, metadata = extract_cloudevent_data(message) + if model is None: + raise ValueError("Message validation failed: No model provided.") - if model is None: - raise ValueError("Message validation failed: No model provided.") - - validated_message = validate_message_model(model, event_data) - - logger.info("Message successfully parsed and validated") - logger.debug(f"Data: {validated_message}") - logger.debug(f"metadata: {metadata}") - - return validated_message, metadata + payload = _maybe_json_body(body) + if isinstance(payload, dict): + event_data = payload + else: + event_data = {"data": payload} - except Exception as e: - logger.error(f"Failed to parse CloudEvent: {e}", exc_info=True) - raise ValueError(f"Invalid CloudEvent: {str(e)}") + validated = validate_message_model(model, event_data) + metadata: dict = {} if attach_metadata else {} + logger.info("HTTP JSON successfully parsed and validated (no CloudEvent semantics)") + logger.debug("Data: %r", validated) + return validated, metadata diff --git a/docs/development/README.md b/docs/development/README.md index 05af012d..4556a7b4 100644 --- a/docs/development/README.md +++ b/docs/development/README.md @@ -152,6 +152,19 @@ tox -e type 2. every time i have a .run() or invoke new workflow via curl or pubsub then a new workflow instance id will be created. If there is an inflight workflow already then it will be resumed, and the new one will be created. 3. Trace ID = workflow ID and make the tracing pick up from where it left off. +### Internal class structuring/setup +When to use Pydantic vs dataclasses: +- Use Pydantic for: + - Data crossing trust boundaries or is persisted: API payloads, pub/sub messages, persisted state (workflow state, timeline messages, trigger/broadcast schemas, tool records, etc.). + - Schemas requiring coercion, validation, or versioned migrations. +- Use dataclasses for: + - Agent construction knobs you pass in code (ie agent config classes). + - Dependency injection of services/stores/policies and behavior hooks. + +Mental model: +- Think “config vs data”: + - Config you wire at construction time → dataclasses. + - Data the system processes/persists at runtime → Pydantic. ## Contributing to Dapr Agents Quickstarts diff --git a/quickstarts/03-agent-tool-call/weather_agent.py b/quickstarts/03-agent-tool-call/weather_agent.py index 3617e467..c4e26f2f 100644 --- a/quickstarts/03-agent-tool-call/weather_agent.py +++ b/quickstarts/03-agent-tool-call/weather_agent.py @@ -1,6 +1,6 @@ import asyncio from weather_tools import tools -from dapr_agents import Agent +from dapr_agents.agents.standalone import Agent from dotenv import load_dotenv from dapr_agents.llm.openai import OpenAIChatClient diff --git a/quickstarts/10-test-durable-agent/README.md b/quickstarts/10-test-durable-agent/README.md new file mode 100644 index 00000000..715af9ab --- /dev/null +++ b/quickstarts/10-test-durable-agent/README.md @@ -0,0 +1,35 @@ +# Test New DurableAgent (Pub/Sub → Workflow) + +## Prerequisites + +- Python 3.10+ +- Dapr CLI & Docker +- OpenAI-compatible API key (for the `llm_activity` decorators) + +## Setup + +Install dependencies in your virtual environment (see repository root instructions) and ensure Dapr is initialised: + +```bash +dapr init +``` + +Provide your OpenAI key via `.env` or by editing `components/openai.yaml`, identical to the earlier quickstart. + +## Run the app + +```bash +# Terminal 1 – run the workflow app +dapr run \ + --app-id blog-app-agent \ + --resources-path ./components \ + -- python app.py + +# Terminal 2 – publish a message to start the workflow +dapr run \ + --app-id blog-app-client \ + --resources-path ./components \ + -- python message_client.py +``` + +You should see the workflow started and the generated blog post appear in the app logs. \ No newline at end of file diff --git a/quickstarts/10-test-durable-agent/app.py b/quickstarts/10-test-durable-agent/app.py new file mode 100644 index 00000000..1084c82e --- /dev/null +++ b/quickstarts/10-test-durable-agent/app.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import asyncio +import logging + +from dotenv import load_dotenv + +from dapr_agents.agents.durable import DurableAgent +from dapr_agents.agents.configs import ( + AgentMemoryConfig, + AgentPubSubConfig, + AgentRegistryConfig, + AgentStateConfig, +) +from dapr_agents.agents.prompting import AgentProfileConfig +from dapr_agents.llm.openai import OpenAIChatClient +from dapr_agents.memory import ConversationDaprStateMemory +from dapr_agents.storage.daprstores.stateservice import StateStoreService +from dapr_agents.workflow.runners import AgentRunner +from dapr_agents.workflow.utils.core import wait_for_shutdown + +# Load environment +load_dotenv() + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s [%(name)s] %(message)s", +) +logger = logging.getLogger("durable-agent-app") + + +async def main() -> None: + """ + Minimal durable-agent quickstart. + + Notes: + - If you do not specify a custom state/message model, AgentStateConfig defaults to + AgentWorkflowState / AgentWorkflowMessage under the hood. + """ + agent_name = "blog-agent" + + # --- Pub/Sub & State/Registry wiring ------------------------------------- + pubsub_config = AgentPubSubConfig( + pubsub_name="messagepubsub", + agent_topic="blog.requests", + broadcast_topic="agents.broadcast", + ) + + state_config = AgentStateConfig( + store=StateStoreService(store_name="workflowstatestore", key_prefix="blog:") + # No default_state/state_model_cls/message_model_cls → uses defaults. + ) + + registry_config = AgentRegistryConfig( + store=StateStoreService(store_name="agentregistrystore"), + team_name="bloggers", + ) + + # --- Profile / Prompting --------------------------------------------------- + profile = AgentProfileConfig( + name="Blog Agent", + role="AI Blogger", + goal="Write engaging blog updates and summarize content for readers.", + instructions=[ + "Summarize new information clearly in 2-3 sentences.", + "Recommend one follow-up topic the audience might enjoy.", + ], + style_guidelines=[ + "Use friendly, professional tone.", + "Avoid technical jargon unless asked.", + ], + ) + + # --- Memory (Dapr-backed conversation history) ----------------------------- + memory_config = AgentMemoryConfig( + store=ConversationDaprStateMemory( + store_name="memorystore", + session_id=f"{agent_name}-session", + ) + ) + + # --- LLM Client ------------------------------------------------------------ + llm = OpenAIChatClient() + + # --- Assemble durable agent ------------------------------------------------ + agent = DurableAgent( + profile_config=profile, + pubsub_config=pubsub_config, + state_config=state_config, + registry_config=registry_config, + memory_config=memory_config, + llm=llm, + ) + agent.start() + + # --- HTTP runner for workflow endpoints ----------------------------------- + runner = AgentRunner() + try: + runner.register_routes(agent) + await wait_for_shutdown() + finally: + runner.shutdown() + agent.stop() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass diff --git a/quickstarts/10-test-durable-agent/components/memorystore.yaml b/quickstarts/10-test-durable-agent/components/memorystore.yaml new file mode 100644 index 00000000..7e3634c4 --- /dev/null +++ b/quickstarts/10-test-durable-agent/components/memorystore.yaml @@ -0,0 +1,10 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: memorystore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 \ No newline at end of file diff --git a/quickstarts/10-test-durable-agent/components/openai.yaml b/quickstarts/10-test-durable-agent/components/openai.yaml new file mode 100644 index 00000000..7c518fb2 --- /dev/null +++ b/quickstarts/10-test-durable-agent/components/openai.yaml @@ -0,0 +1,14 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: openai +spec: + type: conversation.openai + version: v1 + metadata: + - name: key + value: "{{OPENAI_API_KEY}}" + - name: model + value: gpt-5-mini + - name: temperature + value: 1 diff --git a/quickstarts/10-test-durable-agent/components/pubsub.yaml b/quickstarts/10-test-durable-agent/components/pubsub.yaml new file mode 100644 index 00000000..6bd05aca --- /dev/null +++ b/quickstarts/10-test-durable-agent/components/pubsub.yaml @@ -0,0 +1,10 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: messagepubsub +spec: + type: pubsub.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 diff --git a/quickstarts/10-test-durable-agent/components/registrystore.yaml b/quickstarts/10-test-durable-agent/components/registrystore.yaml new file mode 100644 index 00000000..3bfd3ef4 --- /dev/null +++ b/quickstarts/10-test-durable-agent/components/registrystore.yaml @@ -0,0 +1,12 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: agentregistrystore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 + - name: keyPrefix + value: none \ No newline at end of file diff --git a/quickstarts/10-test-durable-agent/components/statestore.yaml b/quickstarts/10-test-durable-agent/components/statestore.yaml new file mode 100644 index 00000000..2fc32cd0 --- /dev/null +++ b/quickstarts/10-test-durable-agent/components/statestore.yaml @@ -0,0 +1,12 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: workflowstatestore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 + - name: actorStateStore + value: "true" \ No newline at end of file diff --git a/quickstarts/10-test-durable-agent/message_client.py b/quickstarts/10-test-durable-agent/message_client.py new file mode 100644 index 00000000..46a66a29 --- /dev/null +++ b/quickstarts/10-test-durable-agent/message_client.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import os +import random +import signal +import sys +import uuid +from typing import Any, Dict + +from dapr.clients import DaprClient + +# --------------------------- +# Config via environment vars +# --------------------------- +PUBSUB_NAME = os.getenv("PUBSUB_NAME", "messagepubsub") +TOPIC_NAME = os.getenv("TOPIC_NAME", "blog.requests") +BLOG_TOPIC = os.getenv("BLOG_TOPIC", "AI Agents") # used when RAW_DATA is not provided +RAW_DATA = os.getenv("RAW_DATA") # if set, must be a JSON object (string) +CONTENT_TYPE = os.getenv("CONTENT_TYPE", "application/json") +CE_TYPE = os.getenv("CLOUDEVENT_TYPE") # optional CloudEvent 'type' metadata +CLIENT_SOURCE = os.getenv("CLIENT_SOURCE", "quickstarts.message_client") + +# Publish behavior +PUBLISH_ONCE = os.getenv("PUBLISH_ONCE", "true").lower() in {"1", "true", "yes"} +INTERVAL_SEC = float(os.getenv("INTERVAL_SEC", "0")) # used when PUBLISH_ONCE=false +MAX_ATTEMPTS = int(os.getenv("MAX_ATTEMPTS", "8")) +INITIAL_DELAY = float(os.getenv("INITIAL_DELAY", "0.5")) +BACKOFF_FACTOR = float(os.getenv("BACKOFF_FACTOR", "2.0")) +JITTER_FRAC = float(os.getenv("JITTER_FRAC", "0.2")) + +# Optional warmup (give sidecar/broker a moment) +STARTUP_DELAY = float(os.getenv("STARTUP_DELAY", "1.0")) + +logger = logging.getLogger("publisher") + + +async def _backoff_sleep(delay: float, jitter: float, factor: float) -> float: + """Sleep for ~delay seconds with ±jitter% randomness, then return the next delay.""" + actual = max(0.0, delay * (1 + random.uniform(-jitter, jitter))) + if actual: + await asyncio.sleep(actual) + return delay * factor + + +def _build_payload() -> Dict[str, Any]: + """ + Build the JSON payload: + - if RAW_DATA is set → parse as JSON (must be an object) + - else → default TriggerAction with a blog prompt + """ + if RAW_DATA: + try: + data = json.loads(RAW_DATA) + except Exception as exc: # noqa: BLE001 + raise ValueError(f"Invalid RAW_DATA JSON: {exc}") from exc + if not isinstance(data, dict): + raise ValueError("RAW_DATA must be a JSON object") + payload = dict(data) + else: + payload = { + "task": f"Write a concise paragraph about '{BLOG_TOPIC}' and suggest a follow-up topic.", + } + + # Derive task from other fields if still missing + if "task" not in payload: + topic = payload.get("topic") or BLOG_TOPIC + payload["task"] = f"Write a helpful response about '{topic}'." + + # Ensure metadata hints are attached for downstream tracing + metadata = dict(payload.get("metadata") or {}) + metadata.setdefault("source", CLIENT_SOURCE) + metadata.setdefault("request_id", str(uuid.uuid4())) + payload["metadata"] = metadata + + return payload + + +def _encode_payload(payload: Dict[str, Any]) -> bytes: + """Encode the payload as UTF-8 JSON bytes.""" + return json.dumps(payload, ensure_ascii=False).encode("utf-8") + + +async def publish_once(client: DaprClient, payload: Dict[str, Any]) -> None: + """Publish once with retries and exponential backoff.""" + delay = INITIAL_DELAY + body = _encode_payload(payload) + + for attempt in range(1, MAX_ATTEMPTS + 1): + try: + logger.info("publish attempt %d → %s/%s", attempt, PUBSUB_NAME, TOPIC_NAME) + client.publish_event( + pubsub_name=PUBSUB_NAME, + topic_name=TOPIC_NAME, + data=body, + data_content_type=CONTENT_TYPE, + publish_metadata=({"cloudevent.type": CE_TYPE} if CE_TYPE else None), + ) + logger.info("published successfully") + return + except Exception as exc: # noqa: BLE001 + logger.warning("publish failed: %s", exc) + if attempt == MAX_ATTEMPTS: + raise + logger.info("retrying in ~%.2fs …", delay) + delay = await _backoff_sleep(delay, JITTER_FRAC, BACKOFF_FACTOR) + + +async def main() -> int: + logging.basicConfig(level=logging.INFO) + stop_event = asyncio.Event() + + # Signal-aware shutdown + loop = asyncio.get_running_loop() + + def _stop(*_: object) -> None: + stop_event.set() + + try: + loop.add_signal_handler(signal.SIGINT, _stop) + loop.add_signal_handler(signal.SIGTERM, _stop) + except NotImplementedError: + signal.signal(signal.SIGINT, lambda *_: _stop()) + signal.signal(signal.SIGTERM, lambda *_: _stop()) + + # Optional warmup + if STARTUP_DELAY > 0: + await asyncio.sleep(STARTUP_DELAY) + + try: + with DaprClient() as client: + if PUBLISH_ONCE: + payload = _build_payload() + logger.info("payload: %s", payload) + await publish_once(client, payload) + # brief wait so logs flush nicely under dapr + await asyncio.sleep(0.2) + return 0 + + # periodic mode + if INTERVAL_SEC <= 0: + logger.error("INTERVAL_SEC must be > 0 when PUBLISH_ONCE=false") + return 2 + + logger.info("starting periodic publisher every %.2fs", INTERVAL_SEC) + while not stop_event.is_set(): + try: + payload = _build_payload() + logger.info("payload: %s", payload) + await publish_once(client, payload) + except Exception as exc: # noqa: BLE001 + logger.error("giving up after %d attempts: %s", MAX_ATTEMPTS, exc) + + # wait for next tick or shutdown + try: + await asyncio.wait_for(stop_event.wait(), timeout=INTERVAL_SEC) + except asyncio.TimeoutError: + pass + + logger.info("shutdown requested; exiting") + return 0 + + except KeyboardInterrupt: + return 130 + except Exception as exc: # noqa: BLE001 + logger.exception("fatal error: %s", exc) + return 1 + + +if __name__ == "__main__": + sys.exit(asyncio.run(main())) diff --git a/quickstarts/10-test-durable-agent/models.py b/quickstarts/10-test-durable-agent/models.py new file mode 100644 index 00000000..4ce41e17 --- /dev/null +++ b/quickstarts/10-test-durable-agent/models.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class StartBlogMessage(BaseModel): + topic: str = Field(min_length=1, description="Blog topic/title") diff --git a/quickstarts/11-test-multi-agents/README.md b/quickstarts/11-test-multi-agents/README.md new file mode 100644 index 00000000..a5d56219 --- /dev/null +++ b/quickstarts/11-test-multi-agents/README.md @@ -0,0 +1,32 @@ +# Test New DurableAgents and Orchestrators + +## Prerequisites + +- Python 3.10+ +- Dapr CLI & Docker +- OpenAI-compatible API key (for the `llm_activity` decorators) + +## Setup + +Install dependencies in your virtual environment (see repository root instructions) and ensure Dapr is initialised: + +```bash +dapr init +``` + +Provide your OpenAI key via `.env` or by editing `components/openai.yaml`, identical to the earlier quickstart. + +## Run the app + +```bash +# Terminal 1 – run the multi-agent system +dapr run -f dapr-llm.yaml +``` + +You should see the workflow started and the generated blog post appear in the app logs. + +## How it maps to the new host + +- `RouteSpec` couples pub/sub metadata with the Pydantic message schema and the registered workflow name. +- Workflows/activities are registered once when constructing `DaprWorkflowApp`; the host handles playback-safe scheduling and lifecycle. +- The code under `workflow.py` is identical to quickstart 04—only the hosting logic changed. diff --git a/quickstarts/11-test-multi-agents/components/memorystore.yaml b/quickstarts/11-test-multi-agents/components/memorystore.yaml new file mode 100644 index 00000000..7e3634c4 --- /dev/null +++ b/quickstarts/11-test-multi-agents/components/memorystore.yaml @@ -0,0 +1,10 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: memorystore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 \ No newline at end of file diff --git a/quickstarts/11-test-multi-agents/components/openai.yaml b/quickstarts/11-test-multi-agents/components/openai.yaml new file mode 100644 index 00000000..7c518fb2 --- /dev/null +++ b/quickstarts/11-test-multi-agents/components/openai.yaml @@ -0,0 +1,14 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: openai +spec: + type: conversation.openai + version: v1 + metadata: + - name: key + value: "{{OPENAI_API_KEY}}" + - name: model + value: gpt-5-mini + - name: temperature + value: 1 diff --git a/quickstarts/11-test-multi-agents/components/pubsub.yaml b/quickstarts/11-test-multi-agents/components/pubsub.yaml new file mode 100644 index 00000000..6bd05aca --- /dev/null +++ b/quickstarts/11-test-multi-agents/components/pubsub.yaml @@ -0,0 +1,10 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: messagepubsub +spec: + type: pubsub.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 diff --git a/quickstarts/11-test-multi-agents/components/registrystore.yaml b/quickstarts/11-test-multi-agents/components/registrystore.yaml new file mode 100644 index 00000000..3bfd3ef4 --- /dev/null +++ b/quickstarts/11-test-multi-agents/components/registrystore.yaml @@ -0,0 +1,12 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: agentregistrystore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 + - name: keyPrefix + value: none \ No newline at end of file diff --git a/quickstarts/11-test-multi-agents/components/statestore.yaml b/quickstarts/11-test-multi-agents/components/statestore.yaml new file mode 100644 index 00000000..2fc32cd0 --- /dev/null +++ b/quickstarts/11-test-multi-agents/components/statestore.yaml @@ -0,0 +1,12 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: workflowstatestore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 + - name: actorStateStore + value: "true" \ No newline at end of file diff --git a/quickstarts/11-test-multi-agents/dapr-llm.yaml b/quickstarts/11-test-multi-agents/dapr-llm.yaml new file mode 100644 index 00000000..f5c9b6b6 --- /dev/null +++ b/quickstarts/11-test-multi-agents/dapr-llm.yaml @@ -0,0 +1,29 @@ +# https://docs.dapr.io/developing-applications/local-development/multi-app-dapr-run/multi-app-template/#template-properties +version: 1 +common: + resourcesPath: ./components + logLevel: info + appLogDestination: console + daprdLogDestination: console + +apps: +- appID: FrodoApp + appDirPath: ./services/frodo/ + command: ["python3", "app.py"] + +- appID: SamApp + appDirPath: ./services/sam/ + command: ["python3", "app.py"] + +- appID: GandalfApp + appDirPath: ./services/gandalf/ + command: ["python3", "app.py"] + +- appID: LLMOrchestratorApp + appDirPath: ./services/workflow-llm/ + command: ["python3", "app.py"] + + +- appID: ClientApp + appDirPath: ./services/client/ + command: ["python3", "pubsub_client.py", "--orchestrator", "llm"] \ No newline at end of file diff --git a/quickstarts/11-test-multi-agents/dapr-random.yaml b/quickstarts/11-test-multi-agents/dapr-random.yaml new file mode 100644 index 00000000..bf67462c --- /dev/null +++ b/quickstarts/11-test-multi-agents/dapr-random.yaml @@ -0,0 +1,24 @@ +# https://docs.dapr.io/developing-applications/local-development/multi-app-dapr-run/multi-app-template/#template-properties +version: 1 +common: + resourcesPath: ./components + logLevel: info + appLogDestination: console + daprdLogDestination: console + +apps: +- appID: FrodoApp + appDirPath: ./services/frodo/ + command: ["python3", "app.py"] + +- appID: SamApp + appDirPath: ./services/sam/ + command: ["python3", "app.py"] + +- appID: RandomApp + appDirPath: ./services/workflow-random/ + command: ["python3", "app.py"] + +- appID: ClientApp + appDirPath: ./services/client/ + command: ["python3", "pubsub_client.py", "--orchestrator", "random"] \ No newline at end of file diff --git a/quickstarts/11-test-multi-agents/dapr-roundrobin.yaml b/quickstarts/11-test-multi-agents/dapr-roundrobin.yaml new file mode 100644 index 00000000..e76bc67f --- /dev/null +++ b/quickstarts/11-test-multi-agents/dapr-roundrobin.yaml @@ -0,0 +1,24 @@ +# https://docs.dapr.io/developing-applications/local-development/multi-app-dapr-run/multi-app-template/#template-properties +version: 1 +common: + resourcesPath: ./components + logLevel: info + appLogDestination: console + daprdLogDestination: console + +apps: +- appID: FrodoApp + appDirPath: ./services/frodo/ + command: ["python3", "app.py"] + +- appID: SamApp + appDirPath: ./services/sam/ + command: ["python3", "app.py"] + +- appID: RoundRobinApp + appDirPath: ./services/workflow-roundrobin/ + command: ["python3", "app.py"] + +- appID: ClientApp + appDirPath: ./services/client/ + command: ["python3", "pubsub_client.py", "--orchestrator", "roundrobin"] \ No newline at end of file diff --git a/quickstarts/11-test-multi-agents/services/client/http_client.py b/quickstarts/11-test-multi-agents/services/client/http_client.py new file mode 100644 index 00000000..f39dd26f --- /dev/null +++ b/quickstarts/11-test-multi-agents/services/client/http_client.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +import requests +import time +import sys + + +if __name__ == "__main__": + status_url = "http://localhost:8004/status" + healthy = False + for attempt in range(1, 11): + try: + print(f"Attempt {attempt}...") + response = requests.get(status_url, timeout=5) + + if response.status_code == 200: + print("Workflow app is healthy!") + healthy = True + break + else: + print(f"Received status code {response.status_code}: {response.text}") + + except requests.exceptions.RequestException as e: + print(f"Request failed: {e}") + + attempt += 1 + print("Waiting 5s seconds before next health checkattempt...") + time.sleep(5) + + if not healthy: + print("Workflow app is not healthy!") + sys.exit(1) + + workflow_url = "http://localhost:8004/start-workflow" + task_payload = {"task": "How to get to Mordor? We all need to help!"} + + try: + print("Starting workflow...") + response = requests.post(workflow_url, json=task_payload, timeout=120) + + if response.status_code == 202: + print("Workflow started successfully!") + print("Check Phoenix UI for workflow progress.") + sys.exit(0) + else: + print(f"Received status code {response.status_code}: {response.text}") + sys.exit(1) + + except requests.exceptions.RequestException as e: + print(f"Request failed: {e}") + sys.exit(1) diff --git a/quickstarts/11-test-multi-agents/services/client/pubsub_client.py b/quickstarts/11-test-multi-agents/services/client/pubsub_client.py new file mode 100644 index 00000000..8409fbd1 --- /dev/null +++ b/quickstarts/11-test-multi-agents/services/client/pubsub_client.py @@ -0,0 +1,321 @@ +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import os +import random +import signal +import sys +from typing import Any, Dict, List + +from dapr.clients import DaprClient + +# --------------------------- +# Env defaults (CLI can override) +# --------------------------- +PUBSUB_NAME = os.getenv("PUBSUB_NAME", "messagepubsub") + +# Named topics (used when --orchestrator is provided) +RANDOM_TOPIC_DEFAULT = os.getenv( + "RANDOM_TOPIC", "fellowship.orchestrator.random.requests" +) +ROUNDROBIN_TOPIC_DEFAULT = os.getenv( + "ROUNDROBIN_TOPIC", "fellowship.orchestrator.roundrobin.requests" +) +LLM_TOPIC_DEFAULT = os.getenv("LLM_TOPIC", "llm.orchestrator.requests") + +# Legacy single-topic env (still honored if you pass --topic without a value) +ORCHESTRATOR_TOPIC_ENV = os.getenv("ORCHESTRATOR_TOPIC") + +RAW_DATA = os.getenv("RAW_DATA") +TASK_TEXT_DEFAULT = os.getenv( + "TASK_TEXT", + "Set the next step for the journey to Mordor. Consider safety, supplies, and stealth. Start from the Shire to Bree.", +) + +CONTENT_TYPE = os.getenv("CONTENT_TYPE", "application/json") +CLOUDEVENT_TYPE_DEFAULT = os.getenv("CLOUDEVENT_TYPE", "TriggerAction") + +PUBLISH_ONCE_DEFAULT = os.getenv("PUBLISH_ONCE", "true").lower() in {"1", "true", "yes"} +INTERVAL_SEC_DEFAULT = float(os.getenv("INTERVAL_SEC", "0")) +MAX_ATTEMPTS_DEFAULT = int(os.getenv("MAX_ATTEMPTS", "8")) +INITIAL_DELAY_DEFAULT = float(os.getenv("INITIAL_DELAY", "0.5")) +BACKOFF_FACTOR_DEFAULT = float(os.getenv("BACKOFF_FACTOR", "2.0")) +JITTER_FRAC_DEFAULT = float(os.getenv("JITTER_FRAC", "0.2")) +STARTUP_DELAY_DEFAULT = float(os.getenv("STARTUP_DELAY", "5.0")) + +logger = logging.getLogger("orchestrator_publisher") + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser( + description="Publish TriggerAction messages to orchestrator(s)." + ) + # Orchestrator selection (mutually exclusive with explicit topics) + p.add_argument( + "--orchestrator", + choices=["random", "roundrobin", "llm"], + help="Route to a named orchestrator topic (or both).", + ) + p.add_argument( + "--topic", + action="append", + help="Explicit topic name (can be repeated). If provided, --orchestrator is ignored.", + ) + + # Message/data knobs + p.add_argument("--task", help="Task text to send (ignored if --raw JSON is used).") + p.add_argument("--raw", help="Raw JSON object to send as the event data.") + p.add_argument( + "--ce-type", default=CLOUDEVENT_TYPE_DEFAULT, help="CloudEvent type to set." + ) + p.add_argument("--content-type", default=CONTENT_TYPE, help="Content type.") + + # Behavior + p.add_argument("--pubsub", default=PUBSUB_NAME, help="Dapr pubsub name.") + p.add_argument( + "--once", + dest="publish_once", + action="store_true", + help="Publish once and exit.", + ) + p.add_argument( + "--loop", + dest="publish_once", + action="store_false", + help="Publish periodically.", + ) + p.set_defaults(publish_once=PUBLISH_ONCE_DEFAULT) + + p.add_argument( + "--interval", + type=float, + default=INTERVAL_SEC_DEFAULT, + help="Interval seconds when looping.", + ) + p.add_argument( + "--startup-delay", + type=float, + default=STARTUP_DELAY_DEFAULT, + help="Initial delay seconds.", + ) + p.add_argument( + "--max-attempts", + type=int, + default=MAX_ATTEMPTS_DEFAULT, + help="Max retry attempts.", + ) + p.add_argument( + "--initial-delay", + type=float, + default=INITIAL_DELAY_DEFAULT, + help="Initial backoff seconds.", + ) + p.add_argument( + "--backoff-factor", + type=float, + default=BACKOFF_FACTOR_DEFAULT, + help="Backoff multiplier.", + ) + p.add_argument( + "--jitter", + type=float, + default=JITTER_FRAC_DEFAULT, + help="± jitter fraction on backoff.", + ) + + # Topic defaults (so you can override per run) + p.add_argument( + "--random-topic", + default=RANDOM_TOPIC_DEFAULT, + help="Topic for random orchestrator.", + ) + p.add_argument( + "--roundrobin-topic", + default=ROUNDROBIN_TOPIC_DEFAULT, + help="Topic for round-robin orchestrator.", + ) + p.add_argument( + "--llm-topic", default=LLM_TOPIC_DEFAULT, help="Topic for LLM orchestrator." + ) + + return p.parse_args() + + +async def _backoff_sleep(delay: float, jitter: float, factor: float) -> float: + actual = max(0.0, delay * (1 + random.uniform(-jitter, jitter))) + if actual: + await asyncio.sleep(actual) + return delay * factor + + +def _build_payload(args: argparse.Namespace) -> Dict[str, Any]: + if args.raw or RAW_DATA: + raw = args.raw or RAW_DATA + try: + data = json.loads(raw) + except Exception as exc: # noqa: BLE001 + raise ValueError(f"Invalid --raw/RAW_DATA JSON: {exc}") from exc + if not isinstance(data, dict): + raise ValueError("--raw/RAW_DATA must be a JSON object") + return data + return {"task": args.task or TASK_TEXT_DEFAULT} + + +def _encode_payload(payload: Dict[str, Any]) -> bytes: + return json.dumps(payload, ensure_ascii=False).encode("utf-8") + + +def _resolve_topics(args: argparse.Namespace) -> List[str]: + # If explicit topics were passed, use them directly + if args.topic: + return args.topic + + # Otherwise use orchestrator selection + if args.orchestrator == "random": + return [args.random_topic] + if args.orchestrator == "roundrobin": + return [args.roundrobin_topic] + if args.orchestrator == "llm": + return [args.llm_topic] + + # Fallbacks: + # 1) Legacy single-topic env if set + if ORCHESTRATOR_TOPIC_ENV: + return [ORCHESTRATOR_TOPIC_ENV] + + # 2) Default to random + return [args.random_topic] + + +async def publish_once( + client: DaprClient, + payload: Dict[str, Any], + *, + pubsub_name: str, + topics: List[str], + content_type: str, + ce_type: str | None, + max_attempts: int, + initial_delay: float, + jitter: float, + backoff_factor: float, +) -> None: + delay = initial_delay + body = _encode_payload(payload) + + for attempt in range(1, max_attempts + 1): + try: + for topic in topics: + logger.info("publish attempt %d → %s/%s", attempt, pubsub_name, topic) + client.publish_event( + pubsub_name=pubsub_name, + topic_name=topic, + data=body, + data_content_type=content_type, + publish_metadata=( + {"cloudevent.type": ce_type} if ce_type else None + ), + ) + logger.info("published successfully to %s", topic) + return + except Exception as exc: # noqa: BLE001 + logger.warning("publish failed: %s", exc) + if attempt == max_attempts: + raise + logger.info("retrying in ~%.2fs …", delay) + delay = await _backoff_sleep(delay, jitter, backoff_factor) + + +async def main() -> int: + logging.basicConfig(level=logging.INFO) + args = parse_args() + stop_event = asyncio.Event() + + # Signal-aware shutdown + loop = asyncio.get_running_loop() + + def _stop(*_: object) -> None: + stop_event.set() + + try: + loop.add_signal_handler(signal.SIGINT, _stop) + loop.add_signal_handler(signal.SIGTERM, _stop) + except NotImplementedError: + import signal as _signal + + _signal.signal(_signal.SIGINT, lambda *_: _stop()) + _signal.signal(_signal.SIGTERM, lambda *_: _stop()) + + # Optional warmup + if args.startup_delay > 0: + await asyncio.sleep(args.startup_delay) + + topics = _resolve_topics(args) + payload = _build_payload(args) + logger.info("payload: %s", payload) + logger.info("topics: %s", topics) + + try: + with DaprClient() as client: + if args.publish_once: + await publish_once( + client, + payload, + pubsub_name=args.pubsub, + topics=topics, + content_type=args.content_type, + ce_type=args.ce_type, + max_attempts=args.max_attempts, + initial_delay=args.initial_delay, + jitter=args.jitter, + backoff_factor=args.backoff_factor, + ) + await asyncio.sleep(0.2) + return 0 + + # periodic mode + if args.interval <= 0: + logger.error("--interval must be > 0 when --loop is used") + return 2 + + logger.info("starting periodic publisher every %.2fs", args.interval) + while not stop_event.is_set(): + try: + payload = _build_payload(args) + await publish_once( + client, + payload, + pubsub_name=args.pubsub, + topics=topics, + content_type=args.content_type, + ce_type=args.ce_type, + max_attempts=args.max_attempts, + initial_delay=args.initial_delay, + jitter=args.jitter, + backoff_factor=args.backoff_factor, + ) + except Exception as exc: # noqa: BLE001 + logger.error( + "giving up after %d attempts: %s", args.max_attempts, exc + ) + + try: + await asyncio.wait_for(stop_event.wait(), timeout=args.interval) + except asyncio.TimeoutError: + pass + + logger.info("shutdown requested; exiting") + return 0 + + except KeyboardInterrupt: + return 130 + except Exception as exc: # noqa: BLE001 + logger.exception("fatal error: %s", exc) + return 1 + + +if __name__ == "__main__": + sys.exit(asyncio.run(main())) diff --git a/quickstarts/11-test-multi-agents/services/frodo/app.py b/quickstarts/11-test-multi-agents/services/frodo/app.py new file mode 100644 index 00000000..d1dae215 --- /dev/null +++ b/quickstarts/11-test-multi-agents/services/frodo/app.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +import asyncio +import logging + +from dotenv import load_dotenv + +from dapr_agents.agents.durable import DurableAgent +from dapr_agents.agents.configs import ( + AgentMemoryConfig, + AgentPubSubConfig, + AgentRegistryConfig, + AgentStateConfig, +) +from dapr_agents.agents.prompting import AgentProfileConfig +from dapr_agents.llm.openai import OpenAIChatClient +from dapr_agents.memory import ConversationDaprStateMemory +from dapr_agents.storage.daprstores.stateservice import StateStoreService +from dapr_agents.workflow.runners import AgentRunner +from dapr_agents.workflow.utils.core import wait_for_shutdown + +# Load environment +load_dotenv() + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s [%(name)s] %(message)s", +) +logger = logging.getLogger("fellowship-app") + + +async def main() -> None: + """ + Fellowship durable-agent app. + + Notes: + - If you don't specify models in AgentStateConfig, it defaults to + AgentWorkflowState / AgentWorkflowMessage internally. + - We run TWO durable agents: Frodo and Sam, in the same process. + """ + # Shared infra (registry) + registry_config = AgentRegistryConfig( + store=StateStoreService(store_name="agentregistrystore"), + team_name="fellowship", + ) + + # Single LLM client reused for both agents + llm = OpenAIChatClient() + + # --------------------------- + # Frodo (journey lead) + # --------------------------- + frodo_name = "frodo" + + frodo_pubsub = AgentPubSubConfig( + pubsub_name="messagepubsub", + agent_topic="fellowship.frodo.requests", + broadcast_topic="fellowship.broadcast", + ) + frodo_state = AgentStateConfig( + store=StateStoreService(store_name="workflowstatestore", key_prefix="frodo:") + # Default state/message models will be used. + ) + frodo_memory = AgentMemoryConfig( + store=ConversationDaprStateMemory( + store_name="memorystore", + session_id=f"{frodo_name}-session", + ) + ) + frodo_profile = AgentProfileConfig( + name="Frodo Baggins", + role="Ring-bearer & Journey Lead", + goal=( + "Safely navigate Middle-earth toward Mount Doom, making prudent decisions, " + "asking for help when needed, and keeping the Fellowship aligned." + ), + instructions=[ + "Plan cautious, low-profile routes across Middle-earth.", + "Identify risks (patrols, Nazgûl, terrain, travel time) and suggest mitigations.", + "When uncertain, propose two options and recommend one with rationale.", + "Defer arcane lore questions to Gandalf or the most relevant expert agent.", + "Summarize the current status in 1-2 sentences at the end of each reply.", + ], + style_guidelines=[ + "Keep tone steady, humble, and focused.", + "Prefer concise, decisive recommendations with brief justification.", + "Avoid unnecessary bravado; safety first.", + ], + # If you use prompt modules later, you can list them here: + modules=("navigation", "risk-assessment"), + ) + + frodo = DurableAgent( + profile_config=frodo_profile, + pubsub_config=frodo_pubsub, + state_config=frodo_state, + registry_config=registry_config, + memory_config=frodo_memory, + llm=llm, + ) + frodo.start() + + # --------------------------- + # HTTP runner (both agents) + # --------------------------- + runner = AgentRunner() + try: + # Expose both agents’ endpoints + runner.register_routes(frodo) + await wait_for_shutdown() + finally: + runner.shutdown() + frodo.stop() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass diff --git a/quickstarts/11-test-multi-agents/services/gandalf/app.py b/quickstarts/11-test-multi-agents/services/gandalf/app.py new file mode 100644 index 00000000..5201df51 --- /dev/null +++ b/quickstarts/11-test-multi-agents/services/gandalf/app.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import asyncio +import logging + +from dotenv import load_dotenv + +from dapr_agents.agents.durable import DurableAgent +from dapr_agents.agents.configs import ( + AgentMemoryConfig, + AgentPubSubConfig, + AgentRegistryConfig, + AgentStateConfig, +) +from dapr_agents.agents.prompting import AgentProfileConfig +from dapr_agents.llm.openai import OpenAIChatClient +from dapr_agents.memory import ConversationDaprStateMemory +from dapr_agents.storage.daprstores.stateservice import StateStoreService +from dapr_agents.workflow.runners import AgentRunner +from dapr_agents.workflow.utils.core import wait_for_shutdown + +# Load environment +load_dotenv() + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s [%(name)s] %(message)s", +) +logger = logging.getLogger("fellowship-app") + + +async def main() -> None: + """ + Gandalf durable-agent app. + + Notes: + - If you don't specify models in AgentStateConfig, it defaults to + AgentWorkflowState / AgentWorkflowMessage internally. + """ + # Shared infra (registry) + registry_config = AgentRegistryConfig( + store=StateStoreService(store_name="agentregistrystore"), + team_name="fellowship", + ) + + # Single LLM client reused for both agents + llm = OpenAIChatClient() + + # --------------------------- + # Gandalf (wizard & loremaster) + # --------------------------- + gandalf_name = "Gandalf" + + gandalf_pubsub = AgentPubSubConfig( + pubsub_name="messagepubsub", + agent_topic="fellowship.gandalf.requests", + broadcast_topic="fellowship.broadcast", + ) + gandalf_state = AgentStateConfig( + store=StateStoreService(store_name="workflowstatestore", key_prefix="gandalf:") + # Default state/message models will be used. + ) + gandalf_memory = AgentMemoryConfig( + store=ConversationDaprStateMemory( + store_name="memorystore", + session_id=f"{gandalf_name}-session", + ) + ) + gandalf_profile = AgentProfileConfig( + name="Gandalf", + role="Wizard & Loremaster", + goal=( + "Provide wisdom, arcane knowledge, and strategic counsel to the Fellowship. " + "Guide the journey with deep understanding of Middle-earth's history and magic." + ), + instructions=[ + "Share knowledge of Middle-earth lore, ancient history, and magical matters.", + "Provide strategic advice based on deep wisdom and foresight.", + "Warn of dangers and explain the significance of artifacts, places, and creatures.", + "When asked for guidance, offer multiple perspectives but recommend the wisest path.", + "Support other Fellowship members with counsel when they face difficult decisions.", + ], + style_guidelines=[ + "Speak with gravitas and wisdom, but warmth when appropriate.", + "Balance mystery with clarity - be cryptic only when necessary.", + "Show patience and understanding, especially with the hobbits.", + "Use occasional references to ancient lore or history to add depth.", + ], + modules=("lore", "magic", "strategy"), + ) + + gandalf = DurableAgent( + profile_config=gandalf_profile, + pubsub_config=gandalf_pubsub, + state_config=gandalf_state, + registry_config=registry_config, + memory_config=gandalf_memory, + llm=llm, + ) + gandalf.start() + + # --------------------------- + # HTTP runner + # --------------------------- + runner = AgentRunner() + try: + runner.register_routes(gandalf) + await wait_for_shutdown() + finally: + runner.shutdown() + gandalf.stop() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass diff --git a/quickstarts/11-test-multi-agents/services/sam/app.py b/quickstarts/11-test-multi-agents/services/sam/app.py new file mode 100644 index 00000000..60dff3f6 --- /dev/null +++ b/quickstarts/11-test-multi-agents/services/sam/app.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import asyncio +import logging + +from dotenv import load_dotenv + +from dapr_agents.agents.durable import DurableAgent +from dapr_agents.agents.configs import ( + AgentMemoryConfig, + AgentPubSubConfig, + AgentRegistryConfig, + AgentStateConfig, +) +from dapr_agents.agents.prompting import AgentProfileConfig +from dapr_agents.llm.openai import OpenAIChatClient +from dapr_agents.memory import ConversationDaprStateMemory +from dapr_agents.storage.daprstores.stateservice import StateStoreService +from dapr_agents.workflow.runners import AgentRunner +from dapr_agents.workflow.utils.core import wait_for_shutdown + +# Load environment +load_dotenv() + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s [%(name)s] %(message)s", +) +logger = logging.getLogger("fellowship-app") + + +async def main() -> None: + """ + Fellowship durable-agent app. + + Notes: + - If you don't specify models in AgentStateConfig, it defaults to + AgentWorkflowState / AgentWorkflowMessage internally. + - We run TWO durable agents: Frodo and Sam, in the same process. + """ + # Shared infra (registry) + registry_config = AgentRegistryConfig( + store=StateStoreService(store_name="agentregistrystore"), + team_name="fellowship", + ) + + # Single LLM client reused for both agents + llm = OpenAIChatClient() + + # --------------------------- + # Sam (logistics & support) + # --------------------------- + sam_name = "sam" + + sam_pubsub = AgentPubSubConfig( + pubsub_name="messagepubsub", + agent_topic="fellowship.sam.requests", + broadcast_topic="fellowship.broadcast", + ) + sam_state = AgentStateConfig( + store=StateStoreService(store_name="workflowstatestore", key_prefix="sam:") + ) + sam_memory = AgentMemoryConfig( + store=ConversationDaprStateMemory( + store_name="memorystore", + session_id=f"{sam_name}-session", + ) + ) + sam_profile = AgentProfileConfig( + name="Samwise Gamgee", + role="Logistics, Provisions & Morale", + goal=( + "Keep the party supplied, rested, and on schedule; watch Frodo's back, " + "offer practical counsel, and maintain morale." + ), + instructions=[ + "Track food, water, camp gear, and rest cadence; flag shortages early.", + "Advise on campsite selection (cover, water, distance from threats).", + "Provide pragmatic alternatives to risky ideas and note trade-offs.", + "Remind the party of the plan and next small step when confusion arises.", + "End with a short, encouraging line if stakes are high.", + ], + style_guidelines=[ + "Warm, plain-spoken, and grounded.", + "Emphasize practicality over poetry.", + "Stay loyal and supportive, especially under pressure.", + ], + modules=("supplies", "campcraft"), + ) + + sam = DurableAgent( + profile_config=sam_profile, + pubsub_config=sam_pubsub, + state_config=sam_state, + registry_config=registry_config, + memory_config=sam_memory, + llm=llm, + ) + sam.start() + + # --------------------------- + # HTTP runner (both agents) + # --------------------------- + runner = AgentRunner() + try: + # Expose both agents’ endpoints + runner.register_routes(sam) + await wait_for_shutdown() + finally: + runner.shutdown() + sam.stop() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass diff --git a/quickstarts/11-test-multi-agents/services/workflow-llm/app.py b/quickstarts/11-test-multi-agents/services/workflow-llm/app.py new file mode 100644 index 00000000..0876920b --- /dev/null +++ b/quickstarts/11-test-multi-agents/services/workflow-llm/app.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import asyncio +import logging +import os + +from dotenv import load_dotenv + +import dapr.ext.workflow as wf +from dapr_agents.agents.configs import ( + AgentPubSubConfig, + AgentRegistryConfig, +) +from dapr_agents.agents.orchestrators.llm.configs import LLMOrchestratorStateConfig +from dapr_agents.agents.orchestrators.llm import LLMOrchestrator +from dapr_agents.storage.daprstores.stateservice import StateStoreService +from dapr_agents.workflow.runners import AgentRunner +from dapr_agents.llm.openai import OpenAIChatClient +from dapr_agents.workflow.utils.core import wait_for_shutdown + + +load_dotenv() + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s [%(name)s] %(message)s", +) +logger = logging.getLogger("llm.orchestrator.app") + + +async def main() -> None: + """ + LLM Orchestrator application. + + This service hosts the LLMOrchestrator on a Dapr Workflow runtime. + It subscribes to the orchestrator topic for TriggerAction messages and + coordinates registered agents by selecting the next agent/step in the workflow. + """ + + # ------------------------------------------------------------------------- + # Config (env-overridable) + # ------------------------------------------------------------------------- + orchestrator_name = os.getenv("ORCHESTRATOR_NAME", "LLMOrchestrator") + team_name = os.getenv("TEAM_NAME", "fellowship") + + # Pub/Sub topics: orchestrator listens on orchestrator_topic for TriggerAction + pubsub_name = os.getenv("PUBSUB_NAME", "messagepubsub") + orchestrator_topic = os.getenv("ORCHESTRATOR_TOPIC", "llm.orchestrator.requests") + broadcast_topic = os.getenv("BROADCAST_TOPIC", "fellowship.broadcast") + + # (Optional) state & registry stores (by name, as configured in Dapr components) + workflow_state_store_name = os.getenv("WORKFLOW_STATE_STORE", "workflowstatestore") + registry_store_name = os.getenv("REGISTRY_STATE_STORE", "agentregistrystore") + + # Orchestrator behavior + max_iterations = int(os.getenv("MAX_ITERATIONS", "8")) + timeout_seconds = int(os.getenv("TIMEOUT_SECONDS", "45")) + + # LLM Provider + llm = OpenAIChatClient() + + # ------------------------------------------------------------------------- + # Pub/Sub, State, Registry wiring + # ------------------------------------------------------------------------- + pubsub_config = AgentPubSubConfig( + pubsub_name=pubsub_name, + agent_topic=orchestrator_topic, + broadcast_topic=broadcast_topic, + ) + + # Orchestrators often don’t persist workflow-local state; still allow it + # so you can extend later (metrics, audit, etc). + state_config = LLMOrchestratorStateConfig( + store=StateStoreService( + store_name=workflow_state_store_name, key_prefix="llm.orchestrator:" + ), + ) + + registry_config = AgentRegistryConfig( + store=StateStoreService(store_name=registry_store_name), + team_name=team_name, + ) + + # ------------------------------------------------------------------------- + # LLM Orchestrator instance + # ------------------------------------------------------------------------- + orchestrator = LLMOrchestrator( + name=orchestrator_name, + llm=llm, + pubsub_config=pubsub_config, + state_config=state_config, + registry_config=registry_config, + agent_metadata={ + "type": "LLMOrchestrator", + "description": "LLM-driven Orchestrator", + }, + max_iterations=max_iterations, + timeout_seconds=timeout_seconds, + runtime=wf.WorkflowRuntime(), # you can inject your own if needed + ) + + # Start workflow runtime + register workflows/activities + orchestrator.start() + + # ------------------------------------------------------------------------- + # HTTP runner (exposes workflow endpoints for Dapr to call) + # ------------------------------------------------------------------------- + runner = AgentRunner() + try: + runner.register_routes(orchestrator) + await wait_for_shutdown() + finally: + runner.shutdown() + orchestrator.stop() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass diff --git a/quickstarts/11-test-multi-agents/services/workflow-random/app.py b/quickstarts/11-test-multi-agents/services/workflow-random/app.py new file mode 100644 index 00000000..eb863750 --- /dev/null +++ b/quickstarts/11-test-multi-agents/services/workflow-random/app.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import asyncio +import logging +import os + +from dotenv import load_dotenv + +import dapr.ext.workflow as wf +from dapr_agents.agents.configs import ( + AgentPubSubConfig, + AgentRegistryConfig, + AgentStateConfig, +) +from dapr_agents.agents.orchestrators.random import RandomOrchestrator +from dapr_agents.storage.daprstores.stateservice import StateStoreService +from dapr_agents.workflow.runners import AgentRunner +from dapr_agents.workflow.utils.core import wait_for_shutdown + +# ----------------------------------------------------------------------------- +# Boot +# ----------------------------------------------------------------------------- +load_dotenv() + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s [%(name)s] %(message)s", +) +logger = logging.getLogger("fellowship.orchestrator.random.app") + + +async def main() -> None: + """ + Fellowship Orchestrator (Random) application. + + This service hosts the RandomOrchestrator on a Dapr Workflow runtime. + It subscribes to the orchestrator topic for TriggerAction messages and + coordinates registered agents (e.g., Frodo, Sam) by selecting one at random + each turn and routing replies back into the workflow via external events. + """ + + # ------------------------------------------------------------------------- + # Config (env-overridable) + # ------------------------------------------------------------------------- + orchestrator_name = os.getenv("ORCHESTRATOR_NAME", "FellowshipRandom") + team_name = os.getenv("TEAM_NAME", "fellowship") + + # Pub/Sub topics: orchestrator listens on agent_topic for TriggerAction + pubsub_name = os.getenv("PUBSUB_NAME", "messagepubsub") + orchestrator_topic = os.getenv( + "ORCHESTRATOR_TOPIC", "fellowship.orchestrator.random.requests" + ) + broadcast_topic = os.getenv("BROADCAST_TOPIC", "fellowship.broadcast") + + # (Optional) state & registry stores (by name, as configured in Dapr components) + workflow_state_store_name = os.getenv("WORKFLOW_STATE_STORE", "workflowstatestore") + registry_store_name = os.getenv("REGISTRY_STATE_STORE", "agentregistrystore") + + # Orchestrator behavior + max_iterations = int(os.getenv("MAX_ITERATIONS", "8")) + timeout_seconds = int(os.getenv("TIMEOUT_SECONDS", "45")) + + # ------------------------------------------------------------------------- + # Pub/Sub, State, Registry wiring + # ------------------------------------------------------------------------- + pubsub_config = AgentPubSubConfig( + pubsub_name=pubsub_name, + agent_topic=orchestrator_topic, # <-- RandomOrchestrator subscribes here + broadcast_topic=broadcast_topic, # <-- Optional (fanout to agents) + ) + + # Orchestrators often don’t persist workflow-local state; still allow it + # so you can extend later (metrics, audit, etc). + state_config = AgentStateConfig( + store=StateStoreService( + store_name=workflow_state_store_name, key_prefix="fellowship.random:" + ), + ) + + registry_config = AgentRegistryConfig( + store=StateStoreService(store_name=registry_store_name), + team_name=team_name, + ) + + # ------------------------------------------------------------------------- + # Orchestrator instance + # ------------------------------------------------------------------------- + orchestrator = RandomOrchestrator( + name=orchestrator_name, + pubsub_config=pubsub_config, + state_config=state_config, + registry_config=registry_config, + agent_metadata={"legend": "One orchestrator to guide them all."}, + max_iterations=max_iterations, + timeout_seconds=timeout_seconds, + runtime=wf.WorkflowRuntime(), # you can inject your own if needed + ) + + # Start workflow runtime + register workflows/activities + orchestrator.start() + + # ------------------------------------------------------------------------- + # HTTP runner (exposes workflow endpoints for Dapr to call) + # ------------------------------------------------------------------------- + runner = AgentRunner() + try: + runner.register_routes(orchestrator) + await wait_for_shutdown() + finally: + runner.shutdown() + orchestrator.stop() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass diff --git a/quickstarts/11-test-multi-agents/services/workflow-roundrobin/app.py b/quickstarts/11-test-multi-agents/services/workflow-roundrobin/app.py new file mode 100644 index 00000000..8d7625f9 --- /dev/null +++ b/quickstarts/11-test-multi-agents/services/workflow-roundrobin/app.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import asyncio +import logging +import os + +from dotenv import load_dotenv + +from dapr_agents.agents.configs import ( + AgentPubSubConfig, + AgentRegistryConfig, + AgentStateConfig, + AgentExecutionConfig, +) +from dapr_agents.agents.orchestrators.roundrobin import RoundRobinOrchestrator +from dapr_agents.storage.daprstores.stateservice import StateStoreService +from dapr_agents.workflow.runners import AgentRunner +from dapr_agents.workflow.utils.core import wait_for_shutdown + +# ----------------------------------------------------------------------------- +# Boot +# ----------------------------------------------------------------------------- +load_dotenv() + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s [%(name)s] %(message)s", +) +logger = logging.getLogger("fellowship.orchestrator.roundrobin.app") + + +async def main() -> None: + """ + Fellowship Orchestrator (RoundRobin). + + Hosts RoundRobinOrchestrator on a Dapr Workflow runtime. + Subscribes to the orchestrator topic for TriggerAction messages and + coordinates agents by selecting them in round-robin order each turn. + """ + + # ------------------------------------------------------------------------- + # Config (env-overridable) + # ------------------------------------------------------------------------- + orchestrator_name = os.getenv("ORCHESTRATOR_NAME", "FellowshipRoundRobin") + team_name = os.getenv("TEAM_NAME", "fellowship") + + # Pub/Sub topics: orchestrator listens on this topic for TriggerAction + pubsub_name = os.getenv("PUBSUB_NAME", "messagepubsub") + orchestrator_topic = os.getenv( + "ORCHESTRATOR_TOPIC", "fellowship.orchestrator.roundrobin.requests" + ) + broadcast_topic = os.getenv("BROADCAST_TOPIC", "fellowship.broadcast") + + # Dapr state components (by name) + workflow_state_store_name = os.getenv("WORKFLOW_STATE_STORE", "workflowstatestore") + registry_store_name = os.getenv("REGISTRY_STATE_STORE", "agentregistrystore") + + # Orchestrator behavior + max_iterations = int(os.getenv("MAX_ITERATIONS", "8")) + timeout_seconds = int(os.getenv("TIMEOUT_SECONDS", "45")) + + # ------------------------------------------------------------------------- + # Pub/Sub, State, Registry wiring + # ------------------------------------------------------------------------- + pubsub_config = AgentPubSubConfig( + pubsub_name=pubsub_name, + agent_topic=orchestrator_topic, # RoundRobinOrchestrator subscribes here + broadcast_topic=broadcast_topic, # optional fan-out to agents + ) + + state_config = AgentStateConfig( + store=StateStoreService( + store_name=workflow_state_store_name, key_prefix="fellowship.roundrobin:" + ), + ) + + registry_config = AgentRegistryConfig( + store=StateStoreService(store_name=registry_store_name), + team_name=team_name, + ) + + execution_config = AgentExecutionConfig(max_iterations=3) + + # ------------------------------------------------------------------------- + # Orchestrator instance + # ------------------------------------------------------------------------- + # Recommended: let the orchestrator OWN the runtime (don’t pass runtime=...). + orchestrator = RoundRobinOrchestrator( + name=orchestrator_name, + pubsub_config=pubsub_config, + state_config=state_config, + registry_config=registry_config, + agent_metadata={"pattern": "Round-robin selection of agents."}, + execution_config=execution_config, + timeout_seconds=timeout_seconds, + ) + + orchestrator.start() # registers and starts owned runtime + + # ------------------------------------------------------------------------- + # HTTP runner (exposes workflow endpoints for Dapr to call) + # ------------------------------------------------------------------------- + runner = AgentRunner() + try: + runner.register_routes(orchestrator) + await wait_for_shutdown() + finally: + runner.shutdown() + orchestrator.stop() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass diff --git a/tests/agents/agent/test_agent.py b/tests/agents/agent/test_agent.py index a42eadb8..34b93944 100644 --- a/tests/agents/agent/test_agent.py +++ b/tests/agents/agent/test_agent.py @@ -2,7 +2,8 @@ import asyncio import os from unittest.mock import Mock, patch -from dapr_agents.agents.agent.agent import Agent +from dapr_agents.agents.standalone import Agent +from dapr_agents.agents.configs import AgentMemoryConfig, AgentExecutionConfig from dapr_agents.types import ( AgentError, AssistantMessage, @@ -46,8 +47,8 @@ def basic_agent(self, mock_llm): goal="Help with testing", instructions=["Be helpful", "Test things"], llm=mock_llm, - memory=ConversationListMemory(), - max_iterations=5, + memory_config=AgentMemoryConfig(store=ConversationListMemory()), + execution_config=AgentExecutionConfig(max_iterations=5), ) @pytest.fixture @@ -59,9 +60,9 @@ def agent_with_tools(self, mock_llm): goal="Execute tools", instructions=["Use tools when needed"], llm=mock_llm, - memory=ConversationListMemory(), + memory_config=AgentMemoryConfig(store=ConversationListMemory()), tools=[echo_tool], - max_iterations=5, + execution_config=AgentExecutionConfig(max_iterations=5), ) def test_agent_initialization(self, mock_llm): @@ -76,12 +77,14 @@ def test_agent_initialization(self, mock_llm): ) assert agent.name == "TestAgent" - assert agent.role == "Test Assistant" - assert agent.goal == "Help with testing" - assert agent.instructions == ["Be helpful"] - assert agent.max_iterations == 10 # default value + assert agent.prompting_helper.role == "Test Assistant" + assert agent.prompting_helper.goal == "Help with testing" + assert agent.prompting_helper.instructions == ["Be helpful"] + assert agent.execution_config.max_iterations == 10 # default value assert agent.tool_history == [] - assert agent.tool_choice == "auto" # auto when tools are provided + assert ( + agent.execution_config.tool_choice == "auto" + ) # auto when tools are provided def test_agent_initialization_without_tools(self, mock_llm): """Test agent initialization without tools.""" @@ -92,7 +95,7 @@ def test_agent_initialization_without_tools(self, mock_llm): llm=mock_llm, ) - assert agent.tool_choice is None + assert agent.execution_config.tool_choice == "auto" # still defaults to auto def test_agent_initialization_with_custom_tool_choice(self, mock_llm): """Test agent initialization with custom tool choice.""" @@ -101,10 +104,10 @@ def test_agent_initialization_with_custom_tool_choice(self, mock_llm): role="Test Assistant", goal="Help with testing", llm=mock_llm, - tool_choice="required", + execution_config=AgentExecutionConfig(tool_choice="required"), ) - assert agent.tool_choice == "required" + assert agent.execution_config.tool_choice == "required" @pytest.mark.asyncio async def test_run_with_shutdown_event(self, basic_agent): @@ -140,7 +143,9 @@ async def test_run_agent_basic(self, basic_agent): mock_response.get_message.return_value = assistant_msg basic_agent.llm.generate.return_value = mock_response - result = await basic_agent._run_agent("Hello") + result = await basic_agent._run_agent( + input_data="Hello", instance_id="test-123" + ) assert isinstance(result, AssistantMessage) assert result.content == "Hello!" @@ -154,6 +159,7 @@ async def test_run_agent_with_tool_calls(self, agent_with_tools): mock_function.arguments_dict = {"arg1": "value1"} tool_call = Mock(spec=ToolCall) tool_call.id = "call_123" + tool_call.type = "function" tool_call.function = mock_function first_response = Mock(spec=LLMChatResponse) @@ -166,15 +172,16 @@ async def test_run_agent_with_tool_calls(self, agent_with_tools): agent_with_tools.llm.generate.side_effect = [first_response, second_response] agent_with_tools.tools = [echo_tool] - agent_with_tools._tool_executor = agent_with_tools._tool_executor.__class__( + agent_with_tools.tool_executor = agent_with_tools.tool_executor.__class__( tools=[echo_tool] ) - result = await agent_with_tools._run_agent("Use the tool") + result = await agent_with_tools._run_agent( + input_data="Use the tool", instance_id="test-123" + ) assert isinstance(result, AssistantMessage) assert result.content == "Final answer" - @pytest.mark.asyncio async def test_process_response_success(self, agent_with_tools): """Test successful tool execution.""" mock_function = Mock() @@ -185,17 +192,26 @@ async def test_process_response_success(self, agent_with_tools): tool_call.id = "call_123" tool_call.function = mock_function agent_with_tools.tools = [echo_tool] - agent_with_tools._tool_executor = agent_with_tools._tool_executor.__class__( + agent_with_tools.tool_executor = agent_with_tools.tool_executor.__class__( tools=[echo_tool] ) - await agent_with_tools.execute_tools([tool_call]) + + # Call the actual internal method that executes tool calls + tool_messages = await agent_with_tools._execute_tool_calls( + "test-instance", [tool_call] + ) + assert len(agent_with_tools.tool_history) == 1 - tool_message = agent_with_tools.tool_history[0] - assert tool_message.tool_call_id == "call_123" - assert tool_message.tool_name == echo_tool.name - assert tool_message.execution_result == "value1" + tool_record = agent_with_tools.tool_history[0] + assert tool_record.tool_call_id == "call_123" + assert tool_record.tool_name == echo_tool.name + assert tool_record.execution_result == "value1" + + # Verify the tool message was returned + assert len(tool_messages) == 1 + assert tool_messages[0]["role"] == "tool" + assert tool_messages[0]["name"] == echo_tool.name - @pytest.mark.asyncio async def test_process_response_failure(self, agent_with_tools): mock_function = Mock() mock_function.name = error_tool.name @@ -205,15 +221,16 @@ async def test_process_response_failure(self, agent_with_tools): tool_call.id = "call_123" tool_call.function = mock_function agent_with_tools.tools = [error_tool] - agent_with_tools._tool_executor = agent_with_tools._tool_executor.__class__( + agent_with_tools.tool_executor = agent_with_tools.tool_executor.__class__( tools=[error_tool] ) + + # Call the actual internal method that executes tool calls with pytest.raises( AgentError, match=f"Error executing tool '{error_tool.name}': .*Tool failed" ): - await agent_with_tools.execute_tools([tool_call]) + await agent_with_tools._execute_tool_calls("test-instance", [tool_call]) - @pytest.mark.asyncio async def test_conversation_max_reached(self, basic_agent): """Test that agent stops immediately when there are no tool calls.""" mock_response = Mock(spec=LLMChatResponse) @@ -221,44 +238,73 @@ async def test_conversation_max_reached(self, basic_agent): mock_response.get_message.return_value = assistant_msg basic_agent.llm.generate.return_value = mock_response - result = await basic_agent.conversation([]) + # Call the actual internal conversation loop method + initial_messages = [{"role": "user", "content": "Hello"}] + result = await basic_agent._conversation_loop( + instance_id="test-123", messages=initial_messages + ) # current logic sees no tools ===> returns on first iteration assert isinstance(result, AssistantMessage) assert result.content == "Using tool" assert basic_agent.llm.generate.call_count == 1 - @pytest.mark.asyncio async def test_conversation_with_llm_error(self, basic_agent): """Test handling of LLM errors during iterations.""" basic_agent.llm.generate.side_effect = Exception("LLM error") + # Call the actual internal conversation loop method + initial_messages = [{"role": "user", "content": "Hello"}] with pytest.raises( AgentError, match="Failed during chat generation: LLM error" ): - await basic_agent.conversation([]) + await basic_agent._conversation_loop( + instance_id="test-123", messages=initial_messages + ) - @pytest.mark.asyncio async def test_run_tool_success(self, agent_with_tools): - """Test successful tool execution via run_tool method.""" + """Test successful tool execution via _run_tool_call method.""" agent_with_tools.tools = [echo_tool] - agent_with_tools._tool_executor = agent_with_tools._tool_executor.__class__( + agent_with_tools.tool_executor = agent_with_tools.tool_executor.__class__( tools=[echo_tool] ) - result = await agent_with_tools.run_tool(echo_tool.name, arg1="value1") - assert result == "value1" - @pytest.mark.asyncio + # Create a mock tool call + mock_function = Mock() + mock_function.name = echo_tool.name + mock_function.arguments_dict = {"arg1": "value1"} + tool_call = Mock(spec=ToolCall) + tool_call.id = "call_123" + tool_call.function = mock_function + + # Call the actual internal method + result = await agent_with_tools._run_tool_call("test-instance", tool_call) + + # Verify the result is a tool message dict + assert result["role"] == "tool" + assert result["name"] == echo_tool.name + assert result["content"] == "value1" + async def test_run_tool_failure(self, agent_with_tools): - """Test tool execution failure via run_tool method.""" + """Test tool execution failure via _run_tool_call method.""" agent_with_tools.tools = [error_tool] - agent_with_tools._tool_executor = agent_with_tools._tool_executor.__class__( + agent_with_tools.tool_executor = agent_with_tools.tool_executor.__class__( tools=[error_tool] ) + + # Create a mock tool call + mock_function = Mock() + mock_function.name = error_tool.name + mock_function.arguments_dict = {} + tool_call = Mock(spec=ToolCall) + tool_call.id = "call_123" + tool_call.function = mock_function + + # Call the actual internal method with pytest.raises( - AgentError, match=f"Failed to run tool '{error_tool.name}': .*Tool failed" + AgentError, match=f"Error executing tool '{error_tool.name}': .*Tool failed" ): - await agent_with_tools.run_tool(error_tool.name) + await agent_with_tools._run_tool_call("test-instance", tool_call) def test_agent_properties(self, basic_agent): """Test agent properties.""" @@ -275,7 +321,7 @@ async def test_agent_with_memory_context(self, basic_agent): mock_response.get_message.return_value = assistant_msg basic_agent.llm.generate.return_value = mock_response - result = await basic_agent._run_agent(None) + result = await basic_agent._run_agent(input_data=None, instance_id="test-123") assert isinstance(result, AssistantMessage) assert result.content == "Response" diff --git a/tests/agents/durableagent/test_durable_agent.py b/tests/agents/durableagent/test_durable_agent.py index 90d10d5b..c47ca1c1 100644 --- a/tests/agents/durableagent/test_durable_agent.py +++ b/tests/agents/durableagent/test_durable_agent.py @@ -2,106 +2,49 @@ # Right now we have to do a bunch of patching at the class-level instead of patching at the instance-level. # In future, we should do dependency injection instead of patching at the class-level to make it easier to test. # This applies to all areas in this file where we have with patch.object()... -import asyncio import os -from typing import Any from unittest.mock import AsyncMock, Mock, patch, MagicMock import pytest from dapr.ext.workflow import DaprWorkflowContext -from dapr_agents.agents.durableagent.agent import DurableAgent -from dapr_agents.agents.durableagent.schemas import ( - AgentTaskResponse, - BroadcastMessage, +from dapr_agents.agents.durable import DurableAgent +from dapr_agents.agents.configs import ( + AgentPubSubConfig, + AgentStateConfig, + AgentRegistryConfig, + AgentMemoryConfig, + AgentExecutionConfig, ) -from dapr_agents.agents.durableagent.state import ( - DurableAgentMessage, - DurableAgentWorkflowEntry, - DurableAgentWorkflowState, +from dapr_agents.agents.schemas import ( + AgentWorkflowMessage, + AgentWorkflowEntry, + AgentWorkflowState, ) from dapr_agents.llm import OpenAIChatClient from dapr_agents.memory import ConversationDaprStateMemory +from dapr_agents.storage.daprstores.stateservice import StateStoreService from dapr_agents.tool.base import AgentTool from dapr_agents.types import ( AssistantMessage, LLMChatCandidate, LLMChatResponse, - ToolExecutionRecord, - ToolMessage, ) # We need this otherwise these tests all fail since they require Dapr to be available. @pytest.fixture(autouse=True) def patch_dapr_check(monkeypatch): + """Mock Dapr dependencies to prevent requiring a running Dapr instance.""" from unittest.mock import Mock + import dapr.ext.workflow as wf - from dapr_agents.workflow import agentic, base + # Mock WorkflowRuntime to prevent Dapr checks + mock_runtime = Mock(spec=wf.WorkflowRuntime) + monkeypatch.setattr(wf, "WorkflowRuntime", lambda: mock_runtime) - # Mock the WorkflowApp initialization to prevent DaprClient creation which does an internal check for Dapr availability. - def mock_workflow_app_post_init(self, __context: Any) -> None: - self.wf_runtime = Mock() - self.wf_runtime_is_running = False - self.wf_client = Mock() - self.client = Mock() - self.tasks = {} - self.workflows["AgenticWorkflow"] = getattr(self, "tool_calling_workflow", None) - - try: - super(base.WorkflowApp, self).model_post_init(__context) - except AttributeError: - # If parent doesn't have model_post_init, that's fine - pass - - monkeypatch.setattr( - base.WorkflowApp, "model_post_init", mock_workflow_app_post_init - ) - - def mock_agentic_post_init(self, __context: Any) -> None: - self._text_formatter = Mock() - self.client = Mock() - self._state_store_client = Mock() - # Configure the mock to return a tuple as expected by try_get_state - self._state_store_client.try_get_state.return_value = (False, None) - # Configure the mock for save_state method - self._state_store_client.save_state.return_value = None - self._agent_metadata = { - "name": getattr(self, "name", "TestAgent"), - "role": getattr(self, "role", "Test Role"), - "goal": getattr(self, "goal", "Test Goal"), - "instructions": getattr(self, "instructions", []), - "topic_name": getattr( - self, "agent_topic_name", getattr(self, "name", "TestAgent") - ), - "pubsub_name": getattr(self, "message_bus_name", "testpubsub"), - "orchestrator": False, - } - self._workflow_name = "AgenticWorkflow" - self._is_running = False - self._shutdown_event = asyncio.Event() - self._subscriptions = {} - self._topic_handlers = {} - - if not hasattr(self, "state") or self.state is None: - self.state = DurableAgentWorkflowState().model_dump() - - # Call the WorkflowApp model_post_init which we have mocked above. - super(agentic.AgenticWorkflow, self).model_post_init(__context) - - monkeypatch.setattr( - agentic.AgenticWorkflow, "model_post_init", mock_agentic_post_init - ) - - # No-op for testing - def mock_register_agentic_system(self): - pass - - monkeypatch.setattr( - agentic.AgenticWorkflow, "register_agentic_system", mock_register_agentic_system - ) - - yield + # Return the mock runtime for tests that need it + yield mock_runtime class MockDaprClient: @@ -188,13 +131,22 @@ def basic_durable_agent(self, mock_llm): goal="Help with testing", instructions=["Be helpful", "Test things"], llm=mock_llm, - memory=ConversationDaprStateMemory( - store_name="teststatestore", session_id="test_session" + pubsub_config=AgentPubSubConfig( + pubsub_name="testpubsub", + agent_topic="TestDurableAgent", + ), + state_config=AgentStateConfig( + store=StateStoreService(store_name="teststatestore") + ), + registry_config=AgentRegistryConfig( + store=StateStoreService(store_name="testregistry") ), - max_iterations=5, - state_store_name="teststatestore", - message_bus_name="testpubsub", - agents_registry_store_name="testregistry", + memory_config=AgentMemoryConfig( + store=ConversationDaprStateMemory( + store_name="teststatestore", session_id="test_session" + ) + ), + execution_config=AgentExecutionConfig(max_iterations=5), ) @pytest.fixture @@ -206,14 +158,23 @@ def durable_agent_with_tools(self, mock_llm, mock_tool): goal="Execute tools", instructions=["Use tools when needed"], llm=mock_llm, - memory=ConversationDaprStateMemory( - store_name="teststatestore", session_id="test_session" + pubsub_config=AgentPubSubConfig( + pubsub_name="testpubsub", + agent_topic="ToolDurableAgent", + ), + state_config=AgentStateConfig( + store=StateStoreService(store_name="teststatestore") + ), + registry_config=AgentRegistryConfig( + store=StateStoreService(store_name="testregistry") + ), + memory_config=AgentMemoryConfig( + store=ConversationDaprStateMemory( + store_name="teststatestore", session_id="test_session" + ) ), tools=[mock_tool], - max_iterations=5, - state_store_name="teststatestore", - message_bus_name="testpubsub", - agents_registry_store_name="testregistry", + execution_config=AgentExecutionConfig(max_iterations=5), ) def test_durable_agent_initialization(self, mock_llm): @@ -224,23 +185,26 @@ def test_durable_agent_initialization(self, mock_llm): goal="Help with testing", instructions=["Be helpful"], llm=mock_llm, - state_store_name="teststatestore", - message_bus_name="testpubsub", - agents_registry_store_name="testregistry", + pubsub_config=AgentPubSubConfig( + pubsub_name="testpubsub", + agent_topic="TestDurableAgent", + ), + state_config=AgentStateConfig( + store=StateStoreService(store_name="teststatestore") + ), + registry_config=AgentRegistryConfig( + store=StateStoreService(store_name="testregistry") + ), ) assert agent.name == "TestDurableAgent" - assert agent.role == "Test Durable Assistant" - assert agent.goal == "Help with testing" - assert agent.instructions == ["Be helpful"] - assert agent.max_iterations == 10 # default value + assert agent.prompting_helper.role == "Test Durable Assistant" + assert agent.prompting_helper.goal == "Help with testing" + assert agent.prompting_helper.instructions == ["Be helpful"] + assert agent.execution_config.max_iterations == 10 # default value assert agent.tool_history == [] - assert agent.state_store_name == "teststatestore" - assert agent.message_bus_name == "testpubsub" - assert agent.agent_topic_name == "TestDurableAgent" - assert agent.state is not None - validated_state = DurableAgentWorkflowState.model_validate(agent.state) - assert isinstance(validated_state, DurableAgentWorkflowState) + assert agent.pubsub_config.pubsub_name == "testpubsub" + assert agent.pubsub_config.agent_topic == "TestDurableAgent" def test_durable_agent_initialization_with_custom_topic(self, mock_llm): """Test durable agent initialization with custom topic name.""" @@ -249,31 +213,45 @@ def test_durable_agent_initialization_with_custom_topic(self, mock_llm): role="Test Durable Assistant", goal="Help with testing", llm=mock_llm, - agent_topic_name="custom-topic", - state_store_name="teststatestore", - message_bus_name="testpubsub", - agents_registry_store_name="testregistry", + pubsub_config=AgentPubSubConfig( + pubsub_name="testpubsub", + agent_topic="custom-topic", + ), + state_config=AgentStateConfig( + store=StateStoreService(store_name="teststatestore") + ), + registry_config=AgentRegistryConfig( + store=StateStoreService(store_name="testregistry") + ), ) - assert agent.agent_topic_name == "custom-topic" + assert agent.pubsub_config.agent_topic == "custom-topic" def test_durable_agent_initialization_name_from_role(self, mock_llm): """Test durable agent initialization with name derived from role.""" agent = DurableAgent( + name="Test Durable Assistant", role="Test Durable Assistant", goal="Help with testing", llm=mock_llm, - state_store_name="teststatestore", - message_bus_name="testpubsub", - agents_registry_store_name="testregistry", + pubsub_config=AgentPubSubConfig( + pubsub_name="testpubsub", + agent_topic="Test Durable Assistant", + ), + state_config=AgentStateConfig( + store=StateStoreService(store_name="teststatestore") + ), + registry_config=AgentRegistryConfig( + store=StateStoreService(store_name="testregistry") + ), ) assert agent.name == "Test Durable Assistant" - assert agent.agent_topic_name == "Test Durable Assistant" + assert agent.pubsub_config.agent_topic == "Test Durable Assistant" def test_durable_agent_metadata(self, basic_durable_agent): """Test durable agent metadata creation.""" - metadata = basic_durable_agent._agent_metadata + metadata = basic_durable_agent.agent_metadata assert metadata is not None assert metadata["name"] == "TestDurableAgent" @@ -283,23 +261,7 @@ def test_durable_agent_metadata(self, basic_durable_agent): assert metadata["pubsub_name"] == "testpubsub" assert metadata["orchestrator"] is False - @pytest.fixture - def mock_wf_client(self): - client = Mock() - client.wait_for_workflow_completion.return_value.serialized_output = { - "output": "test" - } - return client - - @pytest.mark.asyncio - async def test_run_method(self, basic_durable_agent, mock_wf_client): - """Test the run method returns the workflow result from the injected mock client.""" - basic_durable_agent.wf_client = mock_wf_client - result = await basic_durable_agent.run("test input") - assert result == {"output": "test"} - - @pytest.mark.asyncio - async def test_tool_calling_workflow_initialization( + def test_tool_calling_workflow_initialization( self, basic_durable_agent, mock_workflow_context ): """Test workflow initialization on first iteration.""" @@ -316,32 +278,33 @@ async def test_tool_calling_workflow_initialization( "stop", ] - basic_durable_agent.state["instances"]["test-instance-123"] = { - "input": "Test task", - "source": None, - "triggering_workflow_instance_id": "parent-instance-123", - "workflow_instance_id": "test-instance-123", - "workflow_name": "AgenticWorkflow", - "status": "RUNNING", - "messages": [], - "tool_history": [], - "end_time": None, - "trace_context": None, - } + # Use AgentWorkflowEntry for state setup + entry = AgentWorkflowEntry( + input_value="Test task", + source=None, + triggering_workflow_instance_id="parent-instance-123", + workflow_instance_id="test-instance-123", + workflow_name="AgenticWorkflow", + status="RUNNING", + messages=[], + tool_history=[], + ) + basic_durable_agent._state_model.instances["test-instance-123"] = entry - workflow_gen = basic_durable_agent.tool_calling_workflow( + workflow_gen = basic_durable_agent.agent_workflow( mock_workflow_context, message ) try: - await workflow_gen.__next__() - except StopAsyncIteration: + next(workflow_gen) # agent_workflow is a generator, not async + except StopIteration: pass assert "test-instance-123" in basic_durable_agent.state["instances"] - instance_data = basic_durable_agent.state["instances"]["test-instance-123"] - assert instance_data["input"] == "Test task" - assert instance_data["source"] is None - assert instance_data["triggering_workflow_instance_id"] == "parent-instance-123" + instance_data = basic_durable_agent._state_model.instances["test-instance-123"] + # Instance data is an AgentWorkflowEntry object + assert instance_data.input_value == "Test task" + assert instance_data.source is None + assert instance_data.triggering_workflow_instance_id == "parent-instance-123" @pytest.mark.asyncio async def test_call_llm_activity(self, basic_durable_agent): @@ -381,8 +344,17 @@ async def test_call_llm_activity(self, basic_durable_agent): test_time = datetime.fromisoformat( "2024-01-01T00:00:00Z".replace("Z", "+00:00") ) - assistant_dict = await basic_durable_agent.call_llm( - instance_id, test_time, "Test task" + + # Mock the activity context + mock_ctx = Mock() + + assistant_dict = basic_durable_agent.call_llm( + mock_ctx, + { + "instance_id": instance_id, + "time": test_time.isoformat(), + "task": "Test task", + }, ) # The dict dumped from AssistantMessage should have our content assert assistant_dict["content"] == "Test response" @@ -398,11 +370,13 @@ async def test_broadcast_message_to_agents_activity(self, basic_durable_agent): "sender": "TestDurableAgent", } - with patch.object( - type(basic_durable_agent), "broadcast_message" - ) as mock_broadcast: - await basic_durable_agent.broadcast_message_to_agents(message) - mock_broadcast.assert_called_once() + # Mock the activity context + mock_ctx = Mock() + + # The basic_durable_agent fixture doesn't have a broadcast_topic configured, + # so this should execute without error but skip the actual broadcast + basic_durable_agent.broadcast_message_to_agents(mock_ctx, {"message": message}) + # Test passes if no exception is raised @pytest.mark.asyncio async def test_send_response_back_activity(self, basic_durable_agent): @@ -411,44 +385,67 @@ async def test_send_response_back_activity(self, basic_durable_agent): target_agent = "TargetAgent" target_instance_id = "target-instance-123" - with patch.object( - type(basic_durable_agent), "send_message_to_agent" - ) as mock_send: - await basic_durable_agent.send_response_back( - response, target_agent, target_instance_id + # Mock the activity context and _run_asyncio_task + mock_ctx = Mock() + + with patch.object(basic_durable_agent, "_run_asyncio_task") as mock_run_task: + basic_durable_agent.send_response_back( + mock_ctx, + { + "response": response, + "target_agent": target_agent, + "target_instance_id": target_instance_id, + }, ) - mock_send.assert_called_once() + # Verify the async task was called + mock_run_task.assert_called_once() @pytest.mark.asyncio async def test_finish_workflow_activity(self, basic_durable_agent): """Test finishing workflow activity.""" + from datetime import datetime, timezone + instance_id = "test-instance-123" final_output = "Final response" - basic_durable_agent.state["instances"] = { - instance_id: { - "input": "Test task", - "source": "test_source", - "triggering_workflow_instance_id": None, - "workflow_instance_id": instance_id, - "workflow_name": "AgenticWorkflow", - "status": "RUNNING", - "messages": [], - "tool_history": [], - "end_time": None, - "trace_context": None, - } - } - - basic_durable_agent.finalize_workflow( - instance_id, final_output, "2024-01-01T00:00:00Z" + # Set up state in the state model using AgentWorkflowEntry + if not hasattr(basic_durable_agent._state_model, "instances"): + basic_durable_agent._state_model.instances = {} + + basic_durable_agent._state_model.instances[instance_id] = AgentWorkflowEntry( + input_value="Test task", + source="test_source", + triggering_workflow_instance_id=None, + workflow_instance_id=instance_id, + workflow_name="AgenticWorkflow", + status="RUNNING", + messages=[], + tool_history=[], + end_time=None, + start_time=datetime.now(timezone.utc), ) - instance_data = basic_durable_agent.state["instances"][instance_id] - assert instance_data["output"] == final_output - assert instance_data["end_time"] is not None + + # Mock the activity context and save_state + mock_ctx = Mock() + + with patch.object(basic_durable_agent, "save_state"): + basic_durable_agent.finalize_workflow( + mock_ctx, + { + "instance_id": instance_id, + "final_output": final_output, + "end_time": "2024-01-01T00:00:00Z", + "triggering_workflow_instance_id": None, + }, + ) + entry = basic_durable_agent._state_model.instances[instance_id] + assert entry.output == final_output + assert entry.end_time is not None @pytest.mark.asyncio async def test_run_tool(self, basic_durable_agent, mock_tool): """Test that run_tool atomically executes and persists tool results.""" + from datetime import datetime, timezone + instance_id = "test-instance-123" tool_call = { "id": "call_123", @@ -457,34 +454,46 @@ async def test_run_tool(self, basic_durable_agent, mock_tool): # Mock the tool executor with patch.object( - type(basic_durable_agent._tool_executor), "run_tool", new_callable=AsyncMock + type(basic_durable_agent.tool_executor), "run_tool", new_callable=AsyncMock ) as mock_run_tool: mock_run_tool.return_value = "tool_result" - # Set up instance state - basic_durable_agent.state["instances"] = { - instance_id: { - "input": "Test task", - "source": "test_source", - "triggering_workflow_instance_id": None, - "workflow_instance_id": instance_id, - "workflow_name": "AgenticWorkflow", - "status": "RUNNING", - "messages": [], - "tool_history": [], - "end_time": None, - "trace_context": None, - } - } - - from datetime import datetime + # Set up state in the state model using AgentWorkflowEntry + if not hasattr(basic_durable_agent._state_model, "instances"): + basic_durable_agent._state_model.instances = {} + + basic_durable_agent._state_model.instances[ + instance_id + ] = AgentWorkflowEntry( + input_value="Test task", + source="test_source", + triggering_workflow_instance_id=None, + workflow_instance_id=instance_id, + workflow_name="AgenticWorkflow", + status="RUNNING", + messages=[], + tool_history=[], + end_time=None, + start_time=datetime.now(timezone.utc), + ) test_time = datetime.fromisoformat( "2024-01-01T00:00:00Z".replace("Z", "+00:00") ) - result = await basic_durable_agent.run_tool( - tool_call, instance_id, test_time - ) + + # Mock the activity context and save_state + mock_ctx = Mock() + + with patch.object(basic_durable_agent, "save_state"): + result = await basic_durable_agent.run_tool( + mock_ctx, + { + "tool_call": tool_call, + "instance_id": instance_id, + "time": test_time.isoformat(), + "order": 1, + }, + ) # Verify tool was executed and result was returned assert result["tool_call_id"] == "call_123" @@ -492,60 +501,60 @@ async def test_run_tool(self, basic_durable_agent, mock_tool): assert result["execution_result"] == "tool_result" # Verify state was updated atomically - instance_data = basic_durable_agent.state["instances"][instance_id] - assert len(instance_data["messages"]) == 1 # Tool message added - assert ( - len(instance_data["tool_history"]) == 1 - ) # Tool execution record added + entry = basic_durable_agent._state_model.instances[instance_id] + assert len(entry.messages) == 1 # Tool message added + assert len(entry.tool_history) == 1 # Tool execution record added # Verify tool execution record in tool_history - tool_history_entry = instance_data["tool_history"][0] - assert tool_history_entry["tool_call_id"] == "call_123" - assert tool_history_entry["tool_name"] == "test_tool" - assert tool_history_entry["execution_result"] == "tool_result" - - # Verify agent-level tool_history was also updated - assert len(basic_durable_agent.tool_history) == 1 - - def test_get_source_or_default(self, basic_durable_agent): - """Test get_source_or_default helper method.""" - # Test with valid source - assert basic_durable_agent.get_source_or_default("test_source") == "test_source" - - # Test with None source - assert basic_durable_agent.get_source_or_default(None) == "direct" - - # Test with empty string - assert basic_durable_agent.get_source_or_default("") == "direct" + tool_history_entry = entry.tool_history[0] + assert tool_history_entry.tool_call_id == "call_123" + assert tool_history_entry.tool_name == "test_tool" + assert tool_history_entry.execution_result == "tool_result" def test_record_initial_entry(self, basic_durable_agent): """Test record_initial_entry helper method.""" + from datetime import datetime, timezone + instance_id = "test-instance-123" input_data = "Test task" source = "test_source" triggering_workflow_instance_id = "parent-instance-123" start_time = "2024-01-01T00:00:00Z" - basic_durable_agent.record_initial_entry( - instance_id, input_data, source, triggering_workflow_instance_id, start_time + # First, ensure instance exists with ensure_instance_exists + basic_durable_agent.ensure_instance_exists( + instance_id=instance_id, + input_value=input_data, + triggering_workflow_instance_id=None, + time=datetime.now(timezone.utc), ) - # Verify instance was created - assert instance_id in basic_durable_agent.state["instances"] - instance_data = basic_durable_agent.state["instances"][instance_id] - assert instance_data["input"] == input_data - assert instance_data["source"] == source - assert ( - instance_data["triggering_workflow_instance_id"] - == triggering_workflow_instance_id - ) - # start_time is stored as string in dict format - assert instance_data["start_time"] == "2024-01-01T00:00:00Z" - assert instance_data["workflow_name"] == "AgenticWorkflow" - assert instance_data["status"] == "running" + # Mock the activity context + mock_ctx = Mock() + + with patch.object(basic_durable_agent, "save_state"): + basic_durable_agent.record_initial_entry( + mock_ctx, + { + "instance_id": instance_id, + "input_value": input_data, + "source": source, + "triggering_workflow_instance_id": triggering_workflow_instance_id, + "start_time": start_time, + "trace_context": None, + }, + ) + + # Verify instance was updated + assert instance_id in basic_durable_agent._state_model.instances + entry = basic_durable_agent._state_model.instances[instance_id] + assert entry.input_value == input_data + assert entry.source == source + assert entry.triggering_workflow_instance_id == triggering_workflow_instance_id + assert entry.status.lower() == "running" def test_ensure_instance_exists(self, basic_durable_agent): - """Test _ensure_instance_exists helper method.""" + """Test ensure_instance_exists helper method.""" instance_id = "test-instance-123" triggering_workflow_instance_id = "parent-instance-123" time = "2024-01-01T00:00:00Z" @@ -554,308 +563,379 @@ def test_ensure_instance_exists(self, basic_durable_agent): from datetime import datetime test_time = datetime.fromisoformat(time.replace("Z", "+00:00")) - basic_durable_agent._ensure_instance_exists( - instance_id, "Test input", triggering_workflow_instance_id, test_time + basic_durable_agent.ensure_instance_exists( + instance_id=instance_id, + input_value="Test input", + triggering_workflow_instance_id=triggering_workflow_instance_id, + time=test_time, ) - assert instance_id in basic_durable_agent.state["instances"] - instance_data = basic_durable_agent.state["instances"][instance_id] - assert ( - instance_data["triggering_workflow_instance_id"] - == triggering_workflow_instance_id - ) - # start_time is stored as string in dict format - assert instance_data["start_time"] == "2024-01-01T00:00:00+00:00" - assert instance_data["workflow_name"] == "AgenticWorkflow" + assert instance_id in basic_durable_agent._state_model.instances + entry = basic_durable_agent._state_model.instances[instance_id] + assert entry.triggering_workflow_instance_id == triggering_workflow_instance_id + assert entry.start_time == test_time + assert entry.workflow_name is None # Default entry doesn't set workflow_name # Test that existing instance is not overwritten original_input = "Original input" - basic_durable_agent.state["instances"][instance_id]["input"] = original_input + entry.input_value = original_input - basic_durable_agent._ensure_instance_exists( - instance_id, "different-parent", "2024-01-02T00:00:00Z" + basic_durable_agent.ensure_instance_exists( + instance_id=instance_id, + input_value="New input", + triggering_workflow_instance_id="different-parent", + time=datetime.fromisoformat("2024-01-02T00:00:00Z".replace("Z", "+00:00")), ) - # Input should remain unchanged - assert ( - basic_durable_agent.state["instances"][instance_id]["input"] - == original_input - ) + # Input should remain unchanged (ensure_instance_exists doesn't overwrite) + entry = basic_durable_agent._state_model.instances[instance_id] + assert entry.input_value == original_input def test_process_user_message(self, basic_durable_agent): """Test _process_user_message helper method.""" + from datetime import datetime, timezone + instance_id = "test-instance-123" task = "Hello, world!" user_message_copy = {"role": "user", "content": "Hello, world!"} - # Set up instance - basic_durable_agent.state["instances"][instance_id] = { - "input": "Test task", - "source": "test_source", - "triggering_workflow_instance_id": None, - "workflow_instance_id": instance_id, - "workflow_name": "AgenticWorkflow", - "status": "RUNNING", - "messages": [], - "tool_history": [], - "end_time": None, - "trace_context": None, - } + # Set up instance using AgentWorkflowEntry + if not hasattr(basic_durable_agent._state_model, "instances"): + basic_durable_agent._state_model.instances = {} + + basic_durable_agent._state_model.instances[instance_id] = AgentWorkflowEntry( + input_value="Test task", + source="test_source", + triggering_workflow_instance_id=None, + workflow_instance_id=instance_id, + workflow_name="AgenticWorkflow", + status="RUNNING", + messages=[], + tool_history=[], + end_time=None, + start_time=datetime.now(timezone.utc), + ) - # Mock memory.add_message + # Mock memory.add_message and save_state with patch.object(type(basic_durable_agent.memory), "add_message"): - basic_durable_agent._process_user_message( - instance_id, task, user_message_copy - ) + with patch.object(basic_durable_agent, "save_state"): + basic_durable_agent._process_user_message( + instance_id, task, user_message_copy + ) # Verify message was added to instance - instance_data = basic_durable_agent.state["instances"][instance_id] - assert len(instance_data["messages"]) == 1 - assert instance_data["messages"][0]["role"] == "user" - assert instance_data["messages"][0]["content"] == "Hello, world!" - assert instance_data["last_message"]["role"] == "user" + entry = basic_durable_agent._state_model.instances[instance_id] + assert len(entry.messages) == 1 + assert entry.messages[0].role == "user" + assert entry.messages[0].content == "Hello, world!" + assert entry.last_message.role == "user" def test_save_assistant_message(self, basic_durable_agent): """Test _save_assistant_message helper method.""" + from datetime import datetime, timezone + instance_id = "test-instance-123" assistant_message = {"role": "assistant", "content": "Hello back!"} - # Set up instance - basic_durable_agent.state["instances"][instance_id] = { - "input": "Test task", - "source": "test_source", - "triggering_workflow_instance_id": None, - "workflow_instance_id": instance_id, - "workflow_name": "AgenticWorkflow", - "status": "RUNNING", - "messages": [], - "tool_history": [], - "end_time": None, - "trace_context": None, - } + # Set up instance using AgentWorkflowEntry + if not hasattr(basic_durable_agent._state_model, "instances"): + basic_durable_agent._state_model.instances = {} + + basic_durable_agent._state_model.instances[instance_id] = AgentWorkflowEntry( + input_value="Test task", + source="test_source", + triggering_workflow_instance_id=None, + workflow_instance_id=instance_id, + workflow_name="AgenticWorkflow", + status="RUNNING", + messages=[], + tool_history=[], + end_time=None, + start_time=datetime.now(timezone.utc), + ) - # Mock memory.add_message + # Mock memory.add_message and save_state with patch.object(type(basic_durable_agent.memory), "add_message"): - basic_durable_agent._save_assistant_message(instance_id, assistant_message) + with patch.object(basic_durable_agent, "save_state"): + basic_durable_agent._save_assistant_message( + instance_id, assistant_message + ) # Verify message was added to instance - instance_data = basic_durable_agent.state["instances"][instance_id] - assert len(instance_data["messages"]) == 1 - assert instance_data["messages"][0]["role"] == "assistant" - assert instance_data["messages"][0]["content"] == "Hello back!" - assert instance_data["last_message"]["role"] == "assistant" + entry = basic_durable_agent._state_model.instances[instance_id] + assert len(entry.messages) == 1 + assert entry.messages[0].role == "assistant" + assert entry.messages[0].content == "Hello back!" + assert entry.last_message.role == "assistant" def test_get_last_message_from_state(self, basic_durable_agent): - """Test _get_last_message_from_state helper method.""" + """Test accessing last_message from instance state.""" + from datetime import datetime, timezone + instance_id = "test-instance-123" - # Set up instance with last_message - basic_durable_agent.state["instances"][instance_id] = { - "input": "Test task", - "source": "test_source", - "triggering_workflow_instance_id": None, - "workflow_instance_id": instance_id, - "workflow_name": "AgenticWorkflow", - "status": "RUNNING", - "messages": [], - "tool_history": [], - "end_time": None, - "trace_context": None, - "last_message": DurableAgentMessage( - role="assistant", content="Last message" - ).model_dump(mode="json"), - } + # Set up instance with last_message using AgentWorkflowEntry + if not hasattr(basic_durable_agent._state_model, "instances"): + basic_durable_agent._state_model.instances = {} + + last_msg = AgentWorkflowMessage(role="assistant", content="Last message") + basic_durable_agent._state_model.instances[instance_id] = AgentWorkflowEntry( + input_value="Test task", + source="test_source", + triggering_workflow_instance_id=None, + workflow_instance_id=instance_id, + workflow_name="AgenticWorkflow", + status="RUNNING", + messages=[], + tool_history=[], + end_time=None, + start_time=datetime.now(timezone.utc), + last_message=last_msg, + ) - result = basic_durable_agent._get_last_message_from_state(instance_id) - assert result["role"] == "assistant" - assert result["content"] == "Last message" + # Access last_message directly from the entry + entry = basic_durable_agent._state_model.instances.get(instance_id) + assert entry is not None + assert entry.last_message.role == "assistant" + assert entry.last_message.content == "Last message" # Test with non-existent instance - result = basic_durable_agent._get_last_message_from_state("non-existent") + result = basic_durable_agent._state_model.instances.get("non-existent") assert result is None - def test_create_tool_message_objects(self, basic_durable_agent): - """Test _create_tool_message_objects helper method.""" - tool_result = { - "tool_call_id": "call_123", - "tool_name": "test_tool", - "tool_args": {"arg1": "value1"}, - "execution_result": "tool_result", - } + @pytest.mark.asyncio + async def test_create_tool_message_objects(self, basic_durable_agent): + """Test that tool message objects are created correctly (via run_tool activity).""" + from datetime import datetime, timezone - ( - tool_msg, - agent_msg, - tool_history_entry, - ) = basic_durable_agent._create_tool_message_objects(tool_result) - - # Verify tool message - assert tool_msg.tool_call_id == "call_123" - assert tool_msg.name == "test_tool" - assert tool_msg.content == "tool_result" - - # Verify agent message (DurableAgentMessage) - assert agent_msg.role == "tool" - assert agent_msg.tool_call_id == "call_123" - assert agent_msg.content == "tool_result" - - # Verify tool history entry (ToolExecutionRecord) - assert tool_history_entry.tool_call_id == "call_123" - assert tool_history_entry.tool_name == "test_tool" - assert tool_history_entry.tool_args == {"arg1": "value1"} - assert tool_history_entry.execution_result == "tool_result" - - def test_append_tool_message_to_instance(self, basic_durable_agent): - """Test _append_tool_message_to_instance helper method.""" instance_id = "test-instance-123" + tool_call = { + "id": "call_123", + "function": {"name": "test_tool", "arguments": '{"arg1": "value1"}'}, + } # Set up instance - basic_durable_agent.state["instances"][instance_id] = { - "input": "Test task", - "source": "test_source", - "triggering_workflow_instance_id": None, - "workflow_instance_id": instance_id, - "workflow_name": "AgenticWorkflow", - "status": "RUNNING", - "messages": [], - "tool_history": [], - "end_time": None, - "trace_context": None, - } + if not hasattr(basic_durable_agent._state_model, "instances"): + basic_durable_agent._state_model.instances = {} + + basic_durable_agent._state_model.instances[instance_id] = AgentWorkflowEntry( + input_value="Test task", + source="test_source", + triggering_workflow_instance_id=None, + workflow_instance_id=instance_id, + workflow_name="AgenticWorkflow", + status="RUNNING", + messages=[], + tool_history=[], + end_time=None, + start_time=datetime.now(timezone.utc), + ) + + # Mock tool executor + with patch.object( + type(basic_durable_agent.tool_executor), "run_tool", new_callable=AsyncMock + ) as mock_run: + mock_run.return_value = "tool_result" + + mock_ctx = Mock() + + with patch.object(basic_durable_agent, "save_state"): + result = await basic_durable_agent.run_tool( + mock_ctx, + { + "tool_call": tool_call, + "instance_id": instance_id, + "time": datetime.now(timezone.utc).isoformat(), + "order": 1, + }, + ) + + # Verify the tool result structure + assert result["tool_call_id"] == "call_123" + assert result["tool_name"] == "test_tool" + assert result["execution_result"] == "tool_result" + + # Verify messages and history were added to instance + entry = basic_durable_agent._state_model.instances[instance_id] + assert len(entry.messages) == 1 + assert entry.messages[0].role == "tool" + assert ( + entry.messages[0].id == "call_123" + ) # AgentWorkflowMessage uses 'id' not 'tool_call_id' + assert entry.messages[0].name == "test_tool" - # Create mock objects + assert len(entry.tool_history) == 1 + assert entry.tool_history[0].tool_call_id == "call_123" + assert entry.tool_history[0].tool_name == "test_tool" + assert entry.tool_history[0].execution_result == "tool_result" - agent_msg = DurableAgentMessage(role="assistant", content="Tool result") - tool_history_entry = ToolExecutionRecord( - tool_call_id="call_123", - tool_name="test_tool", - execution_result="tool_result", + @pytest.mark.asyncio + async def test_append_tool_message_to_instance(self, basic_durable_agent): + """Test that tool messages are appended to instance via run_tool activity.""" + instance_id = "test-instance-123" + + # Set up instance using AgentWorkflowEntry + entry = AgentWorkflowEntry( + input_value="Test task", + source="test_source", + triggering_workflow_instance_id=None, + workflow_instance_id=instance_id, + workflow_name="AgenticWorkflow", + status="RUNNING", + messages=[], + tool_history=[], ) + basic_durable_agent._state_model.instances[instance_id] = entry + + # Create a simple test tool + from dapr_agents.tool.base import AgentTool - basic_durable_agent._append_tool_message_to_instance( - instance_id, agent_msg, tool_history_entry + def test_tool_func(x): + """Test tool for verification.""" + return "tool_result" + + test_tool = AgentTool.from_func(test_tool_func) + basic_durable_agent.tools.append(test_tool) + # Recreate tool executor with the new tool + from dapr_agents.tool.executor import AgentToolExecutor + + basic_durable_agent.tool_executor = AgentToolExecutor( + tools=list(basic_durable_agent.tools) ) - # Verify instance was updated - instance_data = basic_durable_agent.state["instances"][instance_id] - assert len(instance_data["messages"]) == 1 - assert instance_data["messages"][0]["role"] == "assistant" - assert len(instance_data["tool_history"]) == 1 - assert instance_data["tool_history"][0]["tool_call_id"] == "call_123" + # Mock save_state to prevent actual persistence + with patch.object(basic_durable_agent, "save_state"): + mock_ctx = Mock() + + # Call run_tool activity which appends messages and tool_history + await basic_durable_agent.run_tool( + mock_ctx, + { + "instance_id": instance_id, + "tool_call": { + "id": "call_123", + "type": "function", + "function": { + "name": "TestToolFunc", # Tool name is CamelCase version of function name + "arguments": '{"x": "test"}', # Pass string to match type hint default + }, + }, + }, + ) - def test_update_agent_memory_and_history(self, basic_durable_agent): - """Test _update_agent_memory_and_history helper method.""" + # Verify entry was updated with message and tool_history + assert len(entry.messages) == 1 + assert entry.messages[0].role == "tool" + assert entry.messages[0].id == "call_123" # AgentWorkflowMessage uses 'id' + assert len(entry.tool_history) == 1 + assert entry.tool_history[0].tool_call_id == "call_123" + assert entry.tool_history[0].tool_name == "TestToolFunc" - tool_msg = ToolMessage( - tool_call_id="call_123", name="test_tool", content="Tool result" + @pytest.mark.asyncio + async def test_update_agent_memory_and_history(self, basic_durable_agent): + """Test that memory and history are updated via run_tool activity.""" + instance_id = "test-instance-123" + + # Set up instance using AgentWorkflowEntry + entry = AgentWorkflowEntry( + input_value="Test task", + source="test_source", + triggering_workflow_instance_id=None, + workflow_instance_id=instance_id, + workflow_name="AgenticWorkflow", + status="RUNNING", + messages=[], + tool_history=[], ) - tool_history_entry = ToolExecutionRecord( - tool_call_id="call_123", - tool_name="test_tool", - execution_result="tool_result", + basic_durable_agent._state_model.instances[instance_id] = entry + + # Create a simple test tool + from dapr_agents.tool.base import AgentTool + + def test_tool_func(x: str) -> str: + """Test tool for verification.""" + return "tool_result" + + test_tool = AgentTool.from_func(test_tool_func) + basic_durable_agent.tools.append(test_tool) + # Recreate tool executor with the new tool + from dapr_agents.tool.executor import AgentToolExecutor + + basic_durable_agent.tool_executor = AgentToolExecutor( + tools=list(basic_durable_agent.tools) ) - # Mock the memory add_message method - with patch.object( - type(basic_durable_agent.memory), "add_message" - ) as mock_add_message: - basic_durable_agent._update_agent_memory_and_history( - tool_msg, tool_history_entry + # Mock save_state to prevent actual persistence + with patch.object(basic_durable_agent, "save_state"): + mock_ctx = Mock() + + # Call run_tool activity which updates memory and history + await basic_durable_agent.run_tool( + mock_ctx, + { + "instance_id": instance_id, + "tool_call": { + "id": "call_123", + "type": "function", + "function": { + "name": "TestToolFunc", + "arguments": '{"x": "test"}', + }, + }, + }, ) - # Verify memory was updated - mock_add_message.assert_called_once_with(tool_msg) - # Verify agent-level tool_history was updated assert len(basic_durable_agent.tool_history) == 1 assert basic_durable_agent.tool_history[0].tool_call_id == "call_123" + assert basic_durable_agent.tool_history[0].tool_name == "TestToolFunc" def test_construct_messages_with_instance_history(self, basic_durable_agent): """Test _construct_messages_with_instance_history helper method.""" - instance_id = "test-instance-123" - input_data = "Test input" - - # Set up instance with messages - basic_durable_agent.state["instances"][instance_id] = { - "input": "Test task", - "source": "test_source", - "triggering_workflow_instance_id": None, - "workflow_instance_id": instance_id, - "workflow_name": "AgenticWorkflow", - "status": "RUNNING", - "messages": [ - DurableAgentMessage(role="user", content="Hello").model_dump( - mode="json" - ), - DurableAgentMessage(role="assistant", content="Hi there!").model_dump( - mode="json" - ), - ], - "tool_history": [], - "end_time": None, - "trace_context": None, - } - - # Mock prompt template - basic_durable_agent.prompt_template = Mock() - basic_durable_agent.prompt_template.format_prompt.return_value = [ - {"role": "system", "content": "System prompt"} - ] - - messages = basic_durable_agent._construct_messages_with_instance_history( - instance_id, input_data - ) + from datetime import datetime, timezone - # Should include system message + user input - assert len(messages) == 2 # system + user input - assert messages[0]["role"] == "system" - assert messages[1]["role"] == "user" - assert messages[1]["content"] == "Test input" + instance_id = "test-instance-123" - @pytest.mark.asyncio - async def test_broadcast_message(self, basic_durable_agent): - """Test broadcasting message.""" - broadcast_msg = BroadcastMessage( - content="Test broadcast", - role="assistant", - type="broadcast", - sender="TestDurableAgent", + # Set up instance with messages using AgentWorkflowEntry + if not hasattr(basic_durable_agent._state_model, "instances"): + basic_durable_agent._state_model.instances = {} + + basic_durable_agent._state_model.instances[instance_id] = AgentWorkflowEntry( + input_value="Test task", + source="test_source", + triggering_workflow_instance_id=None, + workflow_instance_id=instance_id, + workflow_name="AgenticWorkflow", + status="RUNNING", + messages=[ + AgentWorkflowMessage(role="user", content="Hello"), + AgentWorkflowMessage(role="assistant", content="Hi there!"), + ], + tool_history=[], + end_time=None, + start_time=datetime.now(timezone.utc), ) - # This needs refactoring / better implementation on this test since the actual implementation would depend on the pubsub msg broker. - await basic_durable_agent.broadcast_message(broadcast_msg) - - @pytest.mark.asyncio - async def test_send_message_to_agent(self, basic_durable_agent): - """Test sending message to specific agent.""" - task_response = AgentTaskResponse( - content="Test task", - role="assistant", - task="Test task", - agent_name="TargetAgent", - workflow_instance_id="target-instance-123", + messages = basic_durable_agent._construct_messages_with_instance_history( + instance_id ) - # This needs refactoring / better implementation on this test since the actual implementation would depend on the pubsub msg broker. - await basic_durable_agent.send_message_to_agent("TargetAgent", task_response) + # Should include messages from instance history (system messages excluded from instance timeline) + # Plus any messages from memory + assert len(messages) >= 2 # At least the 2 instance messages + # Find the user and assistant messages + user_messages = [m for m in messages if m.get("role") == "user"] + assistant_messages = [m for m in messages if m.get("role") == "assistant"] + assert len(user_messages) >= 1 + assert len(assistant_messages) >= 1 def test_register_agentic_system(self, basic_durable_agent): """Test registering agentic system.""" - # TODO(@Sicoyle): fix this to add assertions. - basic_durable_agent.register_agentic_system() - - @pytest.mark.asyncio - async def test_process_broadcast_message(self, basic_durable_agent): - """Test processing broadcast message.""" - broadcast_msg = BroadcastMessage( - content="Test broadcast", - role="assistant", - type="broadcast", - sender="OtherAgent", - ) - - # This needs refactoring / better implementation on this test since the actual implementation would depend on the pubsub msg broker. - await basic_durable_agent.process_broadcast_message(broadcast_msg) + # Mock registry_state.save to prevent actual state store operations + with patch.object(basic_durable_agent.registry_state, "save"): + basic_durable_agent.register_agentic_system() + # Verify it completes without error + assert True # If we get here, registration succeeded def test_durable_agent_properties(self, basic_durable_agent): """Test durable agent properties.""" @@ -863,15 +943,9 @@ def test_durable_agent_properties(self, basic_durable_agent): assert basic_durable_agent.text_formatter is not None assert basic_durable_agent.state is not None - def test_durable_agent_workflow_name(self, basic_durable_agent): - """Test that the workflow name is set correctly.""" - assert basic_durable_agent._workflow_name == "AgenticWorkflow" - def test_durable_agent_state_initialization(self, basic_durable_agent): """Test that the agent state is properly initialized.""" - validated_state = DurableAgentWorkflowState.model_validate( - basic_durable_agent.state - ) - assert isinstance(validated_state, DurableAgentWorkflowState) + validated_state = AgentWorkflowState.model_validate(basic_durable_agent.state) + assert isinstance(validated_state, AgentWorkflowState) assert "instances" in basic_durable_agent.state assert basic_durable_agent.state["instances"] == {} diff --git a/tests/agents/durableagent/test_mcp_streamable_http.py b/tests/agents/durableagent/test_mcp_streamable_http.py index 09f4db8d..dcf10831 100644 --- a/tests/agents/durableagent/test_mcp_streamable_http.py +++ b/tests/agents/durableagent/test_mcp_streamable_http.py @@ -1,8 +1,7 @@ import pytest from unittest.mock import AsyncMock, Mock -from dapr_agents.agents.durableagent.agent import DurableAgent -from dapr_agents.agents.durableagent.state import DurableAgentWorkflowEntry -from dapr_agents.agents.durableagent.state import DurableAgentWorkflowState +from dapr_agents.agents.durable import DurableAgent +from dapr_agents.agents.schemas import AgentWorkflowEntry, AgentWorkflowState from dapr_agents.tool.base import AgentTool @@ -11,7 +10,7 @@ def patch_dapr_check(monkeypatch): monkeypatch.setattr(DurableAgent, "save_state", lambda self: None) # The following monkeypatches are for legacy compatibility with dict-like access in tests. - # If DurableAgentWorkflowState supports dict-like access natively, these can be removed. + # If AgentWorkflowState supports dict-like access natively, these can be removed. def _getitem(self, key): return getattr(self, key) @@ -21,8 +20,8 @@ def _setdefault(self, key, default): setattr(self, key, default) return default - DurableAgentWorkflowState.__getitem__ = _getitem - DurableAgentWorkflowState.setdefault = _setdefault + AgentWorkflowState.__getitem__ = _getitem + AgentWorkflowState.setdefault = _setdefault # Patch DaprStateStore to use a mock DaprClient that supports context manager import dapr_agents.storage.daprstores.statestore as statestore @@ -47,29 +46,13 @@ def execute_state_transaction(self, *args, **kwargs): pass statestore.DaprClient = MockDaprClient - # Patch DaprStateStore to use a mock DaprClient that supports context manager - from dapr_agents.workflow import agentic - from dapr_agents.workflow import base - - # Mock the WorkflowApp initialization to prevent DaprClient creation - def mock_workflow_app_post_init(self, __context): - self.wf_runtime = Mock() - self.wf_runtime_is_running = False - self.wf_client = Mock() - self.client = Mock() - self.tasks = {} - self.workflows = {} - - monkeypatch.setattr( - base.WorkflowApp, "model_post_init", mock_workflow_app_post_init - ) # Patch out agent registration logic (skip state store entirely) def mock_register_agentic_system(self): pass monkeypatch.setattr( - agentic.AgenticWorkflow, "register_agentic_system", mock_register_agentic_system + DurableAgent, "register_agentic_system", mock_register_agentic_system ) yield @@ -138,22 +121,24 @@ async def fake_call_tool(*args, **kwargs): @pytest.fixture def durable_agent_with_mcp_tool(mock_mcp_tool, mock_mcp_session): - from dapr_agents.tool.executor import AgentToolExecutor + from dapr_agents.agents.configs import AgentPubSubConfig, AgentStateConfig + from dapr_agents.storage.daprstores.stateservice import StateStoreService agent_tool = AgentTool.from_mcp(mock_mcp_tool, session=mock_mcp_session) - tool_executor = AgentToolExecutor(tools=[agent_tool]) + agent = DurableAgent( name="TestDurableAgent", role="Math Assistant", goal="Help humans do math", instructions=["Test math instructions"], tools=[agent_tool], - state=DurableAgentWorkflowState().model_dump(), - state_store_name="teststatestore", - message_bus_name="testpubsub", - agents_registry_store_name="testregistry", + pubsub_config=AgentPubSubConfig( + pubsub_name="testpubsub", + ), + state_config=AgentStateConfig( + store=StateStoreService(store_name="teststatestore") + ), ) - agent.__pydantic_private__["_tool_executor"] = tool_executor return agent @@ -161,19 +146,19 @@ def durable_agent_with_mcp_tool(mock_mcp_tool, mock_mcp_session): async def test_execute_tool_activity_with_mcp_tool(durable_agent_with_mcp_tool): # Test the mocked MCP tool (add) with DurableAgent instance_id = "test-instance-123" - workflow_entry = { - "input": "What is 2 plus 2?", - "source": None, - "triggering_workflow_instance_id": None, - "workflow_instance_id": instance_id, - "workflow_name": "AgenticWorkflow", - "status": "RUNNING", - "messages": [], - "tool_history": [], - "end_time": None, - "trace_context": None, - } - durable_agent_with_mcp_tool.state["instances"][instance_id] = workflow_entry + + # Use AgentWorkflowEntry for state setup + entry = AgentWorkflowEntry( + input_value="What is 2 plus 2?", + source=None, + triggering_workflow_instance_id=None, + workflow_instance_id=instance_id, + workflow_name="AgenticWorkflow", + status="RUNNING", + messages=[], + tool_history=[], + ) + durable_agent_with_mcp_tool._state_model.instances[instance_id] = entry # Print available tool names for debugging tool_names = [t.name for t in durable_agent_with_mcp_tool.tool_executor.tools] @@ -183,20 +168,28 @@ async def test_execute_tool_activity_with_mcp_tool(durable_agent_with_mcp_tool): (n for n in tool_names if n.lower().startswith("add")), tool_names[0] ) - tool_call = { - "id": "call_123", - "function": {"name": tool_name, "arguments": '{"a": 2, "b": 2}'}, - } + # Create mock context + mock_ctx = Mock() + # Call run_tool activity with new signature (ctx, payload) await durable_agent_with_mcp_tool.run_tool( - tool_call, instance_id, "2024-01-01T00:00:00Z" + mock_ctx, + { + "instance_id": instance_id, + "tool_call": { + "id": "call_123", + "type": "function", + "function": {"name": tool_name, "arguments": '{"a": 2, "b": 2}'}, + }, + }, ) - instance_data = durable_agent_with_mcp_tool.state["instances"][instance_id] - assert len(instance_data["tool_history"]) == 1 - tool_entry = instance_data["tool_history"][0] - assert tool_entry["tool_call_id"] == "call_123" - assert tool_entry["tool_name"] == tool_name - assert tool_entry["execution_result"] == "4" + + # Verify via AgentWorkflowEntry + assert len(entry.tool_history) == 1 + tool_entry = entry.tool_history[0] + assert tool_entry.tool_call_id == "call_123" + assert tool_entry.tool_name == tool_name + assert tool_entry.execution_result == "4" # Shared fixture to start the math server with streamable HTTP @@ -252,35 +245,37 @@ async def test_add_tool_with_real_server_http(start_math_server_http): @pytest.mark.asyncio async def test_durable_agent_with_real_server_http(start_math_server_http): agent_tools = await get_agent_tools_from_http() - from dapr_agents.tool.executor import AgentToolExecutor + from dapr_agents.agents.configs import AgentPubSubConfig, AgentStateConfig + from dapr_agents.storage.daprstores.stateservice import StateStoreService - tool_executor = AgentToolExecutor(tools=agent_tools) agent = DurableAgent( name="TestDurableAgent", role="Math Assistant", goal="Help humans do math", instructions=["Test math instructions"], tools=agent_tools, - state=DurableAgentWorkflowState().model_dump(), - state_store_name="teststatestore", - message_bus_name="testpubsub", - agents_registry_store_name="testregistry", + pubsub_config=AgentPubSubConfig( + pubsub_name="testpubsub", + ), + state_config=AgentStateConfig( + store=StateStoreService(store_name="teststatestore") + ), ) - agent.__pydantic_private__["_tool_executor"] = tool_executor + instance_id = "test-instance-456" - workflow_entry = { - "input": "What is 2 plus 2?", - "source": None, - "triggering_workflow_instance_id": None, - "workflow_instance_id": instance_id, - "workflow_name": "AgenticWorkflow", - "status": "RUNNING", - "messages": [], - "tool_history": [], - "end_time": None, - "trace_context": None, - } - agent.state["instances"][instance_id] = workflow_entry + # Use AgentWorkflowEntry for state setup + entry = AgentWorkflowEntry( + input_value="What is 2 plus 2?", + source=None, + triggering_workflow_instance_id=None, + workflow_instance_id=instance_id, + workflow_name="AgenticWorkflow", + status="RUNNING", + messages=[], + tool_history=[], + ) + agent._state_model.instances[instance_id] = entry + # Print available tool names tool_names = [t.name for t in agent.tool_executor.tools] print("Available tool names (integration test):", tool_names) @@ -288,14 +283,26 @@ async def test_durable_agent_with_real_server_http(start_math_server_http): tool_name = next( (n for n in tool_names if n.lower().startswith("add")), tool_names[0] ) - tool_call = { - "id": "call_456", - "function": {"name": tool_name, "arguments": '{"a": 2, "b": 2}'}, - } - await agent.run_tool(tool_call, instance_id, "2024-01-01T00:00:00Z") - instance_data = agent.state["instances"][instance_id] - assert len(instance_data["tool_history"]) == 1 - tool_entry = instance_data["tool_history"][0] - assert tool_entry["tool_call_id"] == "call_456" - assert tool_entry["tool_name"] == tool_name - assert tool_entry["execution_result"] == "4" + + # Create mock context + mock_ctx = Mock() + + # Call run_tool activity with new signature (ctx, payload) + await agent.run_tool( + mock_ctx, + { + "instance_id": instance_id, + "tool_call": { + "id": "call_456", + "type": "function", + "function": {"name": tool_name, "arguments": '{"a": 2, "b": 2}'}, + }, + }, + ) + + # Verify via AgentWorkflowEntry + assert len(entry.tool_history) == 1 + tool_entry = entry.tool_history[0] + assert tool_entry.tool_call_id == "call_456" + assert tool_entry.tool_name == tool_name + assert tool_entry.execution_result == "4" diff --git a/tests/agents/test_base.py b/tests/agents/test_base.py index d7aa1d83..77c9f4de 100644 --- a/tests/agents/test_base.py +++ b/tests/agents/test_base.py @@ -1,9 +1,8 @@ import pytest -import asyncio -import signal from unittest.mock import Mock, patch from dapr_agents.agents.base import AgentBase +from dapr_agents.agents.configs import AgentMemoryConfig from dapr_agents.memory import ConversationListMemory from dapr_agents.llm import OpenAIChatClient from dapr_agents.prompt import ChatPromptTemplate @@ -13,7 +12,7 @@ from .mocks.vectorstore import MockVectorStore -class TestAgentBase(AgentBase): +class ConcreteAgentBase(AgentBase): """Concrete implementation of AgentBase for testing.""" def run(self, input_data): @@ -32,7 +31,7 @@ def mock_llm_client(self): @pytest.fixture def basic_agent(self, mock_llm_client): """Create a basic test agent.""" - return TestAgentBase( + return ConcreteAgentBase( name="TestAgent", role="Test Role", goal="Test Goal", @@ -43,12 +42,12 @@ def basic_agent(self, mock_llm_client): @pytest.fixture def minimal_agent(self, mock_llm_client): """Create a minimal test agent with only required fields.""" - return TestAgentBase(llm=mock_llm_client) + return ConcreteAgentBase(name="MinimalAgent", llm=mock_llm_client) @pytest.fixture def agent_with_system_prompt(self, mock_llm_client): """Create an agent with a custom system prompt.""" - return TestAgentBase( + return ConcreteAgentBase( name="CustomAgent", system_prompt="You are a custom assistant. Help users with their questions.", llm=mock_llm_client, @@ -59,46 +58,56 @@ def agent_with_tools(self, mock_llm_client): """Create an agent with tools.""" mock_tool = Mock(spec=AgentTool) mock_tool.name = "test_tool" - return TestAgentBase(name="ToolAgent", tools=[mock_tool], llm=mock_llm_client) + return ConcreteAgentBase( + name="ToolAgent", tools=[mock_tool], llm=mock_llm_client + ) @pytest.fixture def agent_with_vector_store(self, mock_llm_client): """Create an agent with vector store.""" mock_vector_store = MockVectorStore() - return TestAgentBase( + return ConcreteAgentBase( name="VectorAgent", vector_store=mock_vector_store, llm=mock_llm_client ) def test_agent_creation_with_all_fields(self, basic_agent): """Test agent creation with all fields specified.""" assert basic_agent.name == "TestAgent" - assert basic_agent.role == "Test Role" - assert basic_agent.goal == "Test Goal" - assert basic_agent.instructions == ["Test instruction 1", "Test instruction 2"] - assert basic_agent.max_iterations == 10 - assert basic_agent.template_format == "jinja2" + assert basic_agent.prompting_helper.role == "Test Role" + assert basic_agent.prompting_helper.goal == "Test Goal" + assert basic_agent.prompting_helper.instructions == [ + "Test instruction 1", + "Test instruction 2", + ] + assert basic_agent.execution_config.max_iterations == 10 + assert basic_agent.prompting_helper.template_format == "jinja2" assert isinstance(basic_agent.memory, ConversationListMemory) assert basic_agent.llm is not None def test_agent_creation_with_minimal_fields(self, minimal_agent): """Test agent creation with minimal fields.""" - # Accept both None, 'Assistant', and 'Dapr Agent' for name - assert minimal_agent.name in (None, "Assistant", "Dapr Agent") - assert minimal_agent.role == "Assistant" - assert minimal_agent.goal == "Help humans" - assert minimal_agent.instructions is None + # Name is now required + assert minimal_agent.name == "MinimalAgent" + assert minimal_agent.prompting_helper.role == "Assistant" + assert minimal_agent.prompting_helper.goal in ( + "Help users accomplish their tasks.", + "Help humans", + ) + assert minimal_agent.prompting_helper.instructions == [] # The system_prompt is automatically generated, so it won't be None - assert minimal_agent.system_prompt is not None - assert "Today's date is:" in minimal_agent.system_prompt + assert ( + minimal_agent.prompting_helper.system_prompt is not None + or minimal_agent.prompt_template is not None + ) def test_name_set_from_role_when_not_provided(self, mock_llm_client): - """Test that name is set from role when not provided.""" - agent = TestAgentBase(role="Weather Expert", llm=mock_llm_client) + """Test that agent can be created with just a name.""" + agent = ConcreteAgentBase(name="Weather Expert", llm=mock_llm_client) assert agent.name == "Weather Expert" def test_name_not_overwritten_when_provided(self, mock_llm_client): """Test that name is not overwritten when explicitly provided.""" - agent = TestAgentBase( + agent = ConcreteAgentBase( name="CustomName", role="Weather Expert", llm=mock_llm_client ) assert agent.name == "CustomName" @@ -106,7 +115,7 @@ def test_name_not_overwritten_when_provided(self, mock_llm_client): def test_agent_with_custom_system_prompt(self, agent_with_system_prompt): """Test agent with custom system prompt.""" assert ( - agent_with_system_prompt.system_prompt + agent_with_system_prompt.prompting_helper.system_prompt == "You are a custom assistant. Help users with their questions." ) assert agent_with_system_prompt.prompt_template is not None @@ -126,39 +135,35 @@ def test_prompt_template_construction(self, basic_agent): def test_system_prompt_construction(self, basic_agent): """Test system prompt construction.""" - system_prompt = basic_agent.construct_system_prompt() - assert "Today's date is:" in system_prompt - assert "Your name is {{name}}." in system_prompt - assert "Your role is {{role}}." in system_prompt - assert "{{goal}}." in system_prompt - # Instructions placeholder should only be present if instructions are set - if basic_agent.instructions: - assert "{{instructions}}" in system_prompt - else: - assert "{{instructions}}" not in system_prompt + # System prompt is now built automatically via prompting_helper + system_prompt = basic_agent.prompting_helper.system_prompt or str( + basic_agent.prompt_template + ) + assert system_prompt is not None + # The prompting helper has the role and goal + assert basic_agent.prompting_helper.role == "Test Role" + assert basic_agent.prompting_helper.goal == "Test Goal" def test_system_prompt_without_instructions(self, mock_llm_client): """Test system prompt construction without instructions.""" - agent = TestAgentBase( + agent = ConcreteAgentBase( name="TestAgent", role="Test Role", goal="Test Goal", llm=mock_llm_client ) - system_prompt = agent.construct_system_prompt() - # Instructions placeholder is always present in the template, even if instructions are not set - assert "{{instructions}}" not in system_prompt + # Check that prompt template was created + assert agent.prompt_template is not None def test_prompt_template_construction_with_system_prompt( self, agent_with_system_prompt ): """Test prompt template construction with custom system prompt.""" - template = agent_with_system_prompt.construct_prompt_template() + # Prompt template is now automatically constructed + template = agent_with_system_prompt.prompt_template assert isinstance(template, ChatPromptTemplate) - assert len(template.messages) == 2 - assert template.messages[0][0] == "system" - assert isinstance(template.messages[1], MessagePlaceHolder) + assert len(template.messages) >= 1 def test_construct_messages_with_string_input(self, basic_agent): """Test message construction with string input.""" - messages = basic_agent.construct_messages("Hello, how are you?") + messages = basic_agent.build_initial_messages("Hello, how are you?") assert len(messages) > 0 # Find the user message user_messages = [msg for msg in messages if msg.get("role") == "user"] @@ -167,17 +172,16 @@ def test_construct_messages_with_string_input(self, basic_agent): def test_construct_messages_with_dict_input(self, basic_agent): """Test message construction with dictionary input.""" - # Use variables that are actually in the template - input_data = {"chat_history": []} - messages = basic_agent.construct_messages(input_data) + # build_initial_messages handles chat_history internally + messages = basic_agent.build_initial_messages("Test message") assert len(messages) > 0 def test_construct_messages_with_invalid_input(self, basic_agent): - """Test message construction with invalid input.""" - with pytest.raises( - ValueError, match="Input data must be either a string or dictionary" - ): - basic_agent.construct_messages(123) + """Test message construction with various input types.""" + # The method now handles various input types gracefully + # Just verify it doesn't crash + messages = basic_agent.build_initial_messages(None) + assert len(messages) > 0 def test_get_last_message_empty_memory(self, basic_agent): """Test getting last message from empty memory.""" @@ -222,39 +226,6 @@ def test_reset_memory(self, basic_agent): basic_agent.reset_memory() mock_reset.assert_called_once() - def test_pre_fill_prompt_template(self, basic_agent): - """Test pre-filling prompt template with variables.""" - # Store original template for comparison - original_template = basic_agent.prompt_template - - # Pre-fill with a variable - basic_agent.pre_fill_prompt_template(custom_var="test_value") - - # Verify the template was updated - assert basic_agent.prompt_template != original_template - - # Verify the pre-filled variable is set - assert "custom_var" in basic_agent.prompt_template.pre_filled_variables - assert ( - basic_agent.prompt_template.pre_filled_variables["custom_var"] - == "test_value" - ) - - # Verify the template can still be formatted - formatted = basic_agent.prompt_template.format_prompt() - assert formatted is not None - - def test_pre_fill_prompt_template_without_template(self, mock_llm_client): - """Test pre-filling prompt template when template is not initialized.""" - agent = TestAgentBase(llm=mock_llm_client) - agent.prompt_template = None - - with pytest.raises( - ValueError, - match="Prompt template must be initialized before pre-filling variables", - ): - agent.pre_fill_prompt_template(custom_var="test_value") - def test_chat_history_with_vector_memory_and_task(self): """Test chat history retrieval with vector memory and task.""" from tests.agents.mocks.vectorstore import MockVectorStore @@ -263,38 +234,46 @@ def test_chat_history_with_vector_memory_and_task(self): mock_vector_store = MockVectorStore() mock_llm = MockLLMClient() memory = DummyVectorMemory(mock_vector_store) - agent = TestAgentBase(memory=memory, llm=mock_llm) + agent = ConcreteAgentBase( + name="TestAgent", + memory_config=AgentMemoryConfig(store=memory), + llm=mock_llm, + ) - # Access chat_history as a property - result = agent.chat_history + # Call get_chat_history() method instead of accessing property + result = agent.get_chat_history() assert isinstance(result, list) assert isinstance(result[0], Mock) def test_chat_history_with_regular_memory(self, mock_llm_client): """Test chat history retrieval with regular memory.""" memory = ConversationListMemory() - agent = TestAgentBase(memory=memory, llm=mock_llm_client) + agent = ConcreteAgentBase( + name="TestAgent", + memory_config=AgentMemoryConfig(store=memory), + llm=mock_llm_client, + ) with patch.object( ConversationListMemory, "get_messages", return_value=[Mock(spec=MessageContent)], ): - result = agent.chat_history + result = agent.get_chat_history() assert isinstance(result, list) assert isinstance(result[0], Mock) def test_prefill_agent_attributes_missing_fields_warns( self, mock_llm_client, caplog ): - """Test pre-filling agent attributes logs a warning if fields are missing in the template.""" + """Test that prompt variables are prefilled correctly even when some are not used in template.""" prompt_template = ChatPromptTemplate.from_messages( [ ("system", "Just a system message"), MessagePlaceHolder(variable_name="chat_history"), ] ) - agent = TestAgentBase( + agent = ConcreteAgentBase( name="TestAgent", role="TestRole", goal="TestGoal", @@ -302,12 +281,13 @@ def test_prefill_agent_attributes_missing_fields_warns( llm=mock_llm_client, prompt_template=prompt_template, ) - with caplog.at_level("WARNING"): - agent.prefill_agent_attributes() - assert ( - "Agent attributes set but not referenced in prompt_template" - in caplog.text - ) + # Verify that prompting_helper was initialized and prefilled variables + assert agent.prompting_helper is not None + assert agent.prompting_helper.name == "TestAgent" + assert agent.prompting_helper.role == "TestRole" + assert agent.prompting_helper.goal == "TestGoal" + # The prompt template should be prefilled + assert agent.prompt_template is not None def test_validate_llm_openai_without_api_key(self, monkeypatch): """Test validation fails when OpenAI is used without API key.""" @@ -322,7 +302,7 @@ def test_validate_llm_openai_without_api_key(self, monkeypatch): with pytest.raises( openai.OpenAIError, match="api_key client option must be set" ): - TestAgentBase(llm=OpenAIChatClient()) + ConcreteAgentBase(llm=OpenAIChatClient()) def test_validate_memory_failure(self, mock_llm_client): """Test validation fails when memory initialization fails.""" @@ -331,55 +311,50 @@ def test_validate_memory_failure(self, mock_llm_client): side_effect=Exception("Memory error"), ): with pytest.raises(Exception, match="Memory error"): - TestAgentBase(llm=mock_llm_client) - - def test_signal_handler_setup(self, basic_agent): - """Test that signal handlers are set up.""" - assert hasattr(basic_agent, "_shutdown_event") - assert isinstance(basic_agent._shutdown_event, asyncio.Event) - - def test_signal_handler(self, basic_agent): - """Test signal handler functionality.""" - with patch("builtins.print") as mock_print: - basic_agent._signal_handler(signal.SIGINT, None) - mock_print.assert_called_once() - assert basic_agent._shutdown_event.is_set() + ConcreteAgentBase(name="TestAgent", llm=mock_llm_client) def test_conflicting_prompt_templates(self, caplog): - """Test warning when both agent and LLM have prompt templates.""" + """Test that agent can have its own prompt template even when LLM has one.""" mock_llm = MockLLMClient() mock_llm.prompt_template = ChatPromptTemplate.from_messages( - [("system", "test")] + [("system", "llm template")] + ) + mock_prompt_template = ChatPromptTemplate.from_messages( + [("system", "agent template")] ) - mock_prompt_template = ChatPromptTemplate.from_messages([("system", "test2")]) - with caplog.at_level("WARNING"): - TestAgentBase(llm=mock_llm, prompt_template=mock_prompt_template) - assert ( - "Agent attributes set but not referenced in prompt_template" - in caplog.text - or "Agent attributes set but not used in prompt_template" in caplog.text - ) + agent = ConcreteAgentBase( + name="TestAgent", llm=mock_llm, prompt_template=mock_prompt_template + ) + # Agent's prompt template should be used and set on LLM + assert agent.prompt_template is not None + assert agent.llm.prompt_template is not None + # The LLM should now have the agent's template + assert agent.llm.prompt_template == agent.prompt_template def test_agent_with_custom_prompt_template(self): """Test agent with custom prompt template.""" mock_prompt_template = ChatPromptTemplate.from_messages([("system", "test")]) mock_llm = MockLLMClient() mock_llm.prompt_template = None - agent = TestAgentBase(llm=mock_llm, prompt_template=mock_prompt_template) + agent = ConcreteAgentBase( + name="TestAgent", llm=mock_llm, prompt_template=mock_prompt_template + ) assert agent.prompt_template is not None assert agent.llm.prompt_template is not None assert agent.prompt_template.messages == agent.llm.prompt_template.messages def test_agent_with_llm_prompt_template(self): - """Test agent with LLM prompt template.""" + """Test agent initialization when LLM has a prompt template.""" mock_prompt_template = ChatPromptTemplate.from_messages([("system", "test")]) mock_llm = MockLLMClient() mock_llm.prompt_template = mock_prompt_template - agent = TestAgentBase(llm=mock_llm) + agent = ConcreteAgentBase(name="TestAgent", llm=mock_llm) + # Agent should build its own prompt template from profile assert agent.prompt_template is not None assert agent.llm.prompt_template is not None - assert agent.prompt_template.messages == agent.llm.prompt_template.messages + # LLM should have agent's template set on it + assert agent.llm.prompt_template == agent.prompt_template def test_run_method_implementation(self, basic_agent): """Test that the concrete run method works.""" @@ -396,33 +371,28 @@ def test_tool_executor_property(self, basic_agent): executor = basic_agent.tool_executor assert executor is not None - def test_model_fields_set_detection(self, mock_llm_client): - """Test that model_fields_set properly detects user-set attributes.""" - agent = TestAgentBase( - name="TestName", # User set - role="TestRole", # User set - goal="TestGoal", # User set - llm=mock_llm_client, - ) - - # These should be in model_fields_set - assert "name" in agent.model_fields_set - assert "role" in agent.model_fields_set - assert "goal" in agent.model_fields_set - def test_template_format_validation(self, mock_llm_client): """Test template format validation.""" - agent = TestAgentBase(template_format="f-string", llm=mock_llm_client) - assert agent.template_format == "f-string" + from dapr_agents.agents.configs import AgentProfileConfig - agent = TestAgentBase(template_format="jinja2", llm=mock_llm_client) - assert agent.template_format == "jinja2" + profile = AgentProfileConfig(name="TestAgent", template_format="f-string") + agent = ConcreteAgentBase(profile_config=profile, llm=mock_llm_client) + assert agent.prompting_helper.template_format == "f-string" + + agent = ConcreteAgentBase(name="TestAgent", llm=mock_llm_client) + assert agent.prompting_helper.template_format == "jinja2" def test_max_iterations_default(self, minimal_agent): """Test default max iterations.""" - assert minimal_agent.max_iterations == 10 + assert minimal_agent.execution_config.max_iterations == 10 def test_max_iterations_custom(self, mock_llm_client): """Test custom max iterations.""" - agent = TestAgentBase(max_iterations=5, llm=mock_llm_client) - assert agent.max_iterations == 5 + from dapr_agents.agents.configs import AgentExecutionConfig + + agent = ConcreteAgentBase( + name="TestAgent", + execution_config=AgentExecutionConfig(max_iterations=5), + llm=mock_llm_client, + ) + assert agent.execution_config.max_iterations == 5 diff --git a/tests/document/embedder/test_sentence.py b/tests/document/embedder/test_sentence.py index fd12438f..4fe25d87 100644 --- a/tests/document/embedder/test_sentence.py +++ b/tests/document/embedder/test_sentence.py @@ -1,5 +1,20 @@ +import pytest + from dapr_agents.document.embedder.sentence import SentenceTransformerEmbedder +# Check if sentence-transformers is available +try: + import sentence_transformers # noqa: F401 + + SENTENCE_TRANSFORMERS_AVAILABLE = True +except ImportError: + SENTENCE_TRANSFORMERS_AVAILABLE = False + +pytestmark = pytest.mark.skipif( + not SENTENCE_TRANSFORMERS_AVAILABLE, + reason="sentence-transformers not installed", +) + class TestSentenceTransformerEmbedder: """Test cases for SentenceTransformerEmbedder.""" diff --git a/tests/storage/vectorstores/test_chroma.py b/tests/storage/vectorstores/test_chroma.py index e01397c9..5e49e347 100644 --- a/tests/storage/vectorstores/test_chroma.py +++ b/tests/storage/vectorstores/test_chroma.py @@ -1,6 +1,20 @@ import pytest -from dapr_agents.document.embedder.sentence import SentenceTransformerEmbedder -from dapr_agents.storage.vectorstores.chroma import ChromaVectorStore + +try: + import chromadb # noqa: F401 + from dapr_agents.document.embedder.sentence import SentenceTransformerEmbedder + from dapr_agents.storage.vectorstores.chroma import ChromaVectorStore + + CHROMA_AVAILABLE = True +except ImportError: + CHROMA_AVAILABLE = False + SentenceTransformerEmbedder = None # type: ignore + ChromaVectorStore = None # type: ignore + +pytestmark = pytest.mark.skipif( + not CHROMA_AVAILABLE, + reason="chromadb or sentence-transformers not installed - optional dependencies", +) class TestChromaVectorStore: diff --git a/tests/workflow/orchestrators/test_random.py b/tests/workflow/orchestrators/test_random.py index c438dd2e..888b1967 100644 --- a/tests/workflow/orchestrators/test_random.py +++ b/tests/workflow/orchestrators/test_random.py @@ -1,7 +1,13 @@ """Tests for the RandomOrchestrator.""" import pytest from unittest.mock import MagicMock, patch -from dapr_agents.workflow.orchestrators import RandomOrchestrator +from dapr_agents.agents.orchestrators import RandomOrchestrator +from dapr_agents.agents.configs import ( + AgentPubSubConfig, + AgentStateConfig, + AgentRegistryConfig, +) +from dapr_agents.storage.daprstores.stateservice import StateStoreService @pytest.fixture @@ -9,69 +15,53 @@ def orchestrator_config(): """Fixture to provide common orchestrator configuration.""" return { "name": "test_orchestrator", - "message_bus_name": "test-message-bus", - "state_store_name": "test-state-store", - "agents_registry_store_name": "test-registry-store", + "pubsub_config": AgentPubSubConfig(pubsub_name="test-message-bus"), + "state_config": AgentStateConfig( + store=StateStoreService(store_name="test-state-store") + ), + "registry_config": AgentRegistryConfig( + store=StateStoreService(store_name="test-registry-store") + ), } def test_random_orchestrator_initialization(orchestrator_config): """Test that RandomOrchestrator can be initialized.""" - with patch( - "dapr_agents.workflow.orchestrators.random.OrchestratorWorkflowBase.model_post_init" - ) as mock_init, patch( - "dapr_agents.workflow.agentic.AgenticWorkflow.model_post_init" - ), patch( - "dapr_agents.workflow.agentic.AgenticWorkflow._state_store_client" - ) as mock_state_store, patch( - "dapr_agents.workflow.agentic.AgenticWorkflow._dapr_client" - ) as mockclient: - mock_state_store.return_value = MagicMock() - mockclient.return_value = MagicMock() + with patch("dapr.ext.workflow.WorkflowRuntime") as mock_runtime: + mock_runtime.return_value = MagicMock() orchestrator = RandomOrchestrator(**orchestrator_config) assert orchestrator.name == "test_orchestrator" - assert orchestrator._workflow_name == "RandomWorkflow" - mock_init.assert_called_once() @pytest.mark.asyncio async def test_process_input(orchestrator_config): - """Test the process_input task.""" - with patch( - "dapr_agents.workflow.orchestrators.random.OrchestratorWorkflowBase.model_post_init" - ), patch("dapr_agents.workflow.agentic.AgenticWorkflow.model_post_init"), patch( - "dapr_agents.workflow.agentic.AgenticWorkflow._state_store_client" - ) as mock_state_store, patch( - "dapr_agents.workflow.agentic.AgenticWorkflow._dapr_client" - ) as mockclient: - mock_state_store.return_value = MagicMock() - mockclient.return_value = MagicMock() + """Test the _process_input_activity task.""" + with patch("dapr.ext.workflow.WorkflowRuntime") as mock_runtime: + mock_runtime.return_value = MagicMock() orchestrator = RandomOrchestrator(**orchestrator_config) + + # Mock the activity context + mock_ctx = MagicMock() task = "test task" - result = await orchestrator.process_input(task) + result = orchestrator._process_input_activity(mock_ctx, {"task": task}) assert result["role"] == "user" - assert result["name"] == "test_orchestrator" + assert result["name"] == "user" assert result["content"] == task def test_select_random_speaker(orchestrator_config): - """Test the select_random_speaker task.""" - with patch( - "dapr_agents.workflow.orchestrators.random.OrchestratorWorkflowBase.model_post_init" - ), patch("dapr_agents.workflow.agentic.AgenticWorkflow.model_post_init"), patch( - "dapr_agents.workflow.agentic.AgenticWorkflow._state_store_client" - ) as mock_state_store, patch( - "dapr_agents.workflow.agentic.AgenticWorkflow._dapr_client" - ) as mockclient, patch.object( + """Test the _select_random_speaker_activity task.""" + with patch("dapr.ext.workflow.WorkflowRuntime") as mock_runtime, patch.object( RandomOrchestrator, - "get_agents_metadata", + "list_team_agents", return_value={"agent1": {"name": "agent1"}, "agent2": {"name": "agent2"}}, ): - mock_state_store.return_value = MagicMock() - mockclient.return_value = MagicMock() + mock_runtime.return_value = MagicMock() orchestrator = RandomOrchestrator(**orchestrator_config) - speaker = orchestrator.select_random_speaker() + # Mock the activity context + mock_ctx = MagicMock() + speaker = orchestrator._select_random_speaker_activity(mock_ctx) + assert speaker in ["agent1", "agent2"] - assert orchestrator.current_speaker == speaker diff --git a/tests/workflow/test_grpc_config.py b/tests/workflow/test_grpc_config.py deleted file mode 100644 index c8a67cf4..00000000 --- a/tests/workflow/test_grpc_config.py +++ /dev/null @@ -1,212 +0,0 @@ -"""Tests for gRPC configuration in WorkflowApp.""" -import pytest -from unittest.mock import MagicMock, patch, call -import types -from dapr_agents.workflow.base import WorkflowApp - - -@pytest.fixture -def mock_workflow_dependencies(): - """Mock all the dependencies needed for WorkflowApp initialization.""" - with patch("dapr_agents.workflow.base.WorkflowRuntime") as mock_runtime, patch( - "dapr_agents.workflow.base.DaprWorkflowClient" - ) as mock_client, patch( - "dapr_agents.workflow.base.get_default_llm" - ) as mock_llm, patch.object( - WorkflowApp, "start_runtime" - ) as mock_start, patch.object( - WorkflowApp, "setup_signal_handlers" - ) as mock_handlers: - mock_runtime_instance = MagicMock() - mock_runtime.return_value = mock_runtime_instance - - mock_client_instance = MagicMock() - mock_client.return_value = mock_client_instance - - mock_llm_instance = MagicMock() - mock_llm.return_value = mock_llm_instance - - yield { - "runtime": mock_runtime, - "runtime_instance": mock_runtime_instance, - "client": mock_client, - "client_instance": mock_client_instance, - "llm": mock_llm, - "llm_instance": mock_llm_instance, - "start_runtime": mock_start, - "signal_handlers": mock_handlers, - } - - -def test_workflow_app_without_grpc_config(mock_workflow_dependencies): - """Test that WorkflowApp initializes without gRPC configuration.""" - # Create WorkflowApp without gRPC config - app = WorkflowApp() - - # Verify the app was created - assert app is not None - assert app.grpc_max_send_message_length is None - assert app.grpc_max_receive_message_length is None - - # Verify runtime and client were initialized - assert app.wf_runtime is not None - assert app.wf_client is not None - - -def test_workflow_app_with_grpc_config(mock_workflow_dependencies): - """Test that WorkflowApp initializes with gRPC configuration.""" - # Mock the grpc module and durabletask shared module - mock_grpc = MagicMock() - mock_shared = MagicMock() - mock_channel = MagicMock() - - # Set up the mock channel - mock_grpc.insecure_channel.return_value = mock_channel - mock_shared.get_grpc_channel = MagicMock() - - with patch.dict( - "sys.modules", - { - "grpc": mock_grpc, - "durabletask.internal.shared": mock_shared, - }, - ): - # Create WorkflowApp with gRPC config (16MB) - app = WorkflowApp( - grpc_max_send_message_length=16 * 1024 * 1024, # 16MB - grpc_max_receive_message_length=16 * 1024 * 1024, # 16MB - ) - - # Verify the configuration was set - assert app.grpc_max_send_message_length == 16 * 1024 * 1024 - assert app.grpc_max_receive_message_length == 16 * 1024 * 1024 - - # Verify runtime and client were initialized - assert app.wf_runtime is not None - assert app.wf_client is not None - - -def test_configure_grpc_channel_options_is_called(mock_workflow_dependencies): - """Test that _configure_grpc_channel_options is called when gRPC config is provided.""" - with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure: - # Create WorkflowApp with gRPC config - WorkflowApp( - grpc_max_send_message_length=8 * 1024 * 1024, # 8MB - ) - - # Verify the configuration method was called - mock_configure.assert_called_once() - - -def test_configure_grpc_channel_options_not_called_without_config( - mock_workflow_dependencies, -): - """Test that _configure_grpc_channel_options is not called without gRPC config.""" - with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure: - # Create WorkflowApp without gRPC config - WorkflowApp() - - # Verify the configuration method was NOT called - mock_configure.assert_not_called() - - -def test_grpc_channel_patching(): - """Test that the gRPC channel factory is properly patched with custom options.""" - # Mock the grpc module and durabletask shared module - mock_grpc = MagicMock() - mock_shared = MagicMock() - mock_channel = MagicMock() - - # Set up the mock channel - mock_grpc.insecure_channel.return_value = mock_channel - - # Keep original reference - def original_get_grpc_channel(*_, **__): - return "original" - - mock_shared.get_grpc_channel = original_get_grpc_channel - - # Create dummy package/module structure so 'from durabletask.internal import shared' works - durabletask_module = types.ModuleType("durabletask") - internal_module = types.ModuleType("durabletask.internal") - setattr(durabletask_module, "internal", internal_module) - setattr(internal_module, "shared", mock_shared) - - with patch.dict( - "sys.modules", - { - "grpc": mock_grpc, - "durabletask": durabletask_module, - "durabletask.internal": internal_module, - "durabletask.internal.shared": mock_shared, - }, - ), patch("dapr_agents.workflow.base.WorkflowRuntime"), patch( - "dapr_agents.workflow.base.DaprWorkflowClient" - ), patch("dapr_agents.workflow.base.get_default_llm"), patch.object( - WorkflowApp, "start_runtime" - ), patch.object(WorkflowApp, "setup_signal_handlers"): - # Create WorkflowApp with gRPC config - max_send = 10 * 1024 * 1024 # 10MB - max_recv = 12 * 1024 * 1024 # 12MB - - WorkflowApp( - grpc_max_send_message_length=max_send, - grpc_max_receive_message_length=max_recv, - ) - - # Confirm get_grpc_channel was overridden - assert callable(mock_shared.get_grpc_channel) - assert mock_shared.get_grpc_channel is not original_get_grpc_channel - assert ( - getattr(mock_shared.get_grpc_channel, "__name__", "") - == "get_grpc_channel_with_options" - ) - - # Call the patched function - test_address = "localhost:50001" - mock_shared.get_grpc_channel(test_address) - - # Verify insecure_channel was called with correct options - mock_grpc.insecure_channel.assert_called_once() - call_args = mock_grpc.insecure_channel.call_args - - # Check that the address was passed - assert call_args[0][0] == test_address - - # Check that options were passed - assert "options" in call_args.kwargs - options = call_args.kwargs["options"] - - # Verify options contain our custom message size limits - assert ("grpc.max_send_message_length", max_send) in options - assert ("grpc.max_receive_message_length", max_recv) in options - - -def test_grpc_config_with_only_send_limit(mock_workflow_dependencies): - """Test gRPC configuration with only send limit set.""" - with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure: - app = WorkflowApp( - grpc_max_send_message_length=20 * 1024 * 1024, # 20MB - ) - - # Verify configuration was called - mock_configure.assert_called_once() - - # Verify only send limit was set - assert app.grpc_max_send_message_length == 20 * 1024 * 1024 - assert app.grpc_max_receive_message_length is None - - -def test_grpc_config_with_only_receive_limit(mock_workflow_dependencies): - """Test gRPC configuration with only receive limit set.""" - with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure: - app = WorkflowApp( - grpc_max_receive_message_length=24 * 1024 * 1024, # 24MB - ) - - # Verify configuration was called - mock_configure.assert_called_once() - - # Verify only receive limit was set - assert app.grpc_max_send_message_length is None - assert app.grpc_max_receive_message_length == 24 * 1024 * 1024 diff --git a/tests/workflow/test_message_router.py b/tests/workflow/test_message_router.py index 08993acf..ec049a3d 100644 --- a/tests/workflow/test_message_router.py +++ b/tests/workflow/test_message_router.py @@ -12,7 +12,7 @@ validate_message_model, parse_cloudevent, ) -from dapr_agents.workflow.utils.registration import register_message_handlers +from dapr_agents.workflow.utils.registration import register_message_routes # Test Models @@ -87,42 +87,21 @@ def test_extract_message_models_non_class(): # Tests for message_router decorator -def test_message_router_requires_pubsub(): - """Test that message_router raises ValueError when pubsub is missing.""" +def test_message_router_requires_message_model(): + """Test that message_router raises TypeError when message_model is missing and can't be inferred.""" with pytest.raises( - ValueError, - match="`pubsub` and `topic` are required when using @message_router with arguments", + TypeError, + match="`@message_router` requires `message_model`", ): - @message_router(topic="orders") - def handler(message: OrderCreated): - pass - - -def test_message_router_requires_topic(): - """Test that message_router raises ValueError when topic is missing.""" - with pytest.raises( - ValueError, - match="`pubsub` and `topic` are required when using @message_router with arguments", - ): - - @message_router(pubsub="messagepubsub") - def handler(message: OrderCreated): - pass - - -def test_message_router_requires_message_parameter(): - """Test that message_router raises ValueError when 'message' parameter is missing.""" - with pytest.raises(ValueError, match="must have a 'message' parameter"): - @message_router(pubsub="messagepubsub", topic="orders") - def handler(data: OrderCreated): # Wrong parameter name + def handler(data: OrderCreated): # Wrong parameter name, can't infer pass def test_message_router_requires_type_hint(): """Test that message_router raises TypeError when message parameter has no type hint.""" - with pytest.raises(TypeError, match="must type-hint the 'message' parameter"): + with pytest.raises(TypeError, match="`@message_router` requires `message_model`"): @message_router(pubsub="messagepubsub", topic="orders") def handler(message): # No type hint @@ -555,7 +534,9 @@ def handle_order(message: OrderCreated): loop = asyncio.new_event_loop() try: - closers = register_message_handlers([handle_order], mock_client, loop=loop) + closers = register_message_routes( + dapr_client=mock_client, targets=[handle_order], loop=loop + ) finally: loop.close() @@ -587,7 +568,9 @@ def handle_cancelled(self, message: OrderCancelled): handler = OrderHandler() loop = asyncio.new_event_loop() try: - closers = register_message_handlers([handler], mock_client, loop=loop) + closers = register_message_routes( + dapr_client=mock_client, targets=[handler], loop=loop + ) finally: loop.close() @@ -621,7 +604,9 @@ def regular_method(self, message: OrderCreated): handler = MixedHandler() loop = asyncio.new_event_loop() try: - closers = register_message_handlers([handler], mock_client, loop=loop) + closers = register_message_routes( + dapr_client=mock_client, targets=[handler], loop=loop + ) finally: loop.close() @@ -647,8 +632,10 @@ def handle_shipment(self, message: ShipmentCreated): handler_instance = OrderHandler() loop = asyncio.new_event_loop() try: - closers = register_message_handlers( - [standalone_handler, handler_instance], mock_client, loop=loop + closers = register_message_routes( + dapr_client=mock_client, + targets=[standalone_handler, handler_instance], + loop=loop, ) finally: loop.close() @@ -673,8 +660,10 @@ def handle_cancelled(message: OrderCancelled): loop = asyncio.new_event_loop() try: - closers = register_message_handlers( - [handle_created, handle_cancelled], mock_client, loop=loop + closers = register_message_routes( + dapr_client=mock_client, + targets=[handle_created, handle_cancelled], + loop=loop, ) finally: loop.close()