Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions azure/durable_functions/decorators/durable_app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.


from azure.durable_functions.models.RetryOptions import RetryOptions
from .metadata import OrchestrationTrigger, ActivityTrigger, EntityTrigger,\
DurableClient
from typing import Callable, Optional
Expand Down Expand Up @@ -270,7 +270,15 @@ def _setup_durable_openai_agent(self, model_provider):
self._create_invoke_model_activity(model_provider)
self._is_durable_openai_agent_setup = True

def durable_openai_agent_orchestrator(self, _func=None, *, model_provider=None):
def durable_openai_agent_orchestrator(
self,
_func=None,
*,
model_provider=None,
model_retry_options: Optional[RetryOptions] = RetryOptions(
first_retry_interval_in_milliseconds=2000, max_number_of_attempts=5
),
):
"""Decorate Azure Durable Functions orchestrators that use OpenAI Agents.

Parameters
Expand All @@ -292,7 +300,9 @@ def generator_wrapper_wrapper(func):

@wraps(func)
def generator_wrapper(context):
return durable_openai_agent_orchestrator_generator(func, context)
return durable_openai_agent_orchestrator_generator(
func, context, model_retry_options
)

return generator_wrapper

Expand Down
62 changes: 30 additions & 32 deletions azure/durable_functions/openai_agents/context.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,41 @@
import json
from typing import Any, Callable, Optional

from azure.durable_functions.models.DurableOrchestrationContext import (
DurableOrchestrationContext,
)
from azure.durable_functions.models.RetryOptions import RetryOptions

from agents import RunContextWrapper, Tool
from agents.function_schema import function_schema
from agents.tool import FunctionTool
from .exceptions import YieldException
from .task_tracker import TaskTracker


class DurableAIAgentContext:
"""Context for AI agents running in Azure Durable Functions orchestration."""

def __init__(self, context: DurableOrchestrationContext):
def __init__(
self,
context: DurableOrchestrationContext,
task_tracker: TaskTracker,
model_retry_options: Optional[RetryOptions],
):
self._context = context
self._activities_called = 0
self._tasks_to_yield = []

def _get_activity_call_result(self, activity_name, input: str):
task = self._context.call_activity(activity_name, input)

self._activities_called += 1

histories = self._context.histories
completed_tasks = [entry for entry in histories if entry.event_type == 5]
if len(completed_tasks) < self._activities_called:
# yield immediately
raise YieldException(task)
else:
# yield later
self._tasks_to_yield.append(task)

result_json = completed_tasks[self._activities_called - 1].Result
result = json.loads(result_json)
return result
self._task_tracker = task_tracker
self._model_retry_options = model_retry_options

def call_activity(self, activity_name, input: str):
"""Call an activity function and increment the activity counter."""
"""Call an activity function and record the activity call."""
task = self._context.call_activity(activity_name, input)
self._activities_called += 1
self._task_tracker.record_activity_call()
return task

def call_activity_with_retry(
self, activity_name, retry_options: RetryOptions, input: str = None
):
"""Call an activity function with retry options and record the activity call."""
task = self._context.call_activity_with_retry(activity_name, retry_options, input)
self._task_tracker.record_activity_call()
return task

def set_custom_status(self, status: str):
Expand All @@ -51,24 +46,22 @@ def wait_for_external_event(self, event_name: str):
"""Wait for an external event in the orchestration."""
return self._context.wait_for_external_event(event_name)

def _yield_and_clear_tasks(self):
"""Yield all accumulated tasks and clear the tasks list."""
for task in self._tasks_to_yield:
yield task
self._tasks_to_yield.clear()

def activity_as_tool(
self,
activity_func: Callable,
*,
description: Optional[str] = None,
retry_options: Optional[RetryOptions] = RetryOptions(
first_retry_interval_in_milliseconds=2000, max_number_of_attempts=5
),
) -> Tool:
"""Convert an Azure Durable Functions activity to an OpenAI Agents SDK Tool.

