diff --git a/azure/durable_functions/openai_agents/context.py b/azure/durable_functions/openai_agents/context.py index a9de2f43..585e4f9c 100644 --- a/azure/durable_functions/openai_agents/context.py +++ b/azure/durable_functions/openai_agents/context.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, TYPE_CHECKING from azure.durable_functions.models.DurableOrchestrationContext import ( DurableOrchestrationContext, @@ -11,8 +11,39 @@ from .task_tracker import TaskTracker -class DurableAIAgentContext: - """Context for AI agents running in Azure Durable Functions orchestration.""" +if TYPE_CHECKING: + # At type-check time we want all members / signatures for IDE & linters. + _BaseDurableContext = DurableOrchestrationContext +else: + class _BaseDurableContext: # lightweight runtime stub + """Runtime stub base class for delegation; real context is wrapped. + + At runtime we avoid inheriting from DurableOrchestrationContext so that + attribute lookups for its members are delegated via __getattr__ to the + wrapped ``_context`` instance. + """ + + __slots__ = () + + +class DurableAIAgentContext(_BaseDurableContext): + """Context for AI agents running in Azure Durable Functions orchestration. + + Design + ------ + * Static analysis / IDEs: Appears to subclass ``DurableOrchestrationContext`` so + you get autocompletion and type hints (under TYPE_CHECKING branch). + * Runtime: Inherits from a trivial stub. All durable orchestration operations + are delegated to the real ``DurableOrchestrationContext`` instance provided + as ``context`` and stored in ``_context``. + + Consequences + ------------ + * ``isinstance(DurableAIAgentContext, DurableOrchestrationContext)`` is **False** at + runtime (expected). + * Delegation via ``__getattr__`` works for every member of the real context. + * No reliance on internal initialization side-effects of the durable SDK. + """ def __init__( self, @@ -38,14 +69,6 @@ def call_activity_with_retry( self._task_tracker.record_activity_call() return task - def set_custom_status(self, status: str): - """Set custom status for the orchestration.""" - self._context.set_custom_status(status) - - def wait_for_external_event(self, event_name: str): - """Wait for an external event in the orchestration.""" - return self._context.wait_for_external_event(event_name) - def create_activity_tool( self, activity_func: Callable, @@ -101,3 +124,14 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any: on_invoke_tool=run_activity, strict_json_schema=True, ) + + def __getattr__(self, name): + """Delegate missing attributes to the underlying DurableOrchestrationContext.""" + try: + return getattr(self._context, name) + except AttributeError: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def __dir__(self): + """Improve introspection and tab-completion by including delegated attributes.""" + return sorted(set(dir(type(self)) + list(self.__dict__) + dir(self._context))) diff --git a/tests/openai_agents/test_context.py b/tests/openai_agents/test_context.py new file mode 100644 index 00000000..cb887005 --- /dev/null +++ b/tests/openai_agents/test_context.py @@ -0,0 +1,335 @@ +import pytest +from unittest.mock import Mock, patch + +from azure.durable_functions.openai_agents.context import DurableAIAgentContext +from azure.durable_functions.openai_agents.task_tracker import TaskTracker +from azure.durable_functions.models.DurableOrchestrationContext import DurableOrchestrationContext +from azure.durable_functions.models.RetryOptions import RetryOptions + +from agents.tool import FunctionTool + + +class TestDurableAIAgentContext: + """Test suite for DurableAIAgentContext class.""" + + def _create_mock_orchestration_context(self): + """Create a mock DurableOrchestrationContext for testing.""" + orchestration_context = Mock(spec=DurableOrchestrationContext) + orchestration_context.call_activity = Mock(return_value="mock_task") + orchestration_context.call_activity_with_retry = Mock(return_value="mock_task_with_retry") + orchestration_context.instance_id = "test_instance_id" + orchestration_context.current_utc_datetime = "2023-01-01T00:00:00Z" + orchestration_context.is_replaying = False + return orchestration_context + + def _create_mock_task_tracker(self): + """Create a mock TaskTracker for testing.""" + task_tracker = Mock(spec=TaskTracker) + task_tracker.record_activity_call = Mock() + task_tracker.get_activity_call_result = Mock(return_value="activity_result") + task_tracker.get_activity_call_result_with_retry = Mock(return_value="retry_activity_result") + return task_tracker + + def test_init_creates_context_successfully(self): + """Test that __init__ creates a DurableAIAgentContext successfully.""" + orchestration_context = self._create_mock_orchestration_context() + task_tracker = self._create_mock_task_tracker() + retry_options = RetryOptions(1000, 3) + + ai_context = DurableAIAgentContext(orchestration_context, task_tracker, retry_options) + + assert isinstance(ai_context, DurableAIAgentContext) + assert not isinstance(ai_context, DurableOrchestrationContext) + + def test_call_activity_delegates_and_records(self): + """Test that call_activity delegates to context and records activity call.""" + orchestration_context = self._create_mock_orchestration_context() + task_tracker = self._create_mock_task_tracker() + + ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None) + result = ai_context.call_activity("test_activity", "test_input") + + orchestration_context.call_activity.assert_called_once_with("test_activity", "test_input") + task_tracker.record_activity_call.assert_called_once() + assert result == "mock_task" + + def test_call_activity_with_retry_delegates_and_records(self): + """Test that call_activity_with_retry delegates to context and records activity call.""" + orchestration_context = self._create_mock_orchestration_context() + task_tracker = self._create_mock_task_tracker() + retry_options = RetryOptions(1000, 3) + + ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None) + result = ai_context.call_activity_with_retry("test_activity", retry_options, "test_input") + + orchestration_context.call_activity_with_retry.assert_called_once_with( + "test_activity", retry_options, "test_input" + ) + task_tracker.record_activity_call.assert_called_once() + assert result == "mock_task_with_retry" + + @patch('azure.durable_functions.openai_agents.context.function_schema') + @patch('azure.durable_functions.openai_agents.context.FunctionTool') + def test_activity_as_tool_creates_function_tool(self, mock_function_tool, mock_function_schema): + """Test that create_activity_tool creates a FunctionTool with correct parameters.""" + orchestration_context = self._create_mock_orchestration_context() + task_tracker = self._create_mock_task_tracker() + + # Mock the activity function + mock_activity_func = Mock() + mock_activity_func._function._name = "test_activity" + mock_activity_func._function._func = lambda x: x + + # Mock the schema + mock_schema = Mock() + mock_schema.name = "test_activity" + mock_schema.description = "Test activity description" + mock_schema.params_json_schema = {"type": "object"} + mock_function_schema.return_value = mock_schema + + # Mock FunctionTool + mock_tool = Mock(spec=FunctionTool) + mock_function_tool.return_value = mock_tool + + ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None) + retry_options = RetryOptions(1000, 3) + + result = ai_context.create_activity_tool( + mock_activity_func, + description="Custom description", + retry_options=retry_options + ) + + # Verify function_schema was called correctly + mock_function_schema.assert_called_once_with( + func=mock_activity_func._function._func, + docstring_style=None, + description_override="Custom description", + use_docstring_info=True, + strict_json_schema=True, + ) + + # Verify FunctionTool was created correctly + mock_function_tool.assert_called_once() + call_args = mock_function_tool.call_args + assert call_args[1]['name'] == "test_activity" + assert call_args[1]['description'] == "Test activity description" + assert call_args[1]['params_json_schema'] == {"type": "object"} + assert call_args[1]['strict_json_schema'] is True + assert callable(call_args[1]['on_invoke_tool']) + + assert result is mock_tool + + @patch('azure.durable_functions.openai_agents.context.function_schema') + @patch('azure.durable_functions.openai_agents.context.FunctionTool') + def test_activity_as_tool_with_default_retry_options(self, mock_function_tool, mock_function_schema): + """Test that create_activity_tool uses default retry options when none provided.""" + orchestration_context = self._create_mock_orchestration_context() + task_tracker = self._create_mock_task_tracker() + + mock_activity_func = Mock() + mock_activity_func._function._name = "test_activity" + mock_activity_func._function._func = lambda x: x + + mock_schema = Mock() + mock_schema.name = "test_activity" + mock_schema.description = "Test description" + mock_schema.params_json_schema = {"type": "object"} + mock_function_schema.return_value = mock_schema + + mock_tool = Mock(spec=FunctionTool) + mock_function_tool.return_value = mock_tool + + ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None) + + # Call with default retry options + result = ai_context.create_activity_tool(mock_activity_func) + + # Should still create the tool successfully + assert result is mock_tool + mock_function_tool.assert_called_once() + + @patch('azure.durable_functions.openai_agents.context.function_schema') + @patch('azure.durable_functions.openai_agents.context.FunctionTool') + def test_activity_as_tool_run_activity_with_retry(self, mock_function_tool, mock_function_schema): + """Test that the run_activity function calls task tracker with retry options.""" + orchestration_context = self._create_mock_orchestration_context() + task_tracker = self._create_mock_task_tracker() + + mock_activity_func = Mock() + mock_activity_func._function._name = "test_activity" + mock_activity_func._function._trigger = None + mock_activity_func._function._func = lambda x: x + + mock_schema = Mock() + mock_schema.name = "test_activity" + mock_schema.description = "" + mock_schema.params_json_schema = {"type": "object"} + mock_function_schema.return_value = mock_schema + + mock_tool = Mock(spec=FunctionTool) + mock_function_tool.return_value = mock_tool + + ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None) + retry_options = RetryOptions(1000, 3) + + ai_context.create_activity_tool(mock_activity_func, retry_options=retry_options) + + # Get the run_activity function that was passed to FunctionTool + call_args = mock_function_tool.call_args + run_activity = call_args[1]['on_invoke_tool'] + + # Create a mock context wrapper + mock_ctx = Mock() + + # Call the run_activity function + import asyncio + result = asyncio.run(run_activity(mock_ctx, "test_input")) + + # Verify the task tracker was called with retry options + task_tracker.get_activity_call_result_with_retry.assert_called_once_with( + "test_activity", retry_options, "test_input" + ) + assert result == "retry_activity_result" + + @patch('azure.durable_functions.openai_agents.context.function_schema') + @patch('azure.durable_functions.openai_agents.context.FunctionTool') + def test_activity_as_tool_run_activity_without_retry(self, mock_function_tool, mock_function_schema): + """Test that the run_activity function calls task tracker without retry when retry_options is None.""" + orchestration_context = self._create_mock_orchestration_context() + task_tracker = self._create_mock_task_tracker() + + mock_activity_func = Mock() + mock_activity_func._function._name = "test_activity" + mock_activity_func._function._trigger = None + mock_activity_func._function._func = lambda x: x + + mock_schema = Mock() + mock_schema.name = "test_activity" + mock_schema.description = "" + mock_schema.params_json_schema = {"type": "object"} + mock_function_schema.return_value = mock_schema + + mock_tool = Mock(spec=FunctionTool) + mock_function_tool.return_value = mock_tool + + ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None) + + ai_context.create_activity_tool(mock_activity_func, retry_options=None) + + # Get the run_activity function that was passed to FunctionTool + call_args = mock_function_tool.call_args + run_activity = call_args[1]['on_invoke_tool'] + + # Create a mock context wrapper + mock_ctx = Mock() + + # Call the run_activity function + import asyncio + result = asyncio.run(run_activity(mock_ctx, "test_input")) + + # Verify the task tracker was called without retry options + task_tracker.get_activity_call_result.assert_called_once_with( + "test_activity", "test_input" + ) + assert result == "activity_result" + + @patch('azure.durable_functions.openai_agents.context.function_schema') + @patch('azure.durable_functions.openai_agents.context.FunctionTool') + def test_activity_as_tool_extracts_activity_name_from_trigger(self, mock_function_tool, mock_function_schema): + """Test that the run_activity function calls task tracker with the activity name specified in the trigger.""" + orchestration_context = self._create_mock_orchestration_context() + task_tracker = self._create_mock_task_tracker() + + mock_activity_func = Mock() + mock_activity_func._function._name = "test_activity" + mock_activity_func._function._trigger.activity = "activity_name_from_trigger" + mock_activity_func._function._func = lambda x: x + + mock_schema = Mock() + mock_schema.name = "test_activity" + mock_schema.description = "" + mock_schema.params_json_schema = {"type": "object"} + mock_function_schema.return_value = mock_schema + + mock_tool = Mock(spec=FunctionTool) + mock_function_tool.return_value = mock_tool + + ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None) + + ai_context.create_activity_tool(mock_activity_func, retry_options=None) + + # Get the run_activity function that was passed to FunctionTool + call_args = mock_function_tool.call_args + run_activity = call_args[1]['on_invoke_tool'] + + # Create a mock context wrapper + mock_ctx = Mock() + + # Call the run_activity function + import asyncio + result = asyncio.run(run_activity(mock_ctx, "test_input")) + + # Verify the task tracker was called without retry options + task_tracker.get_activity_call_result.assert_called_once_with( + "activity_name_from_trigger", "test_input" + ) + assert result == "activity_result" + + def test_context_delegation_methods_work(self): + """Test that common context methods work through delegation.""" + orchestration_context = self._create_mock_orchestration_context() + task_tracker = self._create_mock_task_tracker() + + # Add some mock methods to the orchestration context + orchestration_context.wait_for_external_event = Mock(return_value="external_event_task") + orchestration_context.create_timer = Mock(return_value="timer_task") + + ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None) + + # These should work through delegation + result1 = ai_context.wait_for_external_event("test_event") + result2 = ai_context.create_timer("2023-01-01T00:00:00Z") + + assert result1 == "external_event_task" + assert result2 == "timer_task" + orchestration_context.wait_for_external_event.assert_called_once_with("test_event") + orchestration_context.create_timer.assert_called_once_with("2023-01-01T00:00:00Z") + + def test_getattr_delegates_to_context(self): + """Test that __getattr__ delegates attribute access to the underlying context.""" + orchestration_context = self._create_mock_orchestration_context() + task_tracker = self._create_mock_task_tracker() + + ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None) + + # Test delegation of various attributes + assert ai_context.instance_id == "test_instance_id" + assert ai_context.current_utc_datetime == "2023-01-01T00:00:00Z" + assert ai_context.is_replaying is False + + def test_getattr_raises_attribute_error_for_nonexistent_attributes(self): + """Test that __getattr__ raises AttributeError for non-existent attributes.""" + orchestration_context = self._create_mock_orchestration_context() + task_tracker = self._create_mock_task_tracker() + + ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None) + + with pytest.raises(AttributeError, match="'DurableAIAgentContext' object has no attribute 'nonexistent_attr'"): + _ = ai_context.nonexistent_attr + + def test_dir_includes_delegated_attributes(self): + """Test that __dir__ includes attributes from the underlying context.""" + orchestration_context = self._create_mock_orchestration_context() + task_tracker = self._create_mock_task_tracker() + + ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None) + dir_result = dir(ai_context) + + # Should include delegated attributes from the underlying context + assert 'instance_id' in dir_result + assert 'current_utc_datetime' in dir_result + assert 'is_replaying' in dir_result + # Should also include public methods + assert 'call_activity' in dir_result + assert 'create_activity_tool' in dir_result