Skip to content

Commit d4b46a6

Browse files
authored
feat: fix llm orchestrator + tracing + sessions for long-term memory (#213)
* feat: wip on orchestrator state fixing + tracing Signed-off-by: Samantha Coyle <[email protected]> * fix: separate ex/in-ternal triggers + wip fix orchestrators Signed-off-by: Samantha Coyle <[email protected]> * fix: ensure progress on substeps/steps Signed-off-by: Samantha Coyle <[email protected]> * fix: give orchestrators ability to pick up where they left off using same session id Signed-off-by: Samantha Coyle <[email protected]> * style: make linter happy Signed-off-by: Samantha Coyle <[email protected]> * fix: rm extra edge check since captured elsewhere Signed-off-by: Samantha Coyle <[email protected]> * style: update new wf name Signed-off-by: Samantha Coyle <[email protected]> * style: tox -e ruff Signed-off-by: Samantha Coyle <[email protected]> * fix: make flake8 happy Signed-off-by: Samantha Coyle <[email protected]> * fix: update requirements file Signed-off-by: Samantha Coyle <[email protected]> * docs: add comment per feedback Signed-off-by: Samantha Coyle <[email protected]> * fix: address feedback + plus update docs on 05 quickstart to use tracing Signed-off-by: Samantha Coyle <[email protected]> * docs: add comment for todo Signed-off-by: Samantha Coyle <[email protected]> * style: rename func Signed-off-by: Samantha Coyle <[email protected]> * style: rm debug logging Signed-off-by: Samantha Coyle <[email protected]> * style: tox -e ruff Signed-off-by: Samantha Coyle <[email protected]> --------- Signed-off-by: Samantha Coyle <[email protected]>
1 parent 1d3a6a1 commit d4b46a6

File tree

24 files changed

+1895
-515
lines changed

24 files changed

+1895
-515
lines changed

dapr_agents/agents/durableagent/agent.py

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .schemas import (
2525
AgentTaskResponse,
2626
BroadcastMessage,
27+
InternalTriggerAction,
2728
TriggerAction,
2829
)
2930
from .state import (
@@ -321,6 +322,30 @@ def tool_calling_workflow(self, ctx: DaprWorkflowContext, message: TriggerAction
321322
# Return the final response message
322323
return final_msg
323324

325+
@message_router
326+
@workflow(name="AgenticWorkflow")
327+
def internal_trigger_workflow(
328+
self, ctx: DaprWorkflowContext, message: InternalTriggerAction
329+
):
330+
"""
331+
Handles InternalTriggerAction messages by treating them the same as TriggerAction.
332+
This prevents self-triggering loops while allowing orchestrators to trigger agents.
333+
334+
Args:
335+
ctx (DaprWorkflowContext): The workflow context for the current execution.
336+
message (InternalTriggerAction): The internal trigger message from an orchestrator.
337+
338+
Returns:
339+
Dict[str, Any]: The final response message when the workflow completes.
340+
"""
341+
# Convert InternalTriggerAction to TriggerAction format and delegate to the main workflow
342+
trigger_message = TriggerAction(
343+
task=message.task,
344+
workflow_instance_id=message.workflow_instance_id,
345+
source="orchestrator", # Default source for internal triggers
346+
)
347+
return self.tool_calling_workflow(ctx, trigger_message)
348+
324349
def get_source_or_default(self, source: str):
325350
# Set default source if not provided (for direct run() calls)
326351
if not source:
@@ -771,14 +796,14 @@ def finalize_workflow(
771796
@message_router(broadcast=True)
772797
async def process_broadcast_message(self, message: BroadcastMessage):
773798
"""
774-
Processes a broadcast message, filtering out messages sent by the same agent
775-
and updating local memory with valid messages.
799+
Processes a broadcast message by filtering out messages from the same agent,
800+
storing valid messages in memory, and triggering the agent's workflow if needed.
776801
777802
Args:
778803
message (BroadcastMessage): The received broadcast message.
779804
780805
Returns:
781-
None: The function updates the agent's memory and ignores unwanted messages.
806+
None: The function updates the agent's memory and triggers a workflow.
782807
"""
783808
try:
784809
# Extract metadata safely from message["_message_metadata"]
@@ -819,9 +844,29 @@ async def process_broadcast_message(self, message: BroadcastMessage):
819844
# Save the state after processing the broadcast message
820845
self.save_state()
821846

847+
# Trigger agent workflow to respond to the broadcast message
848+
workflow_instance_id = metadata.get("workflow_instance_id")
849+
if workflow_instance_id:
850+
# Create a TriggerAction to start the agent's workflow
851+
trigger_message = TriggerAction(
852+
task=message.content, workflow_instance_id=workflow_instance_id
853+
)
854+
trigger_message._message_metadata = {
855+
"source": metadata.get("source", "unknown"),
856+
"type": "BroadcastMessage",
857+
"workflow_instance_id": workflow_instance_id,
858+
}
859+
860+
# Start the agent's workflow
861+
await self.run_and_monitor_workflow_async(
862+
workflow="ToolCallingWorkflow", input=trigger_message
863+
)
864+
822865
except Exception as e:
823866
logger.error(f"Error processing broadcast message: {e}", exc_info=True)
824867

868+
# TODO: we need to better design context history management. Context engineering is important,
869+
# and too much context can derail the agent.
825870
def _construct_messages_with_instance_history(
826871
self, instance_id: str, input_data: Union[str, Dict[str, Any]]
827872
) -> List[Dict[str, Any]]:
@@ -843,14 +888,52 @@ def _construct_messages_with_instance_history(
843888
)
844889

845890
# Get instance-specific chat history instead of global memory
891+
if self.state is None:
892+
logger.warning(
893+
f"Agent state is None for instance {instance_id}, initializing empty state"
894+
)
895+
self.state = {}
896+
846897
instance_data = self.state.get("instances", {}).get(instance_id)
847898
if instance_data is not None:
848899
instance_messages = instance_data.get("messages", [])
849900
else:
850901
instance_messages = []
851902

852-
# Convert instance messages to the format expected by prompt template
903+
# Always include long-term memory (chat_history) for context
904+
# This ensures agents have access to broadcast messages and persistent context
905+
long_term_memory_data = self.state.get("chat_history", [])
906+
907+
# Convert long-term memory to dict format for LLM consumption
908+
long_term_memory_messages = []
909+
for msg in long_term_memory_data:
910+
if isinstance(msg, dict):
911+
long_term_memory_messages.append(msg)
912+
elif hasattr(msg, "model_dump"):
913+
long_term_memory_messages.append(msg.model_dump())
914+
915+
# For broadcast-triggered workflows, also include additional context memory
916+
source = instance_data.get("source") if instance_data else None
917+
additional_context_messages = []
918+
if source and source != "direct":
919+
# Include additional context memory for broadcast-triggered workflows
920+
context_memory_data = self.memory.get_messages()
921+
for msg in context_memory_data:
922+
if isinstance(msg, dict):
923+
additional_context_messages.append(msg)
924+
elif hasattr(msg, "model_dump"):
925+
additional_context_messages.append(msg.model_dump())
926+
927+
# Build chat history with:
928+
# 1. Long-term memory (persistent context, broadcast messages)
929+
# 2. Short-term instance messages (current workflow specific)
930+
# 3. Additional context memory (for broadcast-triggered workflows)
853931
chat_history = []
932+
933+
# Add long-term memory first (broadcast messages, persistent context)
934+
chat_history.extend(long_term_memory_messages)
935+
936+
# Add short-term instance messages (current workflow)
854937
for msg in instance_messages:
855938
if isinstance(msg, dict):
856939
chat_history.append(msg)
@@ -860,6 +943,9 @@ def _construct_messages_with_instance_history(
860943
msg.model_dump() if hasattr(msg, "model_dump") else dict(msg)
861944
)
862945

946+
# Add additional context memory last (for broadcast-triggered workflows)
947+
chat_history.extend(additional_context_messages)
948+
863949
if isinstance(input_data, str):
864950
formatted_messages = self.prompt_template.format_prompt(
865951
chat_history=chat_history

dapr_agents/agents/durableagent/schemas.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,18 @@ class TriggerAction(BaseModel):
3131
workflow_instance_id: Optional[str] = Field(
3232
default=None, description="Dapr workflow instance id from source if available"
3333
)
34+
35+
36+
class InternalTriggerAction(BaseModel):
37+
"""
38+
Represents an internal message used by orchestrators to trigger agents.
39+
This prevents self-triggering loops.
40+
"""
41+
42+
task: Optional[str] = Field(
43+
None,
44+
description="The specific task to execute. If not provided, the agent will act based on its memory or predefined behavior.",
45+
)
46+
workflow_instance_id: Optional[str] = Field(
47+
default=None, description="Dapr workflow instance id from source if available"
48+
)

dapr_agents/llm/dapr/chat.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,8 @@ class DaprChatClient(DaprInferenceClientBase, ChatClientBase):
7979

8080
component_name: Optional[str] = None
8181

82-
component_name: Optional[str] = None
83-
84-
# Only function_call–style structured output is supported
85-
SUPPORTED_STRUCTURED_MODES: ClassVar[set[str]] = {"function_call"}
82+
# Support both function_call and json structured output modes
83+
SUPPORTED_STRUCTURED_MODES: ClassVar[set[str]] = {"function_call", "json"}
8684

8785
def model_post_init(self, __context: Any) -> None:
8886
"""
@@ -234,7 +232,7 @@ def generate(
234232
llm_component: Optional[str] = None,
235233
tools: Optional[List[Union[AgentTool, Dict[str, Any]]]] = None,
236234
response_format: Optional[Type[BaseModel]] = None,
237-
structured_mode: Literal["function_call"] = "function_call",
235+
structured_mode: Literal["function_call", "json"] = "function_call",
238236
scrubPII: bool = False,
239237
temperature: Optional[float] = None,
240238
**kwargs: Any,
@@ -256,7 +254,7 @@ def generate(
256254
llm_component: Dapr component name (defaults from env).
257255
tools: AgentTool or dict specifications.
258256
response_format: Pydantic model for structured output.
259-
structured_mode: Must be "function_call".
257+
structured_mode: Must be "function_call" or "json".
260258
scrubPII: Obfuscate sensitive output if True.
261259
temperature: Sampling temperature.
262260
**kwargs: Other Dapr API parameters.
@@ -273,9 +271,8 @@ def generate(
273271
raise ValueError(
274272
f"structured_mode must be one of {self.SUPPORTED_STRUCTURED_MODES}"
275273
)
276-
# 2) Disallow response_format + streaming
277-
if response_format is not None:
278-
raise ValueError("`response_format` is not supported by DaprChatClient.")
274+
# 2) Disallow streaming
275+
# Note: response_format is now supported for structured output
279276
if kwargs.get("stream"):
280277
raise ValueError("Streaming is not supported by DaprChatClient.")
281278

@@ -306,10 +303,15 @@ def generate(
306303
structured_mode=structured_mode,
307304
)
308305

306+
logger.debug(f"Processed parameters for Dapr: {params}")
307+
if response_format:
308+
logger.debug(f"Response format: {response_format}")
309+
logger.debug(f"Structured mode: {structured_mode}")
310+
309311
# 6) Convert to Dapr inputs & call
310312
conv_inputs = self.convert_to_conversation_inputs(params["inputs"])
311313
try:
312-
logger.info("Invoking the Dapr Conversation API.")
314+
logger.debug("Invoking the Dapr Conversation API.")
313315
# Log tools/tool_choice/parameters for debugging
314316
if params.get("tools"):
315317
try:
@@ -346,7 +348,8 @@ def generate(
346348
normalized = self.translate_response(
347349
raw, llm_component or self._llm_component
348350
)
349-
logger.info("Chat completion retrieved successfully.")
351+
logger.debug(f"Dapr Conversation API response: {raw}")
352+
logger.debug(f"Normalized response: {normalized}")
350353
except Exception as e:
351354
logger.error(
352355
f"An error occurred during the Dapr Conversation API call: {e}"
@@ -369,11 +372,9 @@ def _check_dapr_runtime_support(metadata: "GetMetadataResponse"): # noqa: F821
369372
dapr_runtime_version = extended_metadata.get("daprRuntimeVersion", None)
370373
if dapr_runtime_version is not None:
371374
# Allow only versions >=1.16.0, edge, and <2.0.0 for Alpha2 Chat Client
372-
if not is_version_supported(
373-
str(dapr_runtime_version), ">=1.16.0, edge, <2.0.0"
374-
):
375+
if not is_version_supported(str(dapr_runtime_version), ">=1.16.0, <2.0.0"):
375376
raise DaprRuntimeVersionNotSupportedError(
376-
f"!!!!! Dapr Runtime Version {dapr_runtime_version} is not supported with Alpha2 Dapr Chat Client. Only Dapr runtime versions >=1.16.0, edge, and <2.0.0 are supported."
377+
f"!!!!! Dapr Runtime Version {dapr_runtime_version} is not supported with Alpha2 Dapr Chat Client. Only Dapr runtime versions >=1.16.0 and <2.0.0 are supported."
377378
)
378379

379380

dapr_agents/llm/utils/response.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,5 +124,5 @@ def process_response(
124124
validated = StructureHandler.validate_response(raw, fmt)
125125
logger.info("Structured output successfully validated.")
126126

127-
# 3e) If it’s our auto‑wrapped iterable model, return its `.objects` list
128-
return getattr(validated, "objects", validated)
127+
# 3e) Return the validated model (don't unwrap iterable models)
128+
return validated

dapr_agents/llm/utils/structure.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def extract_structured_response(
233233
"""
234234
try:
235235
logger.debug(f"Processing structured response for mode: {structured_mode}")
236-
if llm_provider in ("openai", "nvidia", "huggingface"):
236+
if llm_provider in ("openai", "nvidia", "huggingface", "dapr"):
237237
if structured_mode == "function_call":
238238
tool_calls = getattr(message, "tool_calls", None)
239239
if tool_calls:
@@ -575,18 +575,26 @@ def validate_against_signature(result: Any, expected_type: Any) -> Any:
575575

576576
# Handle one or more BaseModels
577577
models = StructureHandler.resolve_all_pydantic_models(expected_type)
578-
for model_cls in models:
579-
try:
580-
if isinstance(result, list):
581-
return [
582-
StructureHandler.validate_response(item, model_cls).model_dump()
583-
for item in result
584-
]
585-
else:
578+
if models:
579+
validation_errors = {}
580+
for model_cls in models:
581+
try:
582+
# Always validate the entire result against the model class
583+
# Don't try to validate individual list items
586584
validated = StructureHandler.validate_response(result, model_cls)
587585
return validated.model_dump()
588-
except ValidationError:
589-
continue
586+
except ValidationError as e:
587+
validation_errors[model_cls.__name__] = e
588+
continue
589+
590+
# If we get here, all models failed validation
591+
if validation_errors:
592+
error_details = "\n".join(
593+
f"{model}: {error}" for model, error in validation_errors.items()
594+
)
595+
raise TypeError(
596+
f"Validation failed for all possible models:\n{error_details}"
597+
)
590598

591599
# Handle Union[str, dict, etc.]
592600
if origin is Union:

dapr_agents/observability/instrumentor.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
ProcessIterationsWrapper,
6161
RunToolWrapper,
6262
WorkflowMonitorWrapper,
63+
WorkflowRegistrationWrapper,
6364
WorkflowRunWrapper,
6465
WorkflowTaskWrapper,
6566
)
@@ -262,12 +263,30 @@ async def context_wrapped_coro():
262263
return loop.run_until_complete(context_wrapped_coro())
263264
except RuntimeError:
264265
# No running loop - create new one
265-
loop = asyncio.new_event_loop()
266-
asyncio.set_event_loop(loop)
266+
# TODO: eventually clean this up by using the tracing setup from dapr upstream
267+
# when we have trace propagation in the SDKs for workflows.
267268
try:
269+
loop = asyncio.new_event_loop()
270+
asyncio.set_event_loop(loop)
268271
return loop.run_until_complete(context_wrapped_coro())
272+
except Exception as e:
273+
logger.warning(
274+
f"Failed to run coroutine with new event loop: {e}"
275+
)
276+
# Fallback: run in thread pool to avoid blocking
277+
import concurrent.futures
278+
279+
with concurrent.futures.ThreadPoolExecutor() as executor:
280+
future = executor.submit(
281+
lambda: asyncio.run(context_wrapped_coro())
282+
)
283+
return future.result()
269284
finally:
270-
loop.close()
285+
try:
286+
if "loop" in locals() and not loop.is_closed():
287+
loop.close()
288+
except Exception as e:
289+
logger.debug(f"Error closing event loop: {e}")
271290

272291
def make_context_aware_task_wrapper(
273292
self, task_name: str, method, task_instance
@@ -456,6 +475,13 @@ def _apply_workflow_wrappers(self) -> None:
456475
# run_and_monitor_workflow_async internally calls run_workflow
457476
# So wrapping both causes duplicate instances
458477

478+
# Instrument workflow registration to add AGENT spans for orchestrator workflows
479+
wrap_function_wrapper(
480+
module="dapr_agents.workflow.base",
481+
name="WorkflowApp._register_workflows",
482+
wrapper=WorkflowRegistrationWrapper(self._tracer),
483+
)
484+
459485
wrap_function_wrapper(
460486
module="dapr_agents.workflow.task",
461487
name="WorkflowTask.__call__",

dapr_agents/observability/wrappers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .tool import ExecuteToolsWrapper, RunToolWrapper
77
from .workflow import (
88
WorkflowMonitorWrapper,
9+
WorkflowRegistrationWrapper,
910
WorkflowRunWrapper,
1011
)
1112
from .workflow_task import WorkflowTaskWrapper
@@ -17,6 +18,7 @@
1718
"RunToolWrapper",
1819
"ProcessIterationsWrapper",
1920
"WorkflowMonitorWrapper",
21+
"WorkflowRegistrationWrapper",
2022
"WorkflowRunWrapper",
2123
"WorkflowTaskWrapper",
2224
]

0 commit comments

Comments
 (0)