Skip to content

Commit 61c3dc4

Browse files
authored
feat: simplify durable agent wf def (#204)
* feat: simplify durable agent wf def Signed-off-by: Samantha Coyle <[email protected]> * style: tox -e ruff Signed-off-by: Samantha Coyle <[email protected]> * fix: updates for tests Signed-off-by: Samantha Coyle <[email protected]> * style: tox -e ruff again Signed-off-by: Samantha Coyle <[email protected]> * feat: update to add tests Signed-off-by: Samantha Coyle <[email protected]> * fix: updates for tests and use concrete types not dictionaries Signed-off-by: Samantha Coyle <[email protected]> * style: tox -e ruff Signed-off-by: Samantha Coyle <[email protected]> * style: tox -e flake* Signed-off-by: Samantha Coyle <[email protected]> * style: tox -e ruff again Signed-off-by: Samantha Coyle <[email protected]> * fix: updates for tests Signed-off-by: Samantha Coyle <[email protected]> * style: renaming for bilgin Signed-off-by: Samantha Coyle <[email protected]> * fix: use dict instead of pydantic due to storage layer incompatibility Signed-off-by: Samantha Coyle <[email protected]> * style: tox -e ruff Signed-off-by: Samantha Coyle <[email protected]> --------- Signed-off-by: Samantha Coyle <[email protected]>
1 parent ca43719 commit 61c3dc4

File tree

13 files changed

+911
-491
lines changed

13 files changed

+911
-491
lines changed

dapr_agents/agents/base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,20 +134,20 @@ def set_name_from_role(cls, values: dict):
134134
return values
135135

136136
@model_validator(mode="after")
137-
def validate_llm(cls, values):
137+
def validate_llm(self):
138138
"""Validate that LLM is properly configured."""
139-
if hasattr(values, "llm"):
140-
if values.llm is None:
139+
if hasattr(self, "llm"):
140+
if self.llm is None:
141141
logger.warning("LLM client is None, some functionality may be limited.")
142142
else:
143143
try:
144144
# Validate LLM is properly configured by accessing it as this is required to be set.
145-
_ = values.llm
145+
_ = self.llm
146146
except Exception as e:
147147
logger.error(f"Failed to initialize LLM: {e}")
148-
values.llm = None
148+
self.llm = None
149149

150-
return values
150+
return self
151151

152152
def model_post_init(self, __context: Any) -> None:
153153
"""

dapr_agents/agents/durableagent/agent.py

Lines changed: 371 additions & 336 deletions
Large diffs are not rendered by default.

dapr_agents/agents/durableagent/state.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from pydantic import BaseModel, Field
22
from typing import List, Optional, Dict, Any
33
from dapr_agents.types import MessageContent, ToolExecutionRecord
4+
from dapr_agents.types.workflow import DaprWorkflowStatus
45
from datetime import datetime
56
import uuid
67

@@ -60,6 +61,10 @@ class DurableAgentWorkflowEntry(BaseModel):
6061
default=None,
6162
description="OpenTelemetry trace context for workflow resumption.",
6263
)
64+
status: str = Field(
65+
default=DaprWorkflowStatus.RUNNING.value,
66+
description="Current status of the workflow.",
67+
)
6368

6469

6570
class DurableAgentWorkflowState(BaseModel):

dapr_agents/observability/context_storage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def create_resumed_workflow_context(
125125

126126
# Create AGENT span with proper agent name for resumed workflow
127127
agent_display_name = agent_name or "DurableAgent"
128-
span_name = f"{agent_display_name}.ToolCallingWorkflow"
128+
span_name = f"{agent_display_name}.AgenticWorkflow"
129129
with tracer.start_as_current_span(span_name) as span:
130130
# Set AGENT span attributes
131131
from .constants import OPENINFERENCE_SPAN_KIND

dapr_agents/observability/wrappers/workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def _extract_workflow_name(self, args: Any, kwargs: Any) -> str:
329329
workflow_name = (
330330
workflow
331331
if isinstance(workflow, str)
332-
else getattr(workflow, "__name__", "ToolCallingWorkflow")
332+
else getattr(workflow, "__name__", "AgenticWorkflow")
333333
)
334334
return workflow_name
335335

dapr_agents/observability/wrappers/workflow_task.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,7 @@ def _build_task_attributes(
192192

193193
# Add resource-level attributes for Phoenix UI grouping
194194
attributes["resource.workflow.instance_id"] = instance_id
195-
attributes[
196-
"resource.workflow.name"
197-
] = "ToolCallingWorkflow" # Could be dynamic
195+
attributes["resource.workflow.name"] = "AgenticWorkflow" # Could be dynamic
198196

199197
# Log the trace context for debugging (expected to be disconnected for Dapr Workflows)
200198
from opentelemetry import trace
@@ -323,7 +321,7 @@ def _create_context_for_resumed_workflow(
323321
logger.debug(f"Creating AGENT span for resumed workflow {instance_id}")
324322

325323
agent_name = getattr(instance, "name", "DurableAgent")
326-
workflow_name = instance_data.get("workflow_name", "ToolCallingWorkflow")
324+
workflow_name = instance_data.get("workflow_name", "AgenticWorkflow")
327325
span_name = f"{agent_name}.{workflow_name}"
328326
attributes = {
329327
"openinference.span.kind": "AGENT",
@@ -675,15 +673,13 @@ def _categorize_workflow_task(self, task_name: str) -> str:
675673
Returns:
676674
str: Semantic category for the task type
677675
"""
678-
if task_name in ["record_initial_entry", "get_workflow_entry_info"]:
676+
if task_name in ["record_initial_entry"]:
679677
return "initialization"
680-
elif task_name in ["append_assistant_message", "append_tool_message"]:
681-
return "state_management"
682678
elif task_name in ["finalize_workflow", "finish_workflow"]:
683679
return "finalization"
684680
elif task_name in ["broadcast_message_to_agents", "send_response_back"]:
685681
return "communication"
686-
elif task_name == "generate_response":
682+
elif task_name == "call_llm":
687683
return "llm_generation"
688684
elif task_name == "run_tool":
689685
return "tool_execution"

dapr_agents/tool/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ async def executor(**kwargs: Any) -> Any:
141141
tool_args_model = create_pydantic_model_from_schema(
142142
mcp_tool.inputSchema, f"{tool_name}Args"
143143
)
144-
except Exception:
144+
except Exception as e:
145+
logger.warning(f"Failed to create schema for tool '{tool_name}': {e}")
145146
pass
146147

147148
return cls(

dapr_agents/workflow/base.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
from datetime import datetime, timezone
99
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
1010

11+
from pydantic import BaseModel
1112
from dapr.ext.workflow import (
1213
DaprWorkflowClient,
1314
WorkflowActivityContext,
1415
WorkflowRuntime,
1516
)
1617
from dapr.ext.workflow.workflow_state import WorkflowState
1718
from durabletask import task as dtask
18-
from pydantic import BaseModel, ConfigDict, Field
19-
from typing import ClassVar
19+
from pydantic import ConfigDict, Field
2020

2121
from dapr_agents.agents.base import ChatClientBase
2222
from dapr_agents.types.workflow import DaprWorkflowStatus
@@ -83,12 +83,6 @@ def model_post_init(self, __context: Any) -> None:
8383

8484
self.start_runtime()
8585

86-
# Discover and register tasks and workflows
87-
discovered_tasks = self._discover_tasks()
88-
self._register_tasks(discovered_tasks)
89-
discovered_wfs = self._discover_workflows()
90-
self._register_workflows(discovered_wfs)
91-
9286
# Set up automatic signal handlers for graceful shutdown
9387
try:
9488
self.setup_signal_handlers()
@@ -365,6 +359,10 @@ def register_tasks_from_package(self, package_name: str) -> None:
365359
def _register_tasks(self, tasks: Dict[str, Callable]) -> None:
366360
"""Register each discovered task with the Dapr runtime using direct registration."""
367361
for task_name, method in tasks.items():
362+
# Don't reregister tasks that are already registered
363+
if task_name in self.tasks:
364+
continue
365+
368366
llm = self._choose_llm_for(method)
369367
logger.debug(
370368
f"Registering task '{task_name}' with llm={getattr(llm, '__class__', None)}"
@@ -451,6 +449,10 @@ def _discover_workflows(self) -> Dict[str, Callable]:
451449
def _register_workflows(self, wfs: Dict[str, Callable]) -> None:
452450
"""Register each discovered workflow with the Dapr runtime."""
453451
for wf_name, method in wfs.items():
452+
# Don't reregister workflows that are already registered
453+
if wf_name in self.workflows:
454+
continue
455+
454456
# Use a closure helper to avoid late-binding capture issues.
455457
def make_wrapped(meth: Callable) -> Callable:
456458
@functools.wraps(meth)
@@ -538,6 +540,17 @@ def start_runtime(self):
538540
else:
539541
logger.debug("Workflow runtime already running; skipping.")
540542

543+
self._ensure_activities_registered()
544+
545+
def _ensure_activities_registered(self):
546+
"""Ensure all workflow activities are registered with the Dapr runtime."""
547+
# Discover and register tasks and workflows
548+
discovered_tasks = self._discover_tasks()
549+
self._register_tasks(discovered_tasks)
550+
discovered_wfs = self._discover_workflows()
551+
self._register_workflows(discovered_wfs)
552+
logger.debug("Workflow activities registration completed.")
553+
541554
def _sync_workflow_state_after_startup(self):
542555
"""
543556
Sync database workflow state with actual Dapr workflow status after runtime startup.
@@ -555,25 +568,17 @@ def _sync_workflow_state_after_startup(self):
555568
)
556569
return
557570

558-
logger.debug("Syncing workflow state with Dapr after runtime startup...")
559571
self.load_state()
560-
561-
# Check if we have instances to sync
562-
instances = (
563-
getattr(self.state, "instances", {})
564-
if hasattr(self.state, "instances")
565-
else self.state.get("instances", {})
566-
)
567-
if not instances:
568-
return
572+
instances = self.state.get("instances", {})
569573

570574
logger.debug(f"Found {len(instances)} workflow instances to sync")
571575

572576
# Sync each instance with Dapr's actual status
573577
for instance_id, instance_data in instances.items():
574578
try:
575579
# Skip if already completed
576-
if instance_data.get("end_time") is not None:
580+
end_time = instance_data.get("end_time")
581+
if end_time is not None:
577582
continue
578583

579584
# Get actual status from Dapr
@@ -595,6 +600,7 @@ def _sync_workflow_state_after_startup(self):
595600
timezone.utc
596601
).isoformat()
597602
instance_data["status"] = runtime_status.lower()
603+
598604
logger.debug(
599605
f"Marked workflow {instance_id} as {runtime_status.lower()} in database"
600606
)
@@ -615,6 +621,7 @@ def _sync_workflow_state_after_startup(self):
615621
timezone.utc
616622
).isoformat()
617623
instance_data["status"] = DaprWorkflowStatus.COMPLETED.value
624+
618625
logger.debug(
619626
f"Workflow {instance_id} no longer in Dapr, marked as completed"
620627
)
@@ -819,9 +826,7 @@ async def _create_agent_span_for_resumed_workflow(
819826
# Get tracer and create AGENT span as child of the original trace
820827
tracer = trace.get_tracer(__name__)
821828
agent_name = getattr(self, "name", "DurableAgent")
822-
workflow_name = instance_data.get(
823-
"workflow_name", "ToolCallingWorkflow"
824-
)
829+
workflow_name = instance_data.get("workflow_name", "AgenticWorkflow")
825830
span_name = f"{agent_name}.{workflow_name}"
826831

827832
# Create the AGENT span that will show up in the trace

dapr_agents/workflow/mixins/state.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def initialize_state(self) -> None:
3737

3838
if not isinstance(self.state, dict):
3939
raise TypeError(
40-
f"Invalid state type: {type(self.state)}. Expected dict or Pydantic model."
40+
f"Invalid state type: {type(self.state)}. Expected dict."
4141
)
4242

4343
logger.debug(f"Workflow state initialized with {len(self.state)} key(s).")

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ dependencies = [
3535
"cloudevents>=1.11.0,<2.0.0",
3636
"numpy>=2.2.2,<3.0.0",
3737
"mcp>=1.7.1,<2.0.0",
38+
"websockets>=15.0.0,<16.0.0",
3839
"python-dotenv>=1.1.1,<2.0.0",
3940
"posthog<6.0.0",
4041
"nltk>=3.8.0,<4.0.0",

0 commit comments

Comments
 (0)