Skip to content

Commit 1b3ac4c

Browse files
authored
Enable model activity and tool activity retries (#8)
1 parent 0a8b3b0 commit 1b3ac4c

File tree

9 files changed

+549
-141
lines changed

9 files changed

+549
-141
lines changed

azure/durable_functions/decorators/durable_app.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT License.
33

4-
4+
from azure.durable_functions.models.RetryOptions import RetryOptions
55
from .metadata import OrchestrationTrigger, ActivityTrigger, EntityTrigger,\
66
DurableClient
77
from typing import Callable, Optional
@@ -270,7 +270,15 @@ def _setup_durable_openai_agent(self, model_provider):
270270
self._create_invoke_model_activity(model_provider)
271271
self._is_durable_openai_agent_setup = True
272272

273-
def durable_openai_agent_orchestrator(self, _func=None, *, model_provider=None):
273+
def durable_openai_agent_orchestrator(
274+
self,
275+
_func=None,
276+
*,
277+
model_provider=None,
278+
model_retry_options: Optional[RetryOptions] = RetryOptions(
279+
first_retry_interval_in_milliseconds=2000, max_number_of_attempts=5
280+
),
281+
):
274282
"""Decorate Azure Durable Functions orchestrators that use OpenAI Agents.
275283
276284
Parameters
@@ -292,7 +300,9 @@ def generator_wrapper_wrapper(func):
292300

293301
@wraps(func)
294302
def generator_wrapper(context):
295-
return durable_openai_agent_orchestrator_generator(func, context)
303+
return durable_openai_agent_orchestrator_generator(
304+
func, context, model_retry_options
305+
)
296306

297307
return generator_wrapper
298308

azure/durable_functions/openai_agents/context.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,41 @@
1-
import json
21
from typing import Any, Callable, Optional
32

43
from azure.durable_functions.models.DurableOrchestrationContext import (
54
DurableOrchestrationContext,
65
)
6+
from azure.durable_functions.models.RetryOptions import RetryOptions
77

88
from agents import RunContextWrapper, Tool
99
from agents.function_schema import function_schema
1010
from agents.tool import FunctionTool
11-
from .exceptions import YieldException
11+
from .task_tracker import TaskTracker
1212

1313

1414
class DurableAIAgentContext:
1515
"""Context for AI agents running in Azure Durable Functions orchestration."""
1616

17-
def __init__(self, context: DurableOrchestrationContext):
17+
def __init__(
18+
self,
19+
context: DurableOrchestrationContext,
20+
task_tracker: TaskTracker,
21+
model_retry_options: Optional[RetryOptions],
22+
):
1823
self._context = context
19-
self._activities_called = 0
20-
self._tasks_to_yield = []
21-
22-
def _get_activity_call_result(self, activity_name, input: str):
23-
task = self._context.call_activity(activity_name, input)
24-
25-
self._activities_called += 1
26-
27-
histories = self._context.histories
28-
completed_tasks = [entry for entry in histories if entry.event_type == 5]
29-
if len(completed_tasks) < self._activities_called:
30-
# yield immediately
31-
raise YieldException(task)
32-
else:
33-
# yield later
34-
self._tasks_to_yield.append(task)
35-
36-
result_json = completed_tasks[self._activities_called - 1].Result
37-
result = json.loads(result_json)
38-
return result
24+
self._task_tracker = task_tracker
25+
self._model_retry_options = model_retry_options
3926

4027
def call_activity(self, activity_name, input: str):
41-
"""Call an activity function and increment the activity counter."""
28+
"""Call an activity function and record the activity call."""
4229
task = self._context.call_activity(activity_name, input)
43-
self._activities_called += 1
30+
self._task_tracker.record_activity_call()
31+
return task
32+
33+
def call_activity_with_retry(
34+
self, activity_name, retry_options: RetryOptions, input: str = None
35+
):
36+
"""Call an activity function with retry options and record the activity call."""
37+
task = self._context.call_activity_with_retry(activity_name, retry_options, input)
38+
self._task_tracker.record_activity_call()
4439
return task
4540

4641
def set_custom_status(self, status: str):
@@ -51,24 +46,22 @@ def wait_for_external_event(self, event_name: str):
5146
"""Wait for an external event in the orchestration."""
5247
return self._context.wait_for_external_event(event_name)
5348

54-
def _yield_and_clear_tasks(self):
55-
"""Yield all accumulated tasks and clear the tasks list."""
56-
for task in self._tasks_to_yield:
57-
yield task
58-
self._tasks_to_yield.clear()
59-
6049
def activity_as_tool(
6150
self,
6251
activity_func: Callable,
6352
*,
6453
description: Optional[str] = None,
54+
retry_options: Optional[RetryOptions] = RetryOptions(
55+
first_retry_interval_in_milliseconds=2000, max_number_of_attempts=5
56+
),
6557
) -> Tool:
6658
"""Convert an Azure Durable Functions activity to an OpenAI Agents SDK Tool.
6759
6860
Args
6961
----
7062
activity_func: The Azure Functions activity function to convert
7163
description: Optional description override for the tool
64+
retry_options: The retry options for the activity function
7265
7366
Returns
7467
-------
@@ -78,7 +71,12 @@ def activity_as_tool(
7871
activity_name = activity_func._function._name
7972

8073
async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
81-
result = self._get_activity_call_result(activity_name, input)
74+
if retry_options:
75+
result = self._task_tracker.get_activity_call_result_with_retry(
76+
activity_name, retry_options, input
77+
)
78+
else:
79+
result = self._task_tracker.get_activity_call_result(activity_name, input)
8280
return result
8381

8482
schema = function_schema(

azure/durable_functions/openai_agents/model_invocation_activity.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from datetime import timedelta
55
from typing import Any, AsyncIterator, Optional, Union, cast
66

7-
import azure.functions as func
7+
from azure.durable_functions.models.RetryOptions import RetryOptions
88
from pydantic import BaseModel, Field
99
from agents import (
1010
AgentOutputSchema,
@@ -34,7 +34,7 @@
3434
from openai.types.responses.tool_param import Mcp
3535
from openai.types.responses.response_prompt_param import ResponsePromptParam
3636

37-
from .context import DurableAIAgentContext
37+
from .task_tracker import TaskTracker
3838

3939
try:
4040
from azure.durable_functions import ApplicationError
@@ -283,14 +283,18 @@ def make_tool(tool: ToolInput) -> Tool:
283283
) from e
284284

285285

286-
class _DurableModelStub(Model):
286+
class DurableActivityModel(Model):
287+
"""A model implementation that uses durable activities for model invocations."""
288+
287289
def __init__(
288290
self,
289291
model_name: Optional[str],
290-
context: DurableAIAgentContext,
292+
task_tracker: TaskTracker,
293+
retry_options: Optional[RetryOptions],
291294
) -> None:
292295
self.model_name = model_name
293-
self.context = context
296+
self.task_tracker = task_tracker
297+
self.retry_options = retry_options
294298

295299
async def get_response(
296300
self,
@@ -305,6 +309,7 @@ async def get_response(
305309
previous_response_id: Optional[str],
306310
prompt: Optional[ResponsePromptParam],
307311
) -> ModelResponse:
312+
"""Get a response from the model."""
308313
def make_tool_info(tool: Tool) -> ToolInput:
309314
if isinstance(
310315
tool,
@@ -375,9 +380,17 @@ def make_tool_info(tool: Tool) -> ToolInput:
375380

376381
activity_input_json = activity_input.to_json()
377382

378-
response = self.context._get_activity_call_result(
379-
"invoke_model_activity", activity_input_json
380-
)
383+
if self.retry_options:
384+
response = self.task_tracker.get_activity_call_result_with_retry(
385+
"invoke_model_activity",
386+
self.retry_options,
387+
activity_input_json,
388+
)
389+
else:
390+
response = self.task_tracker.get_activity_call_result(
391+
"invoke_model_activity", activity_input_json
392+
)
393+
381394
json_response = json.loads(response)
382395
model_response = ModelResponse(**json_response)
383396
return model_response
@@ -395,21 +408,5 @@ def stream_response(
395408
previous_response_id: Optional[str],
396409
prompt: Optional[ResponsePromptParam],
397410
) -> AsyncIterator[TResponseStreamEvent]:
411+
"""Stream a response from the model."""
398412
raise NotImplementedError("Durable model doesn't support streams yet")
399-
400-
401-
def create_invoke_model_activity(app: func.FunctionApp, model_provider: Optional[ModelProvider]):
402-
"""Create and register the invoke_model_activity function with the provided FunctionApp."""
403-
404-
@app.activity_trigger(input_name="input")
405-
async def invoke_model_activity(input: str):
406-
"""Activity that handles OpenAI model invocations."""
407-
activity_input = ActivityModelInput.from_json(input)
408-
409-
model_invoker = ModelInvoker(model_provider=model_provider)
410-
result = await model_invoker.invoke_model_activity(activity_input)
411-
412-
json_obj = ModelResponse.__pydantic_serializer__.to_json(result)
413-
return json_obj.decode()
414-
415-
return invoke_model_activity
Lines changed: 14 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,16 @@
1-
import inspect
2-
import json
3-
from typing import Any
1+
from functools import partial
2+
from typing import Optional
43
from agents import ModelProvider, ModelResponse
54
from agents.run import set_default_agent_runner
65
from azure.durable_functions.models.DurableOrchestrationContext import DurableOrchestrationContext
7-
from azure.durable_functions.openai_agents.model_invocation_activity\
8-
import ActivityModelInput, ModelInvoker
6+
from azure.durable_functions.models.RetryOptions import RetryOptions
7+
from .model_invocation_activity import ActivityModelInput, ModelInvoker
8+
from .task_tracker import TaskTracker
99
from .runner import DurableOpenAIRunner
10-
from .exceptions import YieldException
1110
from .context import DurableAIAgentContext
1211
from .event_loop import ensure_event_loop
1312

1413

15-
def _durable_serializer(obj: Any) -> str:
16-
# Strings are already "serialized"
17-
if type(obj) is str:
18-
return obj
19-
20-
# Serialize "Durable" and OpenAI models, and typed dictionaries
21-
if callable(getattr(obj, "to_json", None)):
22-
return obj.to_json()
23-
24-
# Serialize Pydantic models
25-
if callable(getattr(obj, "model_dump_json", None)):
26-
return obj.model_dump_json()
27-
28-
# Fallback to default JSON serialization
29-
return json.dumps(obj)
30-
31-
3214
async def durable_openai_agent_activity(input: str, model_provider: ModelProvider):
3315
"""Activity logic that handles OpenAI model invocations."""
3416
activity_input = ActivityModelInput.from_json(input)
@@ -42,53 +24,17 @@ async def durable_openai_agent_activity(input: str, model_provider: ModelProvide
4224

4325
def durable_openai_agent_orchestrator_generator(
4426
func,
45-
durable_orchestration_context: DurableOrchestrationContext):
27+
durable_orchestration_context: DurableOrchestrationContext,
28+
model_retry_options: Optional[RetryOptions],
29+
):
4630
"""Adapts the synchronous OpenAI Agents function to an Durable orchestrator generator."""
4731
ensure_event_loop()
48-
durable_ai_agent_context = DurableAIAgentContext(durable_orchestration_context)
32+
task_tracker = TaskTracker(durable_orchestration_context)
33+
durable_ai_agent_context = DurableAIAgentContext(
34+
durable_orchestration_context, task_tracker, model_retry_options
35+
)
4936
durable_openai_runner = DurableOpenAIRunner(context=durable_ai_agent_context)
5037
set_default_agent_runner(durable_openai_runner)
5138

52-
if inspect.isgeneratorfunction(func):
53-
gen = iter(func(durable_ai_agent_context))
54-
try:
55-
# prime the subiterator
56-
value = next(gen)
57-
yield from durable_ai_agent_context._yield_and_clear_tasks()
58-
while True:
59-
try:
60-
# send whatever was sent into us down to the subgenerator
61-
yield from durable_ai_agent_context._yield_and_clear_tasks()
62-
sent = yield value
63-
except GeneratorExit:
64-
# ensure the subgenerator is closed
65-
if hasattr(gen, "close"):
66-
gen.close()
67-
raise
68-
except BaseException as exc:
69-
# forward thrown exceptions if possible
70-
if hasattr(gen, "throw"):
71-
value = gen.throw(type(exc), exc, exc.__traceback__)
72-
else:
73-
raise
74-
else:
75-
# normal path: forward .send (or .__next__)
76-
if hasattr(gen, "send"):
77-
value = gen.send(sent)
78-
else:
79-
value = next(gen)
80-
except StopIteration as e:
81-
yield from durable_ai_agent_context._yield_and_clear_tasks()
82-
return _durable_serializer(e.value)
83-
except YieldException as e:
84-
yield from durable_ai_agent_context._yield_and_clear_tasks()
85-
yield e.task
86-
else:
87-
try:
88-
result = func(durable_ai_agent_context)
89-
return _durable_serializer(result)
90-
except YieldException as e:
91-
yield from durable_ai_agent_context._yield_and_clear_tasks()
92-
yield e.task
93-
finally:
94-
yield from durable_ai_agent_context._yield_and_clear_tasks()
39+
func_with_context = partial(func, durable_ai_agent_context)
40+
return task_tracker.execute_orchestrator_function(func_with_context)

azure/durable_functions/openai_agents/runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from pydantic_core import to_json
1616

1717
from .context import DurableAIAgentContext
18-
from .model_invocation_activity import _DurableModelStub
18+
from .model_invocation_activity import DurableActivityModel
1919

2020
logger = logging.getLogger(__name__)
2121

@@ -58,9 +58,10 @@ def run_sync(
5858

5959
updated_run_config = replace(
6060
run_config,
61-
model=_DurableModelStub(
61+
model=DurableActivityModel(
6262
model_name=model_name,
63-
context=self.context,
63+
task_tracker=self.context._task_tracker,
64+
retry_options=self.context._model_retry_options,
6465
),
6566
)
6667

0 commit comments

Comments
 (0)