Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 45 additions & 11 deletions azure/durable_functions/openai_agents/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, TYPE_CHECKING

from azure.durable_functions.models.DurableOrchestrationContext import (
DurableOrchestrationContext,
Expand All @@ -11,8 +11,39 @@
from .task_tracker import TaskTracker


class DurableAIAgentContext:
"""Context for AI agents running in Azure Durable Functions orchestration."""
if TYPE_CHECKING:
# At type-check time we want all members / signatures for IDE & linters.
_BaseDurableContext = DurableOrchestrationContext
else:
class _BaseDurableContext: # lightweight runtime stub
"""Runtime stub base class for delegation; real context is wrapped.

At runtime we avoid inheriting from DurableOrchestrationContext so that
attribute lookups for its members are delegated via __getattr__ to the
wrapped ``_context`` instance.
"""

__slots__ = ()


class DurableAIAgentContext(_BaseDurableContext):
"""Context for AI agents running in Azure Durable Functions orchestration.

Design
------
* Static analysis / IDEs: Appears to subclass ``DurableOrchestrationContext`` so
you get autocompletion and type hints (under TYPE_CHECKING branch).
* Runtime: Inherits from a trivial stub. All durable orchestration operations
are delegated to the real ``DurableOrchestrationContext`` instance provided
as ``context`` and stored in ``_context``.

Consequences
------------
* ``isinstance(DurableAIAgentContext, DurableOrchestrationContext)`` is **False** at
runtime (expected).
* Delegation via ``__getattr__`` works for every member of the real context.
* No reliance on internal initialization side-effects of the durable SDK.
"""

def __init__(
self,
Expand All @@ -38,14 +69,6 @@ def call_activity_with_retry(
self._task_tracker.record_activity_call()
return task

def set_custom_status(self, status: str):
"""Set custom status for the orchestration."""
self._context.set_custom_status(status)

def wait_for_external_event(self, event_name: str):
"""Wait for an external event in the orchestration."""
return self._context.wait_for_external_event(event_name)

def activity_as_tool(
self,
activity_func: Callable,
Expand Down Expand Up @@ -95,3 +118,14 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
on_invoke_tool=run_activity,
strict_json_schema=True,
)

def __getattr__(self, name):
"""Delegate missing attributes to the underlying DurableOrchestrationContext."""
try:
return getattr(self._context, name)
except AttributeError:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

def __dir__(self):
"""Improve introspection and tab-completion by including delegated attributes."""
return sorted(set(dir(type(self)) + list(self.__dict__) + dir(self._context)))
292 changes: 292 additions & 0 deletions tests/openai_agents/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
import pytest
from unittest.mock import Mock, patch

from azure.durable_functions.openai_agents.context import DurableAIAgentContext
from azure.durable_functions.openai_agents.task_tracker import TaskTracker
from azure.durable_functions.models.DurableOrchestrationContext import DurableOrchestrationContext
from azure.durable_functions.models.RetryOptions import RetryOptions

from agents.tool import FunctionTool


class TestDurableAIAgentContext:
"""Test suite for DurableAIAgentContext class."""

def _create_mock_orchestration_context(self):
"""Create a mock DurableOrchestrationContext for testing."""
orchestration_context = Mock(spec=DurableOrchestrationContext)
orchestration_context.call_activity = Mock(return_value="mock_task")
orchestration_context.call_activity_with_retry = Mock(return_value="mock_task_with_retry")
orchestration_context.instance_id = "test_instance_id"
orchestration_context.current_utc_datetime = "2023-01-01T00:00:00Z"
orchestration_context.is_replaying = False
return orchestration_context

def _create_mock_task_tracker(self):
"""Create a mock TaskTracker for testing."""
task_tracker = Mock(spec=TaskTracker)
task_tracker.record_activity_call = Mock()
task_tracker.get_activity_call_result = Mock(return_value="activity_result")
task_tracker.get_activity_call_result_with_retry = Mock(return_value="retry_activity_result")
return task_tracker

def test_init_creates_context_successfully(self):
"""Test that __init__ creates a DurableAIAgentContext successfully."""
orchestration_context = self._create_mock_orchestration_context()
task_tracker = self._create_mock_task_tracker()
retry_options = RetryOptions(1000, 3)

ai_context = DurableAIAgentContext(orchestration_context, task_tracker, retry_options)

assert isinstance(ai_context, DurableAIAgentContext)
assert not isinstance(ai_context, DurableOrchestrationContext)

def test_call_activity_delegates_and_records(self):
"""Test that call_activity delegates to context and records activity call."""
orchestration_context = self._create_mock_orchestration_context()
task_tracker = self._create_mock_task_tracker()

ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)
result = ai_context.call_activity("test_activity", "test_input")

orchestration_context.call_activity.assert_called_once_with("test_activity", "test_input")
task_tracker.record_activity_call.assert_called_once()
assert result == "mock_task"

def test_call_activity_with_retry_delegates_and_records(self):
"""Test that call_activity_with_retry delegates to context and records activity call."""
orchestration_context = self._create_mock_orchestration_context()
task_tracker = self._create_mock_task_tracker()
retry_options = RetryOptions(1000, 3)

ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)
result = ai_context.call_activity_with_retry("test_activity", retry_options, "test_input")

orchestration_context.call_activity_with_retry.assert_called_once_with(
"test_activity", retry_options, "test_input"
)
task_tracker.record_activity_call.assert_called_once()
assert result == "mock_task_with_retry"

@patch('azure.durable_functions.openai_agents.context.function_schema')
@patch('azure.durable_functions.openai_agents.context.FunctionTool')
def test_activity_as_tool_creates_function_tool(self, mock_function_tool, mock_function_schema):
"""Test that activity_as_tool creates a FunctionTool with correct parameters."""
orchestration_context = self._create_mock_orchestration_context()
task_tracker = self._create_mock_task_tracker()

# Mock the activity function
mock_activity_func = Mock()
mock_activity_func._function._name = "test_activity"
mock_activity_func._function._func = lambda x: x

# Mock the schema
mock_schema = Mock()
mock_schema.name = "test_activity"
mock_schema.description = "Test activity description"
mock_schema.params_json_schema = {"type": "object"}
mock_function_schema.return_value = mock_schema

# Mock FunctionTool
mock_tool = Mock(spec=FunctionTool)
mock_function_tool.return_value = mock_tool

ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)
retry_options = RetryOptions(1000, 3)

