Skip to content

Commit 3df9160

Browse files
authored
Python: Add Durabletask samples and minor fixes (#3157)
* Add samples and minor fixes * Add redis sample and wait-for-completion * Add wait-for-completion support * ADd missing docs
1 parent 1e36ba3 commit 3df9160

File tree

48 files changed

+4219
-1151
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+4219
-1151
lines changed

python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
load_agent_response,
1818
)
1919
from azure.durable_functions.models import TaskBase
20+
from azure.durable_functions.models.actions.NoOpAction import NoOpAction
2021
from azure.durable_functions.models.Task import CompoundTask, TaskState
2122
from pydantic import BaseModel
2223

@@ -42,6 +43,25 @@ def __init__(
4243
_TypedCompoundTask = CompoundTask
4344

4445

46+
class PreCompletedTask(TaskBase):
47+
"""A simple task that is already completed with a result.
48+
49+
Used for fire-and-forget mode where we want to return immediately
50+
with an acceptance response without waiting for entity processing.
51+
"""
52+
53+
def __init__(self, result: Any):
54+
"""Initialize with a completed result.
55+
56+
Args:
57+
result: The result value for this completed task
58+
"""
59+
# Initialize with a NoOp action since we don't need actual orchestration actions
60+
super().__init__(-1, NoOpAction())
61+
# Immediately mark as completed with the result
62+
self.set_value(is_error=False, value=result)
63+
64+
4565
class AgentTask(_TypedCompoundTask):
4666
"""A custom Task that wraps entity calls and provides typed AgentRunResponse results.
4767
@@ -62,10 +82,13 @@ def __init__(
6282
response_format: Optional Pydantic model for response parsing
6383
correlation_id: Correlation ID for logging
6484
"""
65-
super().__init__([entity_task])
85+
# Set instance variables BEFORE calling super().__init__
86+
# because super().__init__ may trigger try_set_value for pre-completed tasks
6687
self._response_format = response_format
6788
self._correlation_id = correlation_id
6889

90+
super().__init__([entity_task])
91+
6992
# Override action_repr to expose the inner task's action directly
7093
# This ensures compatibility with ReplaySchema V3 which expects Action objects.
7194
self.action_repr = entity_task.action_repr
@@ -130,16 +153,27 @@ def get_run_request(
130153
message: str,
131154
response_format: type[BaseModel] | None,
132155
enable_tool_calls: bool,
156+
wait_for_response: bool = True,
133157
) -> RunRequest:
134158
"""Get the current run request from the orchestration context.
135159
160+
Args:
161+
message: The message to send to the agent
162+
response_format: Optional Pydantic model for response parsing
163+
enable_tool_calls: Whether to enable tool calls
164+
wait_for_response: Must be True for orchestration contexts
165+
136166
Returns:
137167
RunRequest: The current run request
168+
169+
Raises:
170+
ValueError: If wait_for_response=False (not supported in orchestrations)
138171
"""
139172
request = super().get_run_request(
140173
message,
141174
response_format,
142175
enable_tool_calls,
176+
wait_for_response,
143177
)
144178
request.orchestration_id = self.context.instance_id
145179
return request
@@ -166,7 +200,24 @@ def run_durable_agent(
166200
session_id,
167201
)
168202

169-
entity_task = self.context.call_entity(entity_id, "run", run_request.to_dict())
203+
# Branch based on wait_for_response
204+
if not run_request.wait_for_response:
205+
# Fire-and-forget mode: signal entity and return pre-completed task
206+
logger.debug(
207+
"[AzureFunctionsAgentExecutor] Fire-and-forget mode: signaling entity (correlation: %s)",
208+
run_request.correlation_id,
209+
)
210+
self.context.signal_entity(entity_id, "run", run_request.to_dict())
211+
212+
# Create acceptance response using base class helper
213+
acceptance_response = self._create_acceptance_response(run_request.correlation_id)
214+
215+
# Create a pre-completed task with the acceptance response
216+
entity_task = PreCompletedTask(acceptance_response)
217+
else:
218+
# Blocking mode: call entity and wait for response
219+
entity_task = self.context.call_entity(entity_id, "run", run_request.to_dict())
220+
170221
return AgentTask(
171222
entity_task=entity_task,
172223
response_format=run_request.response_format,

python/packages/azurefunctions/tests/test_orchestration.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from unittest.mock import Mock
77

88
import pytest
9-
from agent_framework import AgentRunResponse, ChatMessage
9+
from agent_framework import AgentRunResponse, ChatMessage, Role
1010
from agent_framework_durabletask import DurableAIAgent
1111
from azure.durable_functions.models.Task import TaskBase, TaskState
1212

@@ -206,6 +206,81 @@ def test_get_agent_raises_for_unregistered_agent(self) -> None:
206206
app.get_agent(Mock(), "MissingAgent")
207207

208208

209+
class TestAzureFunctionsFireAndForget:
210+
"""Test fire-and-forget mode for AzureFunctionsAgentExecutor."""
211+
212+
def test_fire_and_forget_calls_signal_entity(self, executor_with_uuid: tuple[Any, Mock, str]) -> None:
213+
"""Verify wait_for_response=False calls signal_entity instead of call_entity."""
214+
executor, context, _ = executor_with_uuid
215+
context.signal_entity = Mock()
216+
context.call_entity = Mock(return_value=_create_entity_task())
217+
218+
agent = DurableAIAgent(executor, "TestAgent")
219+
thread = agent.get_new_thread()
220+
221+
# Run with wait_for_response=False
222+
result = agent.run("Test message", thread=thread, wait_for_response=False)
223+
224+
# Verify signal_entity was called and call_entity was not
225+
assert context.signal_entity.call_count == 1
226+
assert context.call_entity.call_count == 0
227+
228+
# Should still return an AgentTask
229+
assert isinstance(result, AgentTask)
230+
231+
def test_fire_and_forget_returns_completed_task(self, executor_with_uuid: tuple[Any, Mock, str]) -> None:
232+
"""Verify wait_for_response=False returns pre-completed AgentTask."""
233+
executor, context, _ = executor_with_uuid
234+
context.signal_entity = Mock()
235+
236+
agent = DurableAIAgent(executor, "TestAgent")
237+
thread = agent.get_new_thread()
238+
239+
result = agent.run("Test message", thread=thread, wait_for_response=False)
240+
241+
# Task should be immediately complete
242+
assert isinstance(result, AgentTask)
243+
assert result.is_completed
244+
245+
def test_fire_and_forget_returns_acceptance_response(self, executor_with_uuid: tuple[Any, Mock, str]) -> None:
246+
"""Verify wait_for_response=False returns acceptance response."""
247+
executor, context, _ = executor_with_uuid
248+
context.signal_entity = Mock()
249+
250+
agent = DurableAIAgent(executor, "TestAgent")
251+
thread = agent.get_new_thread()
252+
253+
result = agent.run("Test message", thread=thread, wait_for_response=False)
254+
255+
# Get the result
256+
response = result.result
257+
assert isinstance(response, AgentRunResponse)
258+
assert len(response.messages) == 1
259+
assert response.messages[0].role == Role.SYSTEM
260+
# Check message contains key information
261+
message_text = response.messages[0].text
262+
assert "accepted" in message_text.lower()
263+
assert "background" in message_text.lower()
264+
265+
def test_blocking_mode_still_works(self, executor_with_uuid: tuple[Any, Mock, str]) -> None:
266+
"""Verify wait_for_response=True uses call_entity as before."""
267+
executor, context, _ = executor_with_uuid
268+
context.signal_entity = Mock()
269+
context.call_entity = Mock(return_value=_create_entity_task())
270+
271+
agent = DurableAIAgent(executor, "TestAgent")
272+
thread = agent.get_new_thread()
273+
274+
result = agent.run("Test message", thread=thread, wait_for_response=True)
275+
276+
# Verify call_entity was called and signal_entity was not
277+
assert context.call_entity.call_count == 1
278+
assert context.signal_entity.call_count == 0
279+
280+
# Should return an AgentTask
281+
assert isinstance(result, AgentTask)
282+
283+
209284
class TestOrchestrationIntegration:
210285
"""Integration tests for orchestration scenarios."""
211286

python/packages/durabletask/agent_framework_durabletask/_entities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ async def run(
142142
response_format = run_request.response_format
143143
enable_tool_calls = run_request.enable_tool_calls
144144

145-
logger.debug("[AgentEntity.run] Received Message: %s", run_request)
145+
logger.debug("[AgentEntity.run] Received ThreadId %s Message: %s", thread_id, run_request)
146146

147147
state_request = DurableAgentStateRequest.from_run_request(run_request)
148148
self.state.data.conversation_history.append(state_request)

python/packages/durabletask/agent_framework_durabletask/_executors.py

Lines changed: 73 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
from datetime import datetime, timezone
1717
from typing import Any, Generic, TypeVar
1818

19-
from agent_framework import AgentRunResponse, AgentThread, ChatMessage, ErrorContent, Role, get_logger
19+
from agent_framework import AgentRunResponse, AgentThread, ChatMessage, ErrorContent, Role, TextContent, get_logger
2020
from durabletask.client import TaskHubGrpcClient
2121
from durabletask.entities import EntityInstanceId
22-
from durabletask.task import CompositeTask, OrchestrationContext, Task
22+
from durabletask.task import CompletableTask, CompositeTask, OrchestrationContext, Task
2323
from pydantic import BaseModel
2424

2525
from ._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS
@@ -33,16 +33,19 @@
3333
TaskT = TypeVar("TaskT")
3434

3535

36-
class DurableAgentTask(CompositeTask[AgentRunResponse]):
36+
class DurableAgentTask(CompositeTask[AgentRunResponse], CompletableTask[AgentRunResponse]):
3737
"""A custom Task that wraps entity calls and provides typed AgentRunResponse results.
3838
3939
This task wraps the underlying entity call task and intercepts its completion
4040
to convert the raw result into a typed AgentRunResponse object.
41+
42+
When yielded in an orchestration, this task returns an AgentRunResponse:
43+
response: AgentRunResponse = yield durable_agent_task
4144
"""
4245

4346
def __init__(
4447
self,
45-
entity_task: Task[Any],
48+
entity_task: CompletableTask[Any],
4649
response_format: type[BaseModel] | None,
4750
correlation_id: str,
4851
):
@@ -55,7 +58,7 @@ def __init__(
5558
"""
5659
self._response_format = response_format
5760
self._correlation_id = correlation_id
58-
super().__init__([entity_task]) # type: ignore[misc]
61+
super().__init__([entity_task]) # type: ignore
5962

6063
def on_child_completed(self, task: Task[Any]) -> None:
6164
"""Handle completion of the underlying entity task.
@@ -69,11 +72,8 @@ def on_child_completed(self, task: Task[Any]) -> None:
6972
return
7073

7174
if task.is_failed:
72-
# Propagate the failure
73-
self._exception = task.get_exception()
74-
self._is_complete = True
75-
if self._parent is not None:
76-
self._parent.on_child_completed(self)
75+
# Propagate the failure - pass the original exception directly
76+
self.fail("call_entity Task failed", task.get_exception())
7777
return
7878

7979
# Task succeeded - transform the raw result
@@ -94,18 +94,12 @@ def on_child_completed(self, task: Task[Any]) -> None:
9494
)
9595

9696
# Set the typed AgentRunResponse as this task's result
97-
self._result = response
98-
self._is_complete = True
99-
100-
if self._parent is not None:
101-
self._parent.on_child_completed(self)
97+
self.complete(response)
10298

103-
except Exception:
104-
logger.exception(
105-
"[DurableAgentTask] Failed to convert result for correlation_id: %s",
106-
self._correlation_id,
107-
)
108-
raise
99+
except Exception as ex:
100+
err_msg = "[DurableAgentTask] Failed to convert result for correlation_id: " + self._correlation_id
101+
logger.exception(err_msg)
102+
self.fail(err_msg, ex)
109103

110104

111105
class DurableAgentExecutor(ABC, Generic[TaskT]):
@@ -155,16 +149,42 @@ def get_run_request(
155149
message: str,
156150
response_format: type[BaseModel] | None,
157151
enable_tool_calls: bool,
152+
wait_for_response: bool = True,
158153
) -> RunRequest:
159154
"""Create a RunRequest for the given parameters."""
160155
correlation_id = self.generate_unique_id()
161156
return RunRequest(
162157
message=message,
163158
response_format=response_format,
164159
enable_tool_calls=enable_tool_calls,
160+
wait_for_response=wait_for_response,
165161
correlation_id=correlation_id,
166162
)
167163

164+
def _create_acceptance_response(self, correlation_id: str) -> AgentRunResponse:
165+
"""Create an acceptance response for fire-and-forget mode.
166+
167+
Args:
168+
correlation_id: Correlation ID for tracking the request
169+
170+
Returns:
171+
AgentRunResponse: Acceptance response with correlation ID
172+
"""
173+
acceptance_message = ChatMessage(
174+
role=Role.SYSTEM,
175+
contents=[
176+
TextContent(
177+
f"Request accepted for processing (correlation_id: {correlation_id}). "
178+
f"Agent is executing in the background. "
179+
f"Retrieve response via your configured streaming or callback mechanism."
180+
)
181+
],
182+
)
183+
return AgentRunResponse(
184+
messages=[acceptance_message],
185+
created_at=datetime.now(timezone.utc).isoformat(),
186+
)
187+
168188

169189
class ClientAgentExecutor(DurableAgentExecutor[AgentRunResponse]):
170190
"""Execution strategy for external clients.
@@ -205,11 +225,20 @@ def run_durable_agent(
205225
thread: Optional conversation thread (creates new if not provided)
206226
207227
Returns:
208-
AgentRunResponse: The agent's response after execution completes
228+
AgentRunResponse: The agent's response after execution completes, or an immediate
229+
acknowledgement if wait_for_response is False
209230
"""
210231
# Signal the entity with the request
211232
entity_id = self._signal_agent_entity(agent_name, run_request, thread)
212233

234+
# If fire-and-forget mode, return immediately without polling
235+
if not run_request.wait_for_response:
236+
logger.info(
237+
"[ClientAgentExecutor] Fire-and-forget mode: request signaled (correlation: %s)",
238+
run_request.correlation_id,
239+
)
240+
return self._create_acceptance_response(run_request.correlation_id)
241+
213242
# Poll for the response
214243
agent_response = self._poll_for_agent_response(entity_id, run_request.correlation_id)
215244

@@ -395,11 +424,16 @@ def __init__(self, context: OrchestrationContext):
395424
self._context = context
396425
logger.debug("[OrchestrationAgentExecutor] Initialized")
397426

427+
def generate_unique_id(self) -> str:
428+
"""Create a new UUID that is safe for replay within an orchestration or operation."""
429+
return self._context.new_uuid()
430+
398431
def get_run_request(
399432
self,
400433
message: str,
401434
response_format: type[BaseModel] | None,
402435
enable_tool_calls: bool,
436+
wait_for_response: bool = True,
403437
) -> RunRequest:
404438
"""Get the current run request from the orchestration context.
405439
@@ -410,6 +444,7 @@ def get_run_request(
410444
message,
411445
response_format,
412446
enable_tool_calls,
447+
wait_for_response,
413448
)
414449
request.orchestration_id = self._context.instance_id
415450
return request
@@ -449,8 +484,22 @@ def run_durable_agent(
449484
session_id,
450485
)
451486

452-
# Call the entity and get the underlying task
453-
entity_task: Task[Any] = self._context.call_entity(entity_id, "run", run_request.to_dict()) # type: ignore
487+
# Branch based on wait_for_response
488+
if not run_request.wait_for_response:
489+
# Fire-and-forget mode: signal entity and return pre-completed task
490+
logger.info(
491+
"[OrchestrationAgentExecutor] Fire-and-forget mode: signaling entity (correlation: %s)",
492+
run_request.correlation_id,
493+
)
494+
self._context.signal_entity(entity_id, "run", run_request.to_dict())
495+
496+
# Create a pre-completed task with acceptance response
497+
acceptance_response = self._create_acceptance_response(run_request.correlation_id)
498+
entity_task: CompletableTask[AgentRunResponse] = CompletableTask()
499+
entity_task.complete(acceptance_response)
500+
else:
501+
# Blocking mode: call entity and wait for response
502+
entity_task = self._context.call_entity(entity_id, "run", run_request.to_dict()) # type: ignore
454503

455504
# Wrap in DurableAgentTask for response transformation
456505
return DurableAgentTask(

0 commit comments

Comments
 (0)