88from datetime import datetime , timezone
99from typing import Any , Callable , Dict , List , Optional , TypeVar , Union
1010
11+ from pydantic import BaseModel
1112from dapr .ext .workflow import (
1213 DaprWorkflowClient ,
1314 WorkflowActivityContext ,
1415 WorkflowRuntime ,
1516)
1617from dapr .ext .workflow .workflow_state import WorkflowState
1718from durabletask import task as dtask
18- from pydantic import BaseModel , ConfigDict , Field
19- from typing import ClassVar
19+ from pydantic import ConfigDict , Field
2020
2121from dapr_agents .agents .base import ChatClientBase
2222from dapr_agents .types .workflow import DaprWorkflowStatus
@@ -83,12 +83,6 @@ def model_post_init(self, __context: Any) -> None:
8383
8484 self .start_runtime ()
8585
86- # Discover and register tasks and workflows
87- discovered_tasks = self ._discover_tasks ()
88- self ._register_tasks (discovered_tasks )
89- discovered_wfs = self ._discover_workflows ()
90- self ._register_workflows (discovered_wfs )
91-
9286 # Set up automatic signal handlers for graceful shutdown
9387 try :
9488 self .setup_signal_handlers ()
@@ -365,6 +359,10 @@ def register_tasks_from_package(self, package_name: str) -> None:
365359 def _register_tasks (self , tasks : Dict [str , Callable ]) -> None :
366360 """Register each discovered task with the Dapr runtime using direct registration."""
367361 for task_name , method in tasks .items ():
362+ # Don't reregister tasks that are already registered
363+ if task_name in self .tasks :
364+ continue
365+
368366 llm = self ._choose_llm_for (method )
369367 logger .debug (
370368 f"Registering task '{ task_name } ' with llm={ getattr (llm , '__class__' , None )} "
@@ -451,6 +449,10 @@ def _discover_workflows(self) -> Dict[str, Callable]:
451449 def _register_workflows (self , wfs : Dict [str , Callable ]) -> None :
452450 """Register each discovered workflow with the Dapr runtime."""
453451 for wf_name , method in wfs .items ():
452+ # Don't reregister workflows that are already registered
453+ if wf_name in self .workflows :
454+ continue
455+
454456 # Use a closure helper to avoid late-binding capture issues.
455457 def make_wrapped (meth : Callable ) -> Callable :
456458 @functools .wraps (meth )
@@ -538,6 +540,17 @@ def start_runtime(self):
538540 else :
539541 logger .debug ("Workflow runtime already running; skipping." )
540542
543+ self ._ensure_activities_registered ()
544+
545+ def _ensure_activities_registered (self ):
546+ """Ensure all workflow activities are registered with the Dapr runtime."""
547+ # Discover and register tasks and workflows
548+ discovered_tasks = self ._discover_tasks ()
549+ self ._register_tasks (discovered_tasks )
550+ discovered_wfs = self ._discover_workflows ()
551+ self ._register_workflows (discovered_wfs )
552+ logger .debug ("Workflow activities registration completed." )
553+
541554 def _sync_workflow_state_after_startup (self ):
542555 """
543556 Sync database workflow state with actual Dapr workflow status after runtime startup.
@@ -555,25 +568,17 @@ def _sync_workflow_state_after_startup(self):
555568 )
556569 return
557570
558- logger .debug ("Syncing workflow state with Dapr after runtime startup..." )
559571 self .load_state ()
560-
561- # Check if we have instances to sync
562- instances = (
563- getattr (self .state , "instances" , {})
564- if hasattr (self .state , "instances" )
565- else self .state .get ("instances" , {})
566- )
567- if not instances :
568- return
572+ instances = self .state .get ("instances" , {})
569573
570574 logger .debug (f"Found { len (instances )} workflow instances to sync" )
571575
572576 # Sync each instance with Dapr's actual status
573577 for instance_id , instance_data in instances .items ():
574578 try :
575579 # Skip if already completed
576- if instance_data .get ("end_time" ) is not None :
580+ end_time = instance_data .get ("end_time" )
581+ if end_time is not None :
577582 continue
578583
579584 # Get actual status from Dapr
@@ -595,6 +600,7 @@ def _sync_workflow_state_after_startup(self):
595600 timezone .utc
596601 ).isoformat ()
597602 instance_data ["status" ] = runtime_status .lower ()
603+
598604 logger .debug (
599605 f"Marked workflow { instance_id } as { runtime_status .lower ()} in database"
600606 )
@@ -615,6 +621,7 @@ def _sync_workflow_state_after_startup(self):
615621 timezone .utc
616622 ).isoformat ()
617623 instance_data ["status" ] = DaprWorkflowStatus .COMPLETED .value
624+
618625 logger .debug (
619626 f"Workflow { instance_id } no longer in Dapr, marked as completed"
620627 )
@@ -819,9 +826,7 @@ async def _create_agent_span_for_resumed_workflow(
819826 # Get tracer and create AGENT span as child of the original trace
820827 tracer = trace .get_tracer (__name__ )
821828 agent_name = getattr (self , "name" , "DurableAgent" )
822- workflow_name = instance_data .get (
823- "workflow_name" , "ToolCallingWorkflow"
824- )
829+ workflow_name = instance_data .get ("workflow_name" , "AgenticWorkflow" )
825830 span_name = f"{ agent_name } .{ workflow_name } "
826831
827832 # Create the AGENT span that will show up in the trace
0 commit comments