result = ai_context.activity_as_tool(
mock_activity_func,
description="Custom description",
retry_options=retry_options
)

# Verify function_schema was called correctly
mock_function_schema.assert_called_once_with(
func=mock_activity_func._function._func,
name_override="test_activity",
docstring_style=None,
description_override="Custom description",
use_docstring_info=True,
strict_json_schema=True,
)

# Verify FunctionTool was created correctly
mock_function_tool.assert_called_once()
call_args = mock_function_tool.call_args
assert call_args[1]['name'] == "test_activity"
assert call_args[1]['description'] == "Test activity description"
assert call_args[1]['params_json_schema'] == {"type": "object"}
assert call_args[1]['strict_json_schema'] is True
assert callable(call_args[1]['on_invoke_tool'])

assert result is mock_tool

@patch('azure.durable_functions.openai_agents.context.function_schema')
@patch('azure.durable_functions.openai_agents.context.FunctionTool')
def test_activity_as_tool_with_default_retry_options(self, mock_function_tool, mock_function_schema):
"""Test that activity_as_tool uses default retry options when none provided."""
orchestration_context = self._create_mock_orchestration_context()
task_tracker = self._create_mock_task_tracker()

mock_activity_func = Mock()
mock_activity_func._function._name = "test_activity"
mock_activity_func._function._func = lambda x: x

mock_schema = Mock()
mock_schema.name = "test_activity"
mock_schema.description = "Test description"
mock_schema.params_json_schema = {"type": "object"}
mock_function_schema.return_value = mock_schema

mock_tool = Mock(spec=FunctionTool)
mock_function_tool.return_value = mock_tool

ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)

# Call with default retry options
result = ai_context.activity_as_tool(mock_activity_func)

# Should still create the tool successfully
assert result is mock_tool
mock_function_tool.assert_called_once()

@patch('azure.durable_functions.openai_agents.context.function_schema')
@patch('azure.durable_functions.openai_agents.context.FunctionTool')
def test_activity_as_tool_run_activity_with_retry(self, mock_function_tool, mock_function_schema):
"""Test that the run_activity function calls task tracker with retry options."""
orchestration_context = self._create_mock_orchestration_context()
task_tracker = self._create_mock_task_tracker()

mock_activity_func = Mock()
mock_activity_func._function._name = "test_activity"
mock_activity_func._function._func = lambda x: x

