Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
load_agent_response,
)
from azure.durable_functions.models import TaskBase
from azure.durable_functions.models.actions.NoOpAction import NoOpAction
from azure.durable_functions.models.Task import CompoundTask, TaskState
from pydantic import BaseModel

Expand All @@ -42,6 +43,25 @@ def __init__(
_TypedCompoundTask = CompoundTask


class PreCompletedTask(TaskBase):
"""A simple task that is already completed with a result.

Used for fire-and-forget mode where we want to return immediately
with an acceptance response without waiting for entity processing.
"""

def __init__(self, result: Any):
"""Initialize with a completed result.

Args:
result: The result value for this completed task
"""
# Initialize with a NoOp action since we don't need actual orchestration actions
super().__init__(-1, NoOpAction())
# Immediately mark as completed with the result
self.set_value(is_error=False, value=result)


class AgentTask(_TypedCompoundTask):
"""A custom Task that wraps entity calls and provides typed AgentRunResponse results.

Expand All @@ -62,10 +82,13 @@ def __init__(
response_format: Optional Pydantic model for response parsing
correlation_id: Correlation ID for logging
"""
super().__init__([entity_task])
# Set instance variables BEFORE calling super().__init__
# because super().__init__ may trigger try_set_value for pre-completed tasks
self._response_format = response_format
self._correlation_id = correlation_id

super().__init__([entity_task])

# Override action_repr to expose the inner task's action directly
# This ensures compatibility with ReplaySchema V3 which expects Action objects.
self.action_repr = entity_task.action_repr
Expand Down Expand Up @@ -130,16 +153,27 @@ def get_run_request(
message: str,
response_format: type[BaseModel] | None,
enable_tool_calls: bool,
wait_for_response: bool = True,
) -> RunRequest:
"""Get the current run request from the orchestration context.

Args:
message: The message to send to the agent
response_format: Optional Pydantic model for response parsing
enable_tool_calls: Whether to enable tool calls
wait_for_response: Must be True for orchestration contexts

Returns:
RunRequest: The current run request

Raises:
ValueError: If wait_for_response=False (not supported in orchestrations)
"""
request = super().get_run_request(
message,
response_format,
enable_tool_calls,
wait_for_response,
)
request.orchestration_id = self.context.instance_id
return request
Expand All @@ -166,7 +200,24 @@ def run_durable_agent(
session_id,
)

entity_task = self.context.call_entity(entity_id, "run", run_request.to_dict())
# Branch based on wait_for_response
if not run_request.wait_for_response:
# Fire-and-forget mode: signal entity and return pre-completed task
logger.info(
"[AzureFunctionsAgentExecutor] Fire-and-forget mode: signaling entity (correlation: %s)",
run_request.correlation_id,
)
self.context.signal_entity(entity_id, "run", run_request.to_dict())

# Create acceptance response using base class helper
acceptance_response = self._create_acceptance_response(run_request.correlation_id)

# Create a pre-completed task with the acceptance response
entity_task = PreCompletedTask(acceptance_response)
else:
# Blocking mode: call entity and wait for response
entity_task = self.context.call_entity(entity_id, "run", run_request.to_dict())

return AgentTask(
entity_task=entity_task,
response_format=run_request.response_format,
Expand Down
77 changes: 76 additions & 1 deletion python/packages/azurefunctions/tests/test_orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from unittest.mock import Mock

import pytest
from agent_framework import AgentRunResponse, ChatMessage
from agent_framework import AgentRunResponse, ChatMessage, Role
from agent_framework_durabletask import DurableAIAgent
from azure.durable_functions.models.Task import TaskBase, TaskState

Expand Down Expand Up @@ -206,6 +206,81 @@ def test_get_agent_raises_for_unregistered_agent(self) -> None:
app.get_agent(Mock(), "MissingAgent")


class TestAzureFunctionsFireAndForget:
"""Test fire-and-forget mode for AzureFunctionsAgentExecutor."""

def test_fire_and_forget_calls_signal_entity(self, executor_with_uuid: tuple[Any, Mock, str]) -> None:
"""Verify wait_for_response=False calls signal_entity instead of call_entity."""
executor, context, _ = executor_with_uuid
context.signal_entity = Mock()
context.call_entity = Mock(return_value=_create_entity_task())

agent = DurableAIAgent(executor, "TestAgent")
thread = agent.get_new_thread()

# Run with wait_for_response=False
result = agent.run("Test message", thread=thread, wait_for_response=False)

# Verify signal_entity was called and call_entity was not
assert context.signal_entity.call_count == 1
assert context.call_entity.call_count == 0

# Should still return an AgentTask
assert isinstance(result, AgentTask)

def test_fire_and_forget_returns_completed_task(self, executor_with_uuid: tuple[Any, Mock, str]) -> None:
"""Verify wait_for_response=False returns pre-completed AgentTask."""
executor, context, _ = executor_with_uuid
context.signal_entity = Mock()

agent = DurableAIAgent(executor, "TestAgent")
thread = agent.get_new_thread()

result = agent.run("Test message", thread=thread, wait_for_response=False)

# Task should be immediately complete
assert isinstance(result, AgentTask)
assert result.is_completed

def test_fire_and_forget_returns_acceptance_response(self, executor_with_uuid: tuple[Any, Mock, str]) -> None:
"""Verify wait_for_response=False returns acceptance response."""
executor, context, _ = executor_with_uuid
context.signal_entity = Mock()

agent = DurableAIAgent(executor, "TestAgent")
thread = agent.get_new_thread()

result = agent.run("Test message", thread=thread, wait_for_response=False)

# Get the result
response = result.result
assert isinstance(response, AgentRunResponse)
assert len(response.messages) == 1
assert response.messages[0].role == Role.SYSTEM
# Check message contains key information
message_text = response.messages[0].text
assert "accepted" in message_text.lower()
assert "background" in message_text.lower()

def test_blocking_mode_still_works(self, executor_with_uuid: tuple[Any, Mock, str]) -> None:
"""Verify wait_for_response=True uses call_entity as before."""
executor, context, _ = executor_with_uuid
context.signal_entity = Mock()
context.call_entity = Mock(return_value=_create_entity_task())

agent = DurableAIAgent(executor, "TestAgent")
thread = agent.get_new_thread()

result = agent.run("Test message", thread=thread, wait_for_response=True)

# Verify call_entity was called and signal_entity was not
assert context.call_entity.call_count == 1
assert context.signal_entity.call_count == 0

# Should return an AgentTask
assert isinstance(result, AgentTask)


class TestOrchestrationIntegration:
"""Integration tests for orchestration scenarios."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ async def run(
response_format = run_request.response_format
enable_tool_calls = run_request.enable_tool_calls

logger.debug("[AgentEntity.run] Received Message: %s", run_request)
logger.debug("[AgentEntity.run] Received ThreadId %s Message: %s", thread_id, run_request)

state_request = DurableAgentStateRequest.from_run_request(run_request)
self.state.data.conversation_history.append(state_request)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
from datetime import datetime, timezone
from typing import Any, Generic, TypeVar

from agent_framework import AgentRunResponse, AgentThread, ChatMessage, ErrorContent, Role, get_logger
from agent_framework import AgentRunResponse, AgentThread, ChatMessage, ErrorContent, Role, TextContent, get_logger
from durabletask.client import TaskHubGrpcClient
from durabletask.entities import EntityInstanceId
from durabletask.task import CompositeTask, OrchestrationContext, Task
from durabletask.task import CompletableTask, CompositeTask, OrchestrationContext, Task
from pydantic import BaseModel

from ._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS
Expand All @@ -33,16 +33,19 @@
TaskT = TypeVar("TaskT")


class DurableAgentTask(CompositeTask[AgentRunResponse]):
class DurableAgentTask(CompositeTask[AgentRunResponse], CompletableTask[AgentRunResponse]):
"""A custom Task that wraps entity calls and provides typed AgentRunResponse results.

This task wraps the underlying entity call task and intercepts its completion
to convert the raw result into a typed AgentRunResponse object.

When yielded in an orchestration, this task returns an AgentRunResponse:
response: AgentRunResponse = yield durable_agent_task
"""

def __init__(
self,
entity_task: Task[Any],
entity_task: CompletableTask[Any],
response_format: type[BaseModel] | None,
correlation_id: str,
):
Expand All @@ -55,7 +58,7 @@ def __init__(
"""
self._response_format = response_format
self._correlation_id = correlation_id
super().__init__([entity_task]) # type: ignore[misc]
super().__init__([entity_task]) # type: ignore

def on_child_completed(self, task: Task[Any]) -> None:
"""Handle completion of the underlying entity task.
Expand All @@ -69,11 +72,8 @@ def on_child_completed(self, task: Task[Any]) -> None:
return

if task.is_failed:
# Propagate the failure
self._exception = task.get_exception()
self._is_complete = True
if self._parent is not None:
self._parent.on_child_completed(self)
# Propagate the failure - pass the original exception directly
self.fail("call_entity Task failed", task.get_exception())
return

# Task succeeded - transform the raw result
Expand All @@ -94,18 +94,12 @@ def on_child_completed(self, task: Task[Any]) -> None:
)

# Set the typed AgentRunResponse as this task's result
self._result = response
self._is_complete = True

if self._parent is not None:
self._parent.on_child_completed(self)
self.complete(response)

except Exception:
logger.exception(
"[DurableAgentTask] Failed to convert result for correlation_id: %s",
self._correlation_id,
)
raise
except Exception as ex:
err_msg = "[DurableAgentTask] Failed to convert result for correlation_id: " + self._correlation_id
logger.exception(err_msg)
self.fail(err_msg, ex)


class DurableAgentExecutor(ABC, Generic[TaskT]):
Expand Down Expand Up @@ -155,16 +149,42 @@ def get_run_request(
message: str,
response_format: type[BaseModel] | None,
enable_tool_calls: bool,
wait_for_response: bool = True,
) -> RunRequest:
"""Create a RunRequest for the given parameters."""
correlation_id = self.generate_unique_id()
return RunRequest(
message=message,
response_format=response_format,
enable_tool_calls=enable_tool_calls,
wait_for_response=wait_for_response,
correlation_id=correlation_id,
)

def _create_acceptance_response(self, correlation_id: str) -> AgentRunResponse:
"""Create an acceptance response for fire-and-forget mode.

Args:
correlation_id: Correlation ID for tracking the request

Returns:
AgentRunResponse: Acceptance response with correlation ID
"""
acceptance_message = ChatMessage(
role=Role.SYSTEM,
contents=[
TextContent(
f"Request accepted for processing (correlation_id: {correlation_id}). "
f"Agent is executing in the background. "
f"Retrieve response via your configured streaming or callback mechanism."
)
],
)
return AgentRunResponse(
messages=[acceptance_message],
created_at=datetime.now(timezone.utc).isoformat(),
)


class ClientAgentExecutor(DurableAgentExecutor[AgentRunResponse]):
"""Execution strategy for external clients.
Expand Down Expand Up @@ -205,11 +225,20 @@ def run_durable_agent(
thread: Optional conversation thread (creates new if not provided)

Returns:
AgentRunResponse: The agent's response after execution completes
AgentRunResponse: The agent's response after execution completes, or an immediate
acknowledgement if wait_for_response is False
"""
# Signal the entity with the request
entity_id = self._signal_agent_entity(agent_name, run_request, thread)

# If fire-and-forget mode, return immediately without polling
if not run_request.wait_for_response:
logger.info(
"[ClientAgentExecutor] Fire-and-forget mode: request signaled (correlation: %s)",
run_request.correlation_id,
)
return self._create_acceptance_response(run_request.correlation_id)

# Poll for the response
agent_response = self._poll_for_agent_response(entity_id, run_request.correlation_id)

Expand Down Expand Up @@ -395,11 +424,16 @@ def __init__(self, context: OrchestrationContext):
self._context = context
logger.debug("[OrchestrationAgentExecutor] Initialized")

def generate_unique_id(self) -> str:
"""Create a new UUID that is safe for replay within an orchestration or operation."""
return self._context.new_uuid()

def get_run_request(
self,
message: str,
response_format: type[BaseModel] | None,
enable_tool_calls: bool,
wait_for_response: bool = True,
) -> RunRequest:
"""Get the current run request from the orchestration context.

Expand All @@ -410,6 +444,7 @@ def get_run_request(
message,
response_format,
enable_tool_calls,
wait_for_response,
)
request.orchestration_id = self._context.instance_id
return request
Expand Down Expand Up @@ -449,8 +484,22 @@ def run_durable_agent(
session_id,
)

# Call the entity and get the underlying task
entity_task: Task[Any] = self._context.call_entity(entity_id, "run", run_request.to_dict()) # type: ignore
# Branch based on wait_for_response
if not run_request.wait_for_response:
# Fire-and-forget mode: signal entity and return pre-completed task
logger.info(
"[OrchestrationAgentExecutor] Fire-and-forget mode: signaling entity (correlation: %s)",
run_request.correlation_id,
)
self._context.signal_entity(entity_id, "run", run_request.to_dict())

# Create a pre-completed task with acceptance response
acceptance_response = self._create_acceptance_response(run_request.correlation_id)
entity_task: CompletableTask[AgentRunResponse] = CompletableTask()
entity_task.complete(acceptance_response)
else:
# Blocking mode: call entity and wait for response
entity_task = self._context.call_entity(entity_id, "run", run_request.to_dict()) # type: ignore

# Wrap in DurableAgentTask for response transformation
return DurableAgentTask(
Expand Down
Loading
Loading