Args
----
activity_func: The Azure Functions activity function to convert
description: Optional description override for the tool
retry_options: The retry options for the activity function

Returns
-------
Expand All @@ -78,7 +71,12 @@ def activity_as_tool(
activity_name = activity_func._function._name

async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
result = self._get_activity_call_result(activity_name, input)
if retry_options:
result = self._task_tracker.get_activity_call_result_with_retry(
activity_name, retry_options, input
)
else:
result = self._task_tracker.get_activity_call_result(activity_name, input)
return result

schema = function_schema(
Expand Down
47 changes: 22 additions & 25 deletions azure/durable_functions/openai_agents/model_invocation_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import timedelta
from typing import Any, AsyncIterator, Optional, Union, cast

import azure.functions as func
from azure.durable_functions.models.RetryOptions import RetryOptions
from pydantic import BaseModel, Field
from agents import (
AgentOutputSchema,
Expand Down Expand Up @@ -34,7 +34,7 @@
from openai.types.responses.tool_param import Mcp
from openai.types.responses.response_prompt_param import ResponsePromptParam

from .context import DurableAIAgentContext
from .task_tracker import TaskTracker

try:
from azure.durable_functions import ApplicationError
Expand Down Expand Up @@ -283,14 +283,18 @@ def make_tool(tool: ToolInput) -> Tool:
) from e


class _DurableModelStub(Model):
class DurableActivityModel(Model):
"""A model implementation that uses durable activities for model invocations."""

def __init__(
self,
model_name: Optional[str],
context: DurableAIAgentContext,
task_tracker: TaskTracker,
retry_options: Optional[RetryOptions],
) -> None:
self.model_name = model_name
self.context = context
self.task_tracker = task_tracker
self.retry_options = retry_options

async def get_response(
self,
Expand All @@ -305,6 +309,7 @@ async def get_response(
previous_response_id: Optional[str],
prompt: Optional[ResponsePromptParam],
) -> ModelResponse:
"""Get a response from the model."""
def make_tool_info(tool: Tool) -> ToolInput:
if isinstance(
tool,
Expand Down Expand Up @@ -375,9 +380,17 @@ def make_tool_info(tool: Tool) -> ToolInput:

activity_input_json = activity_input.to_json()

response = self.context._get_activity_call_result(
"invoke_model_activity", activity_input_json
)
if self.retry_options:
response = self.task_tracker.get_activity_call_result_with_retry(
"invoke_model_activity",
self.retry_options,
activity_input_json,
)
else:
response = self.task_tracker.get_activity_call_result(
"invoke_model_activity", activity_input_json
)

json_response = json.loads(response)
model_response = ModelResponse(**json_response)
return model_response
Expand All @@ -395,21 +408,5 @@ def stream_response(
previous_response_id: Optional[str],
prompt: Optional[ResponsePromptParam],
) -> AsyncIterator[TResponseStreamEvent]:
"""Stream a response from the model."""
raise NotImplementedError("Durable model doesn't support streams yet")


def create_invoke_model_activity(app: func.FunctionApp, model_provider: Optional[ModelProvider]):
"""Create and register the invoke_model_activity function with the provided FunctionApp."""

@app.activity_trigger(input_name="input")
async def invoke_model_activity(input: str):
"""Activity that handles OpenAI model invocations."""
activity_input = ActivityModelInput.from_json(input)

model_invoker = ModelInvoker(model_provider=model_provider)
result = await model_invoker.invoke_model_activity(activity_input)

json_obj = ModelResponse.__pydantic_serializer__.to_json(result)
return json_obj.decode()

return invoke_model_activity
82 changes: 14 additions & 68 deletions azure/durable_functions/openai_agents/orchestrator_generator.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,16 @@
import inspect
import json
from typing import Any
from functools import partial
from typing import Optional
from agents import ModelProvider, ModelResponse
from agents.run import set_default_agent_runner
from azure.durable_functions.models.DurableOrchestrationContext import DurableOrchestrationContext
from azure.durable_functions.openai_agents.model_invocation_activity\
import ActivityModelInput, ModelInvoker
from azure.durable_functions.models.RetryOptions import RetryOptions
from .model_invocation_activity import ActivityModelInput, ModelInvoker
from .task_tracker import TaskTracker
from .runner import DurableOpenAIRunner
from .exceptions import YieldException
from .context import DurableAIAgentContext
from .event_loop import ensure_event_loop


def _durable_serializer(obj: Any) -> str:
# Strings are already "serialized"
if type(obj) is str:
return obj

# Serialize "Durable" and OpenAI models, and typed dictionaries
if callable(getattr(obj, "to_json", None)):
return obj.to_json()

# Serialize Pydantic models
if callable(getattr(obj, "model_dump_json", None)):
return obj.model_dump_json()

# Fallback to default JSON serialization
return json.dumps(obj)


async def durable_openai_agent_activity(input: str, model_provider: ModelProvider):
"""Activity logic that handles OpenAI model invocations."""
activity_input = ActivityModelInput.from_json(input)
Expand All @@ -42,53 +24,17 @@ async def durable_openai_agent_activity(input: str, model_provider: ModelProvide

def durable_openai_agent_orchestrator_generator(
func,
durable_orchestration_context: DurableOrchestrationContext):
durable_orchestration_context: DurableOrchestrationContext,
model_retry_options: Optional[RetryOptions],
):
"""Adapts the synchronous OpenAI Agents function to an Durable orchestrator generator."""
ensure_event_loop()
durable_ai_agent_context = DurableAIAgentContext(durable_orchestration_context)
task_tracker = TaskTracker(durable_orchestration_context)
durable_ai_agent_context = DurableAIAgentContext(
durable_orchestration_context, task_tracker, model_retry_options
)
durable_openai_runner = DurableOpenAIRunner(context=durable_ai_agent_context)
set_default_agent_runner(durable_openai_runner)

if inspect.isgeneratorfunction(func):
gen = iter(func(durable_ai_agent_context))
try:
# prime the subiterator
value = next(gen)
yield from durable_ai_agent_context._yield_and_clear_tasks()
while True:
try:
# send whatever was sent into us down to the subgenerator
yield from durable_ai_agent_context._yield_and_clear_tasks()
sent = yield value
except GeneratorExit:
# ensure the subgenerator is closed
if hasattr(gen, "close"):
gen.close()
raise
except BaseException as exc:
# forward thrown exceptions if possible
if hasattr(gen, "throw"):
value = gen.throw(type(exc), exc, exc.__traceback__)
else:
raise
else:
# normal path: forward .send (or .__next__)
if hasattr(gen, "send"):
value = gen.send(sent)
else:
value = next(gen)
except StopIteration as e:
yield from durable_ai_agent_context._yield_and_clear_tasks()
return _durable_serializer(e.value)
except YieldException as e:
yield from durable_ai_agent_context._yield_and_clear_tasks()
yield e.task
else:
try:
result = func(durable_ai_agent_context)
return _durable_serializer(result)
except YieldException as e:
yield from durable_ai_agent_context._yield_and_clear_tasks()
yield e.task
finally:
yield from durable_ai_agent_context._yield_and_clear_tasks()
func_with_context = partial(func, durable_ai_agent_context)
return task_tracker.execute_orchestrator_function(func_with_context)
7 changes: 4 additions & 3 deletions azure/durable_functions/openai_agents/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pydantic_core import to_json

from .context import DurableAIAgentContext
from .model_invocation_activity import _DurableModelStub
from .model_invocation_activity import DurableActivityModel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,9 +58,10 @@ def run_sync(

updated_run_config = replace(
run_config,
model=_DurableModelStub(
model=DurableActivityModel(
model_name=model_name,
context=self.context,
task_tracker=self.context._task_tracker,
retry_options=self.context._model_retry_options,
),
)

Expand Down
Loading