mock_schema = Mock()
mock_schema.name = "test_activity"
mock_schema.description = ""
mock_schema.params_json_schema = {"type": "object"}
mock_function_schema.return_value = mock_schema

mock_tool = Mock(spec=FunctionTool)
mock_function_tool.return_value = mock_tool

ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)
retry_options = RetryOptions(1000, 3)

ai_context.activity_as_tool(mock_activity_func, retry_options=retry_options)

# Get the run_activity function that was passed to FunctionTool
call_args = mock_function_tool.call_args
run_activity = call_args[1]['on_invoke_tool']

# Create a mock context wrapper
mock_ctx = Mock()

# Call the run_activity function
import asyncio
result = asyncio.run(run_activity(mock_ctx, "test_input"))

# Verify the task tracker was called with retry options
task_tracker.get_activity_call_result_with_retry.assert_called_once_with(
"test_activity", retry_options, "test_input"
)
assert result == "retry_activity_result"

@patch('azure.durable_functions.openai_agents.context.function_schema')
@patch('azure.durable_functions.openai_agents.context.FunctionTool')
def test_activity_as_tool_run_activity_without_retry(self, mock_function_tool, mock_function_schema):
"""Test that the run_activity function calls task tracker without retry when retry_options is None."""
orchestration_context = self._create_mock_orchestration_context()
task_tracker = self._create_mock_task_tracker()

mock_activity_func = Mock()
mock_activity_func._function._name = "test_activity"
mock_activity_func._function._func = lambda x: x

mock_schema = Mock()
mock_schema.name = "test_activity"
mock_schema.description = ""
mock_schema.params_json_schema = {"type": "object"}
mock_function_schema.return_value = mock_schema

mock_tool = Mock(spec=FunctionTool)
mock_function_tool.return_value = mock_tool

ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)

ai_context.activity_as_tool(mock_activity_func, retry_options=None)

# Get the run_activity function that was passed to FunctionTool
call_args = mock_function_tool.call_args
run_activity = call_args[1]['on_invoke_tool']

# Create a mock context wrapper
mock_ctx = Mock()

# Call the run_activity function
import asyncio
result = asyncio.run(run_activity(mock_ctx, "test_input"))

# Verify the task tracker was called without retry options
task_tracker.get_activity_call_result.assert_called_once_with(
"test_activity", "test_input"
)
assert result == "activity_result"

def test_context_delegation_methods_work(self):
"""Test that common context methods work through delegation."""
orchestration_context = self._create_mock_orchestration_context()
task_tracker = self._create_mock_task_tracker()

# Add some mock methods to the orchestration context
orchestration_context.wait_for_external_event = Mock(return_value="external_event_task")
orchestration_context.create_timer = Mock(return_value="timer_task")

ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)

# These should work through delegation
result1 = ai_context.wait_for_external_event("test_event")
result2 = ai_context.create_timer("2023-01-01T00:00:00Z")

assert result1 == "external_event_task"
assert result2 == "timer_task"
orchestration_context.wait_for_external_event.assert_called_once_with("test_event")
orchestration_context.create_timer.assert_called_once_with("2023-01-01T00:00:00Z")

def test_getattr_delegates_to_context(self):
"""Test that __getattr__ delegates attribute access to the underlying context."""
orchestration_context = self._create_mock_orchestration_context()
task_tracker = self._create_mock_task_tracker()

ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)

# Test delegation of various attributes
assert ai_context.instance_id == "test_instance_id"
assert ai_context.current_utc_datetime == "2023-01-01T00:00:00Z"
assert ai_context.is_replaying is False

def test_getattr_raises_attribute_error_for_nonexistent_attributes(self):
"""Test that __getattr__ raises AttributeError for non-existent attributes."""
orchestration_context = self._create_mock_orchestration_context()
task_tracker = self._create_mock_task_tracker()

ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)

with pytest.raises(AttributeError, match="'DurableAIAgentContext' object has no attribute 'nonexistent_attr'"):
_ = ai_context.nonexistent_attr

def test_dir_includes_delegated_attributes(self):
"""Test that __dir__ includes attributes from the underlying context."""
orchestration_context = self._create_mock_orchestration_context()
task_tracker = self._create_mock_task_tracker()

ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)
dir_result = dir(ai_context)

# Should include delegated attributes from the underlying context
assert 'instance_id' in dir_result
assert 'current_utc_datetime' in dir_result
assert 'is_replaying' in dir_result
# Should also include public methods
assert 'call_activity' in dir_result
assert 'activity_as_tool' in dir_result