diff --git a/azure/durable_functions/decorators/durable_app.py b/azure/durable_functions/decorators/durable_app.py index 0ef92d02..26026bab 100644 --- a/azure/durable_functions/decorators/durable_app.py +++ b/azure/durable_functions/decorators/durable_app.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - +from azure.durable_functions.models.RetryOptions import RetryOptions from .metadata import OrchestrationTrigger, ActivityTrigger, EntityTrigger,\ DurableClient from typing import Callable, Optional @@ -270,7 +270,15 @@ def _setup_durable_openai_agent(self, model_provider): self._create_invoke_model_activity(model_provider) self._is_durable_openai_agent_setup = True - def durable_openai_agent_orchestrator(self, _func=None, *, model_provider=None): + def durable_openai_agent_orchestrator( + self, + _func=None, + *, + model_provider=None, + model_retry_options: Optional[RetryOptions] = RetryOptions( + first_retry_interval_in_milliseconds=2000, max_number_of_attempts=5 + ), + ): """Decorate Azure Durable Functions orchestrators that use OpenAI Agents. Parameters @@ -292,7 +300,9 @@ def generator_wrapper_wrapper(func): @wraps(func) def generator_wrapper(context): - return durable_openai_agent_orchestrator_generator(func, context) + return durable_openai_agent_orchestrator_generator( + func, context, model_retry_options + ) return generator_wrapper diff --git a/azure/durable_functions/openai_agents/context.py b/azure/durable_functions/openai_agents/context.py index e0106b0f..973ab01d 100644 --- a/azure/durable_functions/openai_agents/context.py +++ b/azure/durable_functions/openai_agents/context.py @@ -1,46 +1,41 @@ -import json from typing import Any, Callable, Optional from azure.durable_functions.models.DurableOrchestrationContext import ( DurableOrchestrationContext, ) +from azure.durable_functions.models.RetryOptions import RetryOptions from agents import RunContextWrapper, Tool from agents.function_schema import function_schema from agents.tool import FunctionTool -from .exceptions import YieldException +from .task_tracker import TaskTracker class DurableAIAgentContext: """Context for AI agents running in Azure Durable Functions orchestration.""" - def __init__(self, context: DurableOrchestrationContext): + def __init__( + self, + context: DurableOrchestrationContext, + task_tracker: TaskTracker, + model_retry_options: Optional[RetryOptions], + ): self._context = context - self._activities_called = 0 - self._tasks_to_yield = [] - - def _get_activity_call_result(self, activity_name, input: str): - task = self._context.call_activity(activity_name, input) - - self._activities_called += 1 - - histories = self._context.histories - completed_tasks = [entry for entry in histories if entry.event_type == 5] - if len(completed_tasks) < self._activities_called: - # yield immediately - raise YieldException(task) - else: - # yield later - self._tasks_to_yield.append(task) - - result_json = completed_tasks[self._activities_called - 1].Result - result = json.loads(result_json) - return result + self._task_tracker = task_tracker + self._model_retry_options = model_retry_options def call_activity(self, activity_name, input: str): - """Call an activity function and increment the activity counter.""" + """Call an activity function and record the activity call.""" task = self._context.call_activity(activity_name, input) - self._activities_called += 1 + self._task_tracker.record_activity_call() + return task + + def call_activity_with_retry( + self, activity_name, retry_options: RetryOptions, input: str = None + ): + """Call an activity function with retry options and record the activity call.""" + task = self._context.call_activity_with_retry(activity_name, retry_options, input) + self._task_tracker.record_activity_call() return task def set_custom_status(self, status: str): @@ -51,17 +46,14 @@ def wait_for_external_event(self, event_name: str): """Wait for an external event in the orchestration.""" return self._context.wait_for_external_event(event_name) - def _yield_and_clear_tasks(self): - """Yield all accumulated tasks and clear the tasks list.""" - for task in self._tasks_to_yield: - yield task - self._tasks_to_yield.clear() - def activity_as_tool( self, activity_func: Callable, *, description: Optional[str] = None, + retry_options: Optional[RetryOptions] = RetryOptions( + first_retry_interval_in_milliseconds=2000, max_number_of_attempts=5 + ), ) -> Tool: """Convert an Azure Durable Functions activity to an OpenAI Agents SDK Tool. @@ -69,6 +61,7 @@ def activity_as_tool( ---- activity_func: The Azure Functions activity function to convert description: Optional description override for the tool + retry_options: The retry options for the activity function Returns ------- @@ -78,7 +71,12 @@ def activity_as_tool( activity_name = activity_func._function._name async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any: - result = self._get_activity_call_result(activity_name, input) + if retry_options: + result = self._task_tracker.get_activity_call_result_with_retry( + activity_name, retry_options, input + ) + else: + result = self._task_tracker.get_activity_call_result(activity_name, input) return result schema = function_schema( diff --git a/azure/durable_functions/openai_agents/model_invocation_activity.py b/azure/durable_functions/openai_agents/model_invocation_activity.py index 9f4f1287..31b48730 100644 --- a/azure/durable_functions/openai_agents/model_invocation_activity.py +++ b/azure/durable_functions/openai_agents/model_invocation_activity.py @@ -4,7 +4,7 @@ from datetime import timedelta from typing import Any, AsyncIterator, Optional, Union, cast -import azure.functions as func +from azure.durable_functions.models.RetryOptions import RetryOptions from pydantic import BaseModel, Field from agents import ( AgentOutputSchema, @@ -34,7 +34,7 @@ from openai.types.responses.tool_param import Mcp from openai.types.responses.response_prompt_param import ResponsePromptParam -from .context import DurableAIAgentContext +from .task_tracker import TaskTracker try: from azure.durable_functions import ApplicationError @@ -283,14 +283,18 @@ def make_tool(tool: ToolInput) -> Tool: ) from e -class _DurableModelStub(Model): +class DurableActivityModel(Model): + """A model implementation that uses durable activities for model invocations.""" + def __init__( self, model_name: Optional[str], - context: DurableAIAgentContext, + task_tracker: TaskTracker, + retry_options: Optional[RetryOptions], ) -> None: self.model_name = model_name - self.context = context + self.task_tracker = task_tracker + self.retry_options = retry_options async def get_response( self, @@ -305,6 +309,7 @@ async def get_response( previous_response_id: Optional[str], prompt: Optional[ResponsePromptParam], ) -> ModelResponse: + """Get a response from the model.""" def make_tool_info(tool: Tool) -> ToolInput: if isinstance( tool, @@ -375,9 +380,17 @@ def make_tool_info(tool: Tool) -> ToolInput: activity_input_json = activity_input.to_json() - response = self.context._get_activity_call_result( - "invoke_model_activity", activity_input_json - ) + if self.retry_options: + response = self.task_tracker.get_activity_call_result_with_retry( + "invoke_model_activity", + self.retry_options, + activity_input_json, + ) + else: + response = self.task_tracker.get_activity_call_result( + "invoke_model_activity", activity_input_json + ) + json_response = json.loads(response) model_response = ModelResponse(**json_response) return model_response @@ -395,21 +408,5 @@ def stream_response( previous_response_id: Optional[str], prompt: Optional[ResponsePromptParam], ) -> AsyncIterator[TResponseStreamEvent]: + """Stream a response from the model.""" raise NotImplementedError("Durable model doesn't support streams yet") - - -def create_invoke_model_activity(app: func.FunctionApp, model_provider: Optional[ModelProvider]): - """Create and register the invoke_model_activity function with the provided FunctionApp.""" - - @app.activity_trigger(input_name="input") - async def invoke_model_activity(input: str): - """Activity that handles OpenAI model invocations.""" - activity_input = ActivityModelInput.from_json(input) - - model_invoker = ModelInvoker(model_provider=model_provider) - result = await model_invoker.invoke_model_activity(activity_input) - - json_obj = ModelResponse.__pydantic_serializer__.to_json(result) - return json_obj.decode() - - return invoke_model_activity diff --git a/azure/durable_functions/openai_agents/orchestrator_generator.py b/azure/durable_functions/openai_agents/orchestrator_generator.py index 770d36dc..dc74a411 100644 --- a/azure/durable_functions/openai_agents/orchestrator_generator.py +++ b/azure/durable_functions/openai_agents/orchestrator_generator.py @@ -1,34 +1,16 @@ -import inspect -import json -from typing import Any +from functools import partial +from typing import Optional from agents import ModelProvider, ModelResponse from agents.run import set_default_agent_runner from azure.durable_functions.models.DurableOrchestrationContext import DurableOrchestrationContext -from azure.durable_functions.openai_agents.model_invocation_activity\ - import ActivityModelInput, ModelInvoker +from azure.durable_functions.models.RetryOptions import RetryOptions +from .model_invocation_activity import ActivityModelInput, ModelInvoker +from .task_tracker import TaskTracker from .runner import DurableOpenAIRunner -from .exceptions import YieldException from .context import DurableAIAgentContext from .event_loop import ensure_event_loop -def _durable_serializer(obj: Any) -> str: - # Strings are already "serialized" - if type(obj) is str: - return obj - - # Serialize "Durable" and OpenAI models, and typed dictionaries - if callable(getattr(obj, "to_json", None)): - return obj.to_json() - - # Serialize Pydantic models - if callable(getattr(obj, "model_dump_json", None)): - return obj.model_dump_json() - - # Fallback to default JSON serialization - return json.dumps(obj) - - async def durable_openai_agent_activity(input: str, model_provider: ModelProvider): """Activity logic that handles OpenAI model invocations.""" activity_input = ActivityModelInput.from_json(input) @@ -42,53 +24,17 @@ async def durable_openai_agent_activity(input: str, model_provider: ModelProvide def durable_openai_agent_orchestrator_generator( func, - durable_orchestration_context: DurableOrchestrationContext): + durable_orchestration_context: DurableOrchestrationContext, + model_retry_options: Optional[RetryOptions], +): """Adapts the synchronous OpenAI Agents function to an Durable orchestrator generator.""" ensure_event_loop() - durable_ai_agent_context = DurableAIAgentContext(durable_orchestration_context) + task_tracker = TaskTracker(durable_orchestration_context) + durable_ai_agent_context = DurableAIAgentContext( + durable_orchestration_context, task_tracker, model_retry_options + ) durable_openai_runner = DurableOpenAIRunner(context=durable_ai_agent_context) set_default_agent_runner(durable_openai_runner) - if inspect.isgeneratorfunction(func): - gen = iter(func(durable_ai_agent_context)) - try: - # prime the subiterator - value = next(gen) - yield from durable_ai_agent_context._yield_and_clear_tasks() - while True: - try: - # send whatever was sent into us down to the subgenerator - yield from durable_ai_agent_context._yield_and_clear_tasks() - sent = yield value - except GeneratorExit: - # ensure the subgenerator is closed - if hasattr(gen, "close"): - gen.close() - raise - except BaseException as exc: - # forward thrown exceptions if possible - if hasattr(gen, "throw"): - value = gen.throw(type(exc), exc, exc.__traceback__) - else: - raise - else: - # normal path: forward .send (or .__next__) - if hasattr(gen, "send"): - value = gen.send(sent) - else: - value = next(gen) - except StopIteration as e: - yield from durable_ai_agent_context._yield_and_clear_tasks() - return _durable_serializer(e.value) - except YieldException as e: - yield from durable_ai_agent_context._yield_and_clear_tasks() - yield e.task - else: - try: - result = func(durable_ai_agent_context) - return _durable_serializer(result) - except YieldException as e: - yield from durable_ai_agent_context._yield_and_clear_tasks() - yield e.task - finally: - yield from durable_ai_agent_context._yield_and_clear_tasks() + func_with_context = partial(func, durable_ai_agent_context) + return task_tracker.execute_orchestrator_function(func_with_context) diff --git a/azure/durable_functions/openai_agents/runner.py b/azure/durable_functions/openai_agents/runner.py index 57f609df..ca4a131b 100644 --- a/azure/durable_functions/openai_agents/runner.py +++ b/azure/durable_functions/openai_agents/runner.py @@ -15,7 +15,7 @@ from pydantic_core import to_json from .context import DurableAIAgentContext -from .model_invocation_activity import _DurableModelStub +from .model_invocation_activity import DurableActivityModel logger = logging.getLogger(__name__) @@ -58,9 +58,10 @@ def run_sync( updated_run_config = replace( run_config, - model=_DurableModelStub( + model=DurableActivityModel( model_name=model_name, - context=self.context, + task_tracker=self.context._task_tracker, + retry_options=self.context._model_retry_options, ), ) diff --git a/azure/durable_functions/openai_agents/task_tracker.py b/azure/durable_functions/openai_agents/task_tracker.py new file mode 100644 index 00000000..1f346de7 --- /dev/null +++ b/azure/durable_functions/openai_agents/task_tracker.py @@ -0,0 +1,169 @@ +import json +import inspect +from typing import Any + +from azure.durable_functions.models.DurableOrchestrationContext import ( + DurableOrchestrationContext, +) +from azure.durable_functions.models.history.HistoryEventType import HistoryEventType +from azure.durable_functions.models.RetryOptions import RetryOptions + +from .exceptions import YieldException + + +class TaskTracker: + """Tracks activity calls and handles task result processing for durable AI agents.""" + + def __init__(self, context: DurableOrchestrationContext): + self._context = context + self._activities_called = 0 + self._tasks_to_yield = [] + + def _get_activity_result_or_raise(self, task): + """Return the activity result if available; otherwise raise ``YieldException`` to defer. + + The first time an activity is scheduled its result won't yet exist in the + orchestration history, so we raise ``YieldException`` with the task so the + orchestrator can yield it. On replay, once the corresponding TASK_COMPLETED + history event is present, we capture the result and queue the task for a + later yield (to preserve ordering) while returning the deserialized value. + """ + self.record_activity_call() + + histories = self._context.histories + completed_tasks = [ + entry for entry in histories + if entry.event_type == HistoryEventType.TASK_COMPLETED + ] + if len(completed_tasks) < self._activities_called: + # Result not yet available in history -> raise to signal a yield now + raise YieldException(task) + # Result exists (replay). Queue task to be yielded after returning value. + # + # We cannot just yield it now because this method can be called from + # deeply nested code paths that we don't control (such as the + # OpenAI Agents SDK internals), and yielding here would lead to + # unintended behavior. Instead, we queue the task to be yielded + # later and return the result recorded in the history, so the + # code invoking this method can continue executing normally. + self._tasks_to_yield.append(task) + + result_json = completed_tasks[self._activities_called - 1].Result + result = json.loads(result_json) + return result + + def get_activity_call_result(self, activity_name, input: str): + """Call an activity and return its result or raise ``YieldException`` if pending.""" + task = self._context.call_activity(activity_name, input) + return self._get_activity_result_or_raise(task) + + def get_activity_call_result_with_retry( + self, activity_name, retry_options: RetryOptions, input: str + ): + """Call an activity with retry and return its result or raise YieldException if pending.""" + task = self._context.call_activity_with_retry(activity_name, retry_options, input) + return self._get_activity_result_or_raise(task) + + def record_activity_call(self): + """Record that an activity was called.""" + self._activities_called += 1 + + def _yield_and_clear_tasks(self): + """Yield all accumulated tasks and clear the tasks list.""" + for task in self._tasks_to_yield: + yield task + self._tasks_to_yield.clear() + + def execute_orchestrator_function(self, func): + """Execute the orchestrator function with comprehensive task and exception handling. + + The orchestrator function can exhibit any combination of the following behaviors: + - Execute regular code and return a value or raise an exception + - Invoke get_activity_call_result or get_activity_call_result_with_retry, which leads to + either interrupting the orchestrator function immediately (because of YieldException), + or queueing the task for later yielding while continuing execution + - Invoke DurableAIAgentContext.call_activity or call_activity_with_retry (which must lead + to corresponding record_activity_call invocations) + - Yield tasks (typically produced by DurableAIAgentContext methods like call_activity, + wait_for_external_event, etc.), which may or may not interrupt orchestrator function + execution + - Mix all of the above in any combination + + This method converts both YieldException and regular yields into a sequence of yields + preserving the order, while also capturing return values through the generator protocol. + For example, if the orchestrator function yields task A, then queues task B for yielding, + then raises YieldException wrapping task C, this method makes sure that the resulting + sequence of yields is: (A, B, C). + + Args + ---- + func: The orchestrator function to execute (generator or regular function) + + Yields + ------ + Tasks yielded by the orchestrator function and tasks wrapped in YieldException + + Returns + ------- + The return value from the orchestrator function + """ + if inspect.isgeneratorfunction(func): + gen = iter(func()) + try: + # prime the subiterator + value = next(gen) + yield from self._yield_and_clear_tasks() + while True: + try: + # send whatever was sent into us down to the subgenerator + yield from self._yield_and_clear_tasks() + sent = yield value + except GeneratorExit: + # ensure the subgenerator is closed + if hasattr(gen, "close"): + gen.close() + raise + except BaseException as exc: + # forward thrown exceptions if possible + if hasattr(gen, "throw"): + value = gen.throw(type(exc), exc, exc.__traceback__) + else: + raise + else: + # normal path: forward .send (or .__next__) + if hasattr(gen, "send"): + value = gen.send(sent) + else: + value = next(gen) + except StopIteration as e: + yield from self._yield_and_clear_tasks() + return TaskTracker._durable_serializer(e.value) + except YieldException as e: + yield from self._yield_and_clear_tasks() + yield e.task + else: + try: + result = func() + return TaskTracker._durable_serializer(result) + except YieldException as e: + yield from self._yield_and_clear_tasks() + yield e.task + finally: + yield from self._yield_and_clear_tasks() + + @staticmethod + def _durable_serializer(obj: Any) -> str: + # Strings are already "serialized" + if type(obj) is str: + return obj + + # Serialize "Durable" and OpenAI models, and typed dictionaries + if callable(getattr(obj, "to_json", None)): + return obj.to_json() + + # Serialize Pydantic models + if callable(getattr(obj, "model_dump_json", None)): + return obj.model_dump_json() + + # Fallback to default JSON serialization + return json.dumps(obj) diff --git a/tests/openai_agents/__init__.py b/tests/openai_agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/openai_agents/test_task_tracker.py b/tests/openai_agents/test_task_tracker.py new file mode 100644 index 00000000..46d8e7f7 --- /dev/null +++ b/tests/openai_agents/test_task_tracker.py @@ -0,0 +1,288 @@ +import pytest +import json +from unittest.mock import Mock + +from azure.durable_functions.openai_agents.task_tracker import TaskTracker +from azure.durable_functions.openai_agents.exceptions import YieldException +from azure.durable_functions.models.DurableOrchestrationContext import DurableOrchestrationContext +from azure.durable_functions.models.history.HistoryEvent import HistoryEvent +from azure.durable_functions.models.history.HistoryEventType import HistoryEventType +from azure.durable_functions.models.RetryOptions import RetryOptions + + +class MockTask: + """Mock Task object for testing.""" + + def __init__(self, activity_name: str, input_data: str): + self.activity_name = activity_name + self.input = input_data + self.id = f"task_{activity_name}" + + +def create_mock_context(task_completed_results=None): + """Create a mock DurableOrchestrationContext with configurable history. + + Args: + ---- + task_completed_results: List of objects to be serialized as JSON results. + Each object will be json.dumps() serialized automatically. + """ + context = Mock(spec=DurableOrchestrationContext) + + # Create history events for completed tasks + histories = [] + if task_completed_results: + for i, result_object in enumerate(task_completed_results): + history_event = Mock(spec=HistoryEvent) + history_event.event_type = HistoryEventType.TASK_COMPLETED + history_event.Result = json.dumps(result_object) + histories.append(history_event) + + context.histories = histories + + # Mock call_activity method + def mock_call_activity(activity_name, input_data): + return MockTask(activity_name, input_data) + + context.call_activity = Mock(side_effect=mock_call_activity) + + # Mock call_activity_with_retry method + def mock_call_activity_with_retry(activity_name, retry_options, input_data): + return MockTask(activity_name, input_data) + + context.call_activity_with_retry = Mock(side_effect=mock_call_activity_with_retry) + + return context + + +class TestTaskTracker: + """Tests for the TaskTracker implementation.""" + + def _consume_generator_with_return_value(self, generator): + """Consume a generator and capture both yielded items and return value. + + Returns + ------- + tuple + (yielded_items, return_value) where return_value is None if no return value + """ + yielded_items = [] + return_value = None + try: + while True: + yielded_items.append(next(generator)) + except StopIteration as e: + return_value = e.value + return yielded_items, return_value + + def test_get_activity_call_result_returns_result_when_history_available(self): + """Test get_activity_call_result returns result when history is available.""" + context = create_mock_context(task_completed_results=["test_result"]) + tracker = TaskTracker(context) + + result = tracker.get_activity_call_result("test_activity", "test_input") + assert result == "test_result" + + def test_get_activity_call_result_raises_yield_exception_when_no_history(self): + """Test get_activity_call_result raises YieldException when no history.""" + context = create_mock_context(task_completed_results=[]) + tracker = TaskTracker(context) + + with pytest.raises(YieldException) as exc_info: + tracker.get_activity_call_result("test_activity", "test_input") + + task = exc_info.value.task + assert task.activity_name == "test_activity" + assert task.input == "test_input" + + def test_get_activity_call_result_with_retry_returns_result_when_history_available(self): + """Test get_activity_call_result_with_retry returns result when history is available.""" + context = create_mock_context(task_completed_results=["result"]) + tracker = TaskTracker(context) + retry_options = RetryOptions(1000, 3) + + result = tracker.get_activity_call_result_with_retry("activity", retry_options, "input") + assert result == "result" + + def test_get_activity_call_result_with_retry_raises_yield_exception_when_no_history(self): + """Test get_activity_call_result_with_retry raises YieldException when no history.""" + context = create_mock_context(task_completed_results=[]) + tracker = TaskTracker(context) + retry_options = RetryOptions(1000, 3) + + with pytest.raises(YieldException) as exc_info: + tracker.get_activity_call_result_with_retry("activity", retry_options, "input") + + task = exc_info.value.task + assert task.activity_name == "activity" + assert task.input == "input" + + def test_multiple_activity_calls_with_partial_history(self): + """Test sequential activity calls with partial history available.""" + context = create_mock_context(task_completed_results=["result1", "result2"]) + tracker = TaskTracker(context) + + # First call returns result1 + result1 = tracker.get_activity_call_result("activity1", "input1") + assert result1 == "result1" + + # Second call returns result2 + result2 = tracker.get_activity_call_result("activity2", "input2") + assert result2 == "result2" + + # Third call raises YieldException (no more history) + with pytest.raises(YieldException): + tracker.get_activity_call_result("activity3", "input3") + + def test_execute_orchestrator_function_return_value(self): + """Test execute_orchestrator_function with orchestrator function that returns a value.""" + context = create_mock_context() + tracker = TaskTracker(context) + + expected_result = "orchestrator_result" + + def test_orchestrator(): + return expected_result + + result_gen = tracker.execute_orchestrator_function(test_orchestrator) + yielded_items, return_value = self._consume_generator_with_return_value(result_gen) + + # Should yield nothing and return the value + assert yielded_items == [] + assert return_value == expected_result + + def test_execute_orchestrator_function_get_activity_call_result_incomplete(self): + """Test execute_orchestrator_function with orchestrator function that tries to get an activity result before this activity call completes (not a replay).""" + context = create_mock_context() # No history available + tracker = TaskTracker(context) + + def test_orchestrator(): + return tracker.get_activity_call_result("activity", "test_input") + + result_gen = tracker.execute_orchestrator_function(test_orchestrator) + yielded_items, return_value = self._consume_generator_with_return_value(result_gen) + + # Should yield a task with this activity name + assert yielded_items[0].activity_name == "activity" + assert len(yielded_items) == 1 + assert return_value is None + + def test_execute_orchestrator_function_get_complete_activity_result(self): + """Test execute_orchestrator_function with orchestrator function that gets a complete activity call result (replay).""" + context = create_mock_context(task_completed_results=["activity_result"]) + tracker = TaskTracker(context) + + def test_orchestrator(): + return tracker.get_activity_call_result("activity", "test_input") + + result_gen = tracker.execute_orchestrator_function(test_orchestrator) + yielded_items, return_value = self._consume_generator_with_return_value(result_gen) + + # Should yield the queued task and return the result + assert yielded_items[0].activity_name == "activity" + assert len(yielded_items) == 1 + assert return_value == "activity_result" + + def test_execute_orchestrator_function_yields_tasks(self): + """Test execute_orchestrator_function with orchestrator function that yields tasks.""" + context = create_mock_context() + tracker = TaskTracker(context) + + def test_orchestrator(): + yield "task_1" + yield "task_2" + return "final_result" + + result_gen = tracker.execute_orchestrator_function(test_orchestrator) + yielded_items, return_value = self._consume_generator_with_return_value(result_gen) + + # Should yield the tasks in order and return the final result + assert yielded_items[0] == "task_1" + assert yielded_items[1] == "task_2" + assert len(yielded_items) == 2 + assert return_value == "final_result" + + def test_execute_orchestrator_function_context_activity_call_incomplete(self): + """Test execute_orchestrator_function with orchestrator function that tries to get an activity result before this activity call completes (not a replay) after a DurableAIAgentContext.call_activity invocation.""" + context = create_mock_context(task_completed_results=["result1"]) + tracker = TaskTracker(context) + + def test_orchestrator(): + # Simulate invoking DurableAIAgentContext.call_activity and yielding the resulting task + tracker.record_activity_call() + yield "task" # Produced "result1" + + return tracker.get_activity_call_result("activity", "input") # Incomplete, should raise YieldException that will be translated to yield + + result_gen = tracker.execute_orchestrator_function(test_orchestrator) + yielded_items, return_value = self._consume_generator_with_return_value(result_gen) + + # Should yield the incomplete task + assert yielded_items[0] == "task" + assert yielded_items[1].activity_name == "activity" + assert len(yielded_items) == 2 + assert return_value == None + + def test_execute_orchestrator_function_context_activity_call_complete(self): + """Test execute_orchestrator_function with orchestrator function that gets a complete activity call result (replay) after a DurableAIAgentContext.call_activity invocation.""" + context = create_mock_context(task_completed_results=["result1", "result2"]) + tracker = TaskTracker(context) + + def test_orchestrator(): + # Simulate invoking DurableAIAgentContext.call_activity and yielding the resulting task + tracker.record_activity_call() + yield "task" # Produced "result1" + + return tracker.get_activity_call_result("activity", "input") # Complete, should return "result2" + + result_gen = tracker.execute_orchestrator_function(test_orchestrator) + yielded_items, return_value = self._consume_generator_with_return_value(result_gen) + + # Should yield the queued task and return the result + assert yielded_items[0] == "task" + assert yielded_items[1].activity_name == "activity" + assert len(yielded_items) == 2 + assert return_value == "result2" + + def test_execute_orchestrator_function_mixed_behaviors_combination(self): + """Test execute_orchestrator_function mixing all documented behaviors.""" + context = create_mock_context(task_completed_results=[ + "result1", + "result2", + "result3", + "result4" + ]) + tracker = TaskTracker(context) + + def test_orchestrator(): + activity1_result = tracker.get_activity_call_result("activity1", "input1") + + # Simulate invoking DurableAIAgentContext.call_activity("activity2") and yielding the resulting task + tracker.record_activity_call() + yield "yielded task from activity2" # Produced "result2" + + # Yield a regular task, possibly returned from DurableAIAgentContext methods like wait_for_external_event, etc. + yield "another yielded task" + + activity3_result = tracker.get_activity_call_result("activity3", "input3") + + # Simulate invoking DurableAIAgentContext.call_activity("activity4") and yielding the resulting task + tracker.record_activity_call() + yield "yielded task from activity4" # Produced "result4" + + return f"activity1={activity1_result};activity3={activity3_result}" + + result_gen = tracker.execute_orchestrator_function(test_orchestrator) + yielded_items, return_value = self._consume_generator_with_return_value(result_gen) + + # Verify yield order + assert yielded_items[0].activity_name == "activity1" + assert yielded_items[1] == "yielded task from activity2" + assert yielded_items[2] == "another yielded task" + assert yielded_items[3].activity_name == "activity3" + assert yielded_items[4] == "yielded task from activity4" + assert len(yielded_items) == 5 + + # Verify return value + expected_return = "activity1=result1;activity3=result3" + assert return_value == expected_return diff --git a/tests/orchestrator/openai_agents/test_openai_agents.py b/tests/orchestrator/openai_agents/test_openai_agents.py index 4cd6f522..d930fdcd 100644 --- a/tests/orchestrator/openai_agents/test_openai_agents.py +++ b/tests/orchestrator/openai_agents/test_openai_agents.py @@ -1,9 +1,8 @@ -from typing import Optional, TypedDict - import azure.durable_functions as df import azure.functions as func import json import pydantic +from typing import TypedDict from agents import Agent, Runner from azure.durable_functions.models import OrchestratorState from azure.durable_functions.models.actions import CallActivityAction @@ -17,7 +16,7 @@ @app.function_name("openai_agent_hello_world") @app.orchestration_trigger(context_name="context") -@app.durable_openai_agent_orchestrator +@app.durable_openai_agent_orchestrator(model_retry_options=None) def openai_agent_hello_world(context): agent = Agent( name="Assistant", @@ -47,12 +46,12 @@ def get_weather(city: str) -> Weather: @app.function_name("openai_agent_use_tool") @app.orchestration_trigger(context_name="context") -@app.durable_openai_agent_orchestrator +@app.durable_openai_agent_orchestrator(model_retry_options=None) def openai_agent_use_tool(context): agent = Agent( name="Assistant", instructions="You only respond in haikus.", - tools=[context.activity_as_tool(get_weather)] + tools=[context.activity_as_tool(get_weather, retry_options=None)] ) result = Runner.run_sync(agent, "Tell me the weather in Seattle.", ) @@ -61,7 +60,7 @@ def openai_agent_use_tool(context): @app.function_name("openai_agent_return_string_type") @app.orchestration_trigger(context_name="context") -@app.durable_openai_agent_orchestrator +@app.durable_openai_agent_orchestrator(model_retry_options=None) def openai_agent_return_string_type(context): return "Hello World" @@ -74,7 +73,7 @@ def to_json(self) -> str: @app.function_name("openai_agent_return_durable_model_type") @app.orchestration_trigger(context_name="context") -@app.durable_openai_agent_orchestrator +@app.durable_openai_agent_orchestrator(model_retry_options=None) def openai_agent_return_durable_model_type(context): model = DurableModel(property="value") @@ -85,7 +84,7 @@ class TypedDictionaryModel(TypedDict): @app.function_name("openai_agent_return_typed_dictionary_model_type") @app.orchestration_trigger(context_name="context") -@app.durable_openai_agent_orchestrator +@app.durable_openai_agent_orchestrator(model_retry_options=None) def openai_agent_return_typed_dictionary_model_type(context): model = TypedDictionaryModel(property="value") @@ -96,7 +95,7 @@ class OpenAIPydanticModel(BaseModel): @app.function_name("openai_agent_return_openai_pydantic_model_type") @app.orchestration_trigger(context_name="context") -@app.durable_openai_agent_orchestrator +@app.durable_openai_agent_orchestrator(model_retry_options=None) def openai_agent_return_openai_pydantic_model_type(context): model = OpenAIPydanticModel(property="value") @@ -107,7 +106,7 @@ class PydanticModel(pydantic.BaseModel): @app.function_name("openai_agent_return_pydantic_model_type") @app.orchestration_trigger(context_name="context") -@app.durable_openai_agent_orchestrator +@app.durable_openai_agent_orchestrator(model_retry_options=None) def openai_agent_return_pydantic_model_type(context): model = PydanticModel(property="value")