Skip to content

Commit ff69827

Browse files
committed
Make DurableAIAgentContext delegate to the underlying DurableOrchestrationContext automatically
1 parent 1b3ac4c commit ff69827

File tree

2 files changed

+336
-11
lines changed

2 files changed

+336
-11
lines changed

azure/durable_functions/openai_agents/context.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Optional
1+
from typing import Any, Callable, Optional, TYPE_CHECKING
22

33
from azure.durable_functions.models.DurableOrchestrationContext import (
44
DurableOrchestrationContext,
@@ -11,8 +11,38 @@
1111
from .task_tracker import TaskTracker
1212

1313

14-
class DurableAIAgentContext:
15-
"""Context for AI agents running in Azure Durable Functions orchestration."""
14+
if TYPE_CHECKING:
15+
# At type-check time we want all members / signatures for IDE & linters.
16+
_BaseDurableContext = DurableOrchestrationContext
17+
else:
18+
class _BaseDurableContext: # lightweight runtime stub
19+
"""Runtime stub base class for delegation; real context is wrapped.
20+
21+
At runtime we avoid inheriting from DurableOrchestrationContext so that
22+
attribute lookups for its members are delegated via __getattr__ to the
23+
wrapped ``_context`` instance.
24+
"""
25+
__slots__ = ()
26+
27+
28+
class DurableAIAgentContext(_BaseDurableContext):
29+
"""Context for AI agents running in Azure Durable Functions orchestration.
30+
31+
Design
32+
------
33+
* Static analysis / IDEs: Appears to subclass ``DurableOrchestrationContext`` so
34+
you get autocompletion and type hints (under TYPE_CHECKING branch).
35+
* Runtime: Inherits from a trivial stub. All durable orchestration operations
36+
are delegated to the real ``DurableOrchestrationContext`` instance provided
37+
as ``context`` and stored in ``_context``.
38+
39+
Consequences
40+
------------
41+
* ``isinstance(DurableAIAgentContext, DurableOrchestrationContext)`` is **False** at
42+
runtime (expected).
43+
* Delegation via ``__getattr__`` works for every member of the real context.
44+
* No reliance on internal initialization side-effects of the durable SDK.
45+
"""
1646

1747
def __init__(
1848
self,
@@ -38,14 +68,6 @@ def call_activity_with_retry(
3868
self._task_tracker.record_activity_call()
3969
return task
4070

41-
def set_custom_status(self, status: str):
42-
"""Set custom status for the orchestration."""
43-
self._context.set_custom_status(status)
44-
45-
def wait_for_external_event(self, event_name: str):
46-
"""Wait for an external event in the orchestration."""
47-
return self._context.wait_for_external_event(event_name)
48-
4971
def activity_as_tool(
5072
self,
5173
activity_func: Callable,
@@ -95,3 +117,14 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
95117
on_invoke_tool=run_activity,
96118
strict_json_schema=True,
97119
)
120+
121+
def __getattr__(self, name):
122+
"""Delegate missing attributes to the underlying DurableOrchestrationContext."""
123+
try:
124+
return getattr(self._context, name)
125+
except AttributeError:
126+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
127+
128+
def __dir__(self):
129+
"""Improve introspection and tab-completion by including delegated attributes."""
130+
return sorted(set(dir(type(self)) + list(self.__dict__) + dir(self._context)))
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
import pytest
2+
from unittest.mock import Mock, patch
3+
4+
from azure.durable_functions.openai_agents.context import DurableAIAgentContext
5+
from azure.durable_functions.openai_agents.task_tracker import TaskTracker
6+
from azure.durable_functions.models.DurableOrchestrationContext import DurableOrchestrationContext
7+
from azure.durable_functions.models.RetryOptions import RetryOptions
8+
9+
from agents.tool import FunctionTool
10+
11+
12+
class TestDurableAIAgentContext:
13+
"""Test suite for DurableAIAgentContext class."""
14+
15+
def _create_mock_orchestration_context(self):
16+
"""Create a mock DurableOrchestrationContext for testing."""
17+
orchestration_context = Mock(spec=DurableOrchestrationContext)
18+
orchestration_context.call_activity = Mock(return_value="mock_task")
19+
orchestration_context.call_activity_with_retry = Mock(return_value="mock_task_with_retry")
20+
orchestration_context.instance_id = "test_instance_id"
21+
orchestration_context.current_utc_datetime = "2023-01-01T00:00:00Z"
22+
orchestration_context.is_replaying = False
23+
return orchestration_context
24+
25+
def _create_mock_task_tracker(self):
26+
"""Create a mock TaskTracker for testing."""
27+
task_tracker = Mock(spec=TaskTracker)
28+
task_tracker.record_activity_call = Mock()
29+
task_tracker.get_activity_call_result = Mock(return_value="activity_result")
30+
task_tracker.get_activity_call_result_with_retry = Mock(return_value="retry_activity_result")
31+
return task_tracker
32+
33+
def test_init_creates_context_successfully(self):
34+
"""Test that __init__ creates a DurableAIAgentContext successfully."""
35+
orchestration_context = self._create_mock_orchestration_context()
36+
task_tracker = self._create_mock_task_tracker()
37+
retry_options = RetryOptions(1000, 3)
38+
39+
ai_context = DurableAIAgentContext(orchestration_context, task_tracker, retry_options)
40+
41+
assert isinstance(ai_context, DurableAIAgentContext)
42+
assert not isinstance(ai_context, DurableOrchestrationContext)
43+
44+
def test_call_activity_delegates_and_records(self):
45+
"""Test that call_activity delegates to context and records activity call."""
46+
orchestration_context = self._create_mock_orchestration_context()
47+
task_tracker = self._create_mock_task_tracker()
48+
49+
ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)
50+
result = ai_context.call_activity("test_activity", "test_input")
51+
52+
orchestration_context.call_activity.assert_called_once_with("test_activity", "test_input")
53+
task_tracker.record_activity_call.assert_called_once()
54+
assert result == "mock_task"
55+
56+
def test_call_activity_with_retry_delegates_and_records(self):
57+
"""Test that call_activity_with_retry delegates to context and records activity call."""
58+
orchestration_context = self._create_mock_orchestration_context()
59+
task_tracker = self._create_mock_task_tracker()
60+
retry_options = RetryOptions(1000, 3)
61+
62+
ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)
63+
result = ai_context.call_activity_with_retry("test_activity", retry_options, "test_input")
64+
65+
orchestration_context.call_activity_with_retry.assert_called_once_with(
66+
"test_activity", retry_options, "test_input"
67+
)
68+
task_tracker.record_activity_call.assert_called_once()
69+
assert result == "mock_task_with_retry"
70+
71+
@patch('azure.durable_functions.openai_agents.context.function_schema')
72+
@patch('azure.durable_functions.openai_agents.context.FunctionTool')
73+
def test_activity_as_tool_creates_function_tool(self, mock_function_tool, mock_function_schema):
74+
"""Test that activity_as_tool creates a FunctionTool with correct parameters."""
75+
orchestration_context = self._create_mock_orchestration_context()
76+
task_tracker = self._create_mock_task_tracker()
77+
78+
# Mock the activity function
79+
mock_activity_func = Mock()
80+
mock_activity_func._function._name = "test_activity"
81+
mock_activity_func._function._func = lambda x: x
82+
83+
# Mock the schema
84+
mock_schema = Mock()
85+
mock_schema.name = "test_activity"
86+
mock_schema.description = "Test activity description"
87+
mock_schema.params_json_schema = {"type": "object"}
88+
mock_function_schema.return_value = mock_schema
89+
90+
# Mock FunctionTool
91+
mock_tool = Mock(spec=FunctionTool)
92+
mock_function_tool.return_value = mock_tool
93+
94+
ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)
95+
retry_options = RetryOptions(1000, 3)
96+
97+
result = ai_context.activity_as_tool(
98+
mock_activity_func,
99+
description="Custom description",
100+
retry_options=retry_options
101+
)
102+
103+
# Verify function_schema was called correctly
104+
mock_function_schema.assert_called_once_with(
105+
func=mock_activity_func._function._func,
106+
name_override="test_activity",
107+
docstring_style=None,
108+
description_override="Custom description",
109+
use_docstring_info=True,
110+
strict_json_schema=True,
111+
)
112+
113+
# Verify FunctionTool was created correctly
114+
mock_function_tool.assert_called_once()
115+
call_args = mock_function_tool.call_args
116+
assert call_args[1]['name'] == "test_activity"
117+
assert call_args[1]['description'] == "Test activity description"
118+
assert call_args[1]['params_json_schema'] == {"type": "object"}
119+
assert call_args[1]['strict_json_schema'] is True
120+
assert callable(call_args[1]['on_invoke_tool'])
121+
122+
assert result is mock_tool
123+
124+
@patch('azure.durable_functions.openai_agents.context.function_schema')
125+
@patch('azure.durable_functions.openai_agents.context.FunctionTool')
126+
def test_activity_as_tool_with_default_retry_options(self, mock_function_tool, mock_function_schema):
127+
"""Test that activity_as_tool uses default retry options when none provided."""
128+
orchestration_context = self._create_mock_orchestration_context()
129+
task_tracker = self._create_mock_task_tracker()
130+
131+
mock_activity_func = Mock()
132+
mock_activity_func._function._name = "test_activity"
133+
mock_activity_func._function._func = lambda x: x
134+
135+
mock_schema = Mock()
136+
mock_schema.name = "test_activity"
137+
mock_schema.description = "Test description"
138+
mock_schema.params_json_schema = {"type": "object"}
139+
mock_function_schema.return_value = mock_schema
140+
141+
mock_tool = Mock(spec=FunctionTool)
142+
mock_function_tool.return_value = mock_tool
143+
144+
ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)
145+
146+
# Call with default retry options
147+
result = ai_context.activity_as_tool(mock_activity_func)
148+
149+
# Should still create the tool successfully
150+
assert result is mock_tool
151+
mock_function_tool.assert_called_once()
152+
153+
@patch('azure.durable_functions.openai_agents.context.function_schema')
154+
@patch('azure.durable_functions.openai_agents.context.FunctionTool')
155+
def test_activity_as_tool_run_activity_with_retry(self, mock_function_tool, mock_function_schema):
156+
"""Test that the run_activity function calls task tracker with retry options."""
157+
orchestration_context = self._create_mock_orchestration_context()
158+
task_tracker = self._create_mock_task_tracker()
159+
160+
mock_activity_func = Mock()
161+
mock_activity_func._function._name = "test_activity"
162+
mock_activity_func._function._func = lambda x: x
163+
164+
mock_schema = Mock()
165+
mock_schema.name = "test_activity"
166+
mock_schema.description = ""
167+
mock_schema.params_json_schema = {"type": "object"}
168+
mock_function_schema.return_value = mock_schema
169+
170+
mock_tool = Mock(spec=FunctionTool)
171+
mock_function_tool.return_value = mock_tool
172+
173+
ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)
174+
retry_options = RetryOptions(1000, 3)
175+
176+
ai_context.activity_as_tool(mock_activity_func, retry_options=retry_options)
177+
178+
# Get the run_activity function that was passed to FunctionTool
179+
call_args = mock_function_tool.call_args
180+
run_activity = call_args[1]['on_invoke_tool']
181+
182+
# Create a mock context wrapper
183+
mock_ctx = Mock()
184+
185+
# Call the run_activity function
186+
import asyncio
187+
result = asyncio.run(run_activity(mock_ctx, "test_input"))
188+
189+
# Verify the task tracker was called with retry options
190+
task_tracker.get_activity_call_result_with_retry.assert_called_once_with(
191+
"test_activity", retry_options, "test_input"
192+
)
193+
assert result == "retry_activity_result"
194+
195+
@patch('azure.durable_functions.openai_agents.context.function_schema')
196+
@patch('azure.durable_functions.openai_agents.context.FunctionTool')
197+
def test_activity_as_tool_run_activity_without_retry(self, mock_function_tool, mock_function_schema):
198+
"""Test that the run_activity function calls task tracker without retry when retry_options is None."""
199+
orchestration_context = self._create_mock_orchestration_context()
200+
task_tracker = self._create_mock_task_tracker()
201+
202+
mock_activity_func = Mock()
203+
mock_activity_func._function._name = "test_activity"
204+
mock_activity_func._function._func = lambda x: x
205+
206+
mock_schema = Mock()
207+
mock_schema.name = "test_activity"
208+
mock_schema.description = ""
209+
mock_schema.params_json_schema = {"type": "object"}
210+
mock_function_schema.return_value = mock_schema
211+
212+
mock_tool = Mock(spec=FunctionTool)
213+
mock_function_tool.return_value = mock_tool
214+
215+
ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)
216+
217+
ai_context.activity_as_tool(mock_activity_func, retry_options=None)
218+
219+
# Get the run_activity function that was passed to FunctionTool
220+
call_args = mock_function_tool.call_args
221+
run_activity = call_args[1]['on_invoke_tool']
222+
223+
# Create a mock context wrapper
224+
mock_ctx = Mock()
225+
226+
# Call the run_activity function
227+
import asyncio
228+
result = asyncio.run(run_activity(mock_ctx, "test_input"))
229+
230+
# Verify the task tracker was called without retry options
231+
task_tracker.get_activity_call_result.assert_called_once_with(
232+
"test_activity", "test_input"
233+
)
234+
assert result == "activity_result"
235+
236+
def test_context_delegation_methods_work(self):
237+
"""Test that common context methods work through delegation."""
238+
orchestration_context = self._create_mock_orchestration_context()
239+
task_tracker = self._create_mock_task_tracker()
240+
241+
# Add some mock methods to the orchestration context
242+
orchestration_context.wait_for_external_event = Mock(return_value="external_event_task")
243+
orchestration_context.create_timer = Mock(return_value="timer_task")
244+
245+
ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)
246+
247+
# These should work through delegation
248+
result1 = ai_context.wait_for_external_event("test_event")
249+
result2 = ai_context.create_timer("2023-01-01T00:00:00Z")
250+
251+
assert result1 == "external_event_task"
252+
assert result2 == "timer_task"
253+
orchestration_context.wait_for_external_event.assert_called_once_with("test_event")
254+
orchestration_context.create_timer.assert_called_once_with("2023-01-01T00:00:00Z")
255+
256+
def test_getattr_delegates_to_context(self):
257+
"""Test that __getattr__ delegates attribute access to the underlying context."""
258+
orchestration_context = self._create_mock_orchestration_context()
259+
task_tracker = self._create_mock_task_tracker()
260+
261+
ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)
262+
263+
# Test delegation of various attributes
264+
assert ai_context.instance_id == "test_instance_id"
265+
assert ai_context.current_utc_datetime == "2023-01-01T00:00:00Z"
266+
assert ai_context.is_replaying is False
267+
268+
def test_getattr_raises_attribute_error_for_nonexistent_attributes(self):
269+
"""Test that __getattr__ raises AttributeError for non-existent attributes."""
270+
orchestration_context = self._create_mock_orchestration_context()
271+
task_tracker = self._create_mock_task_tracker()
272+
273+
ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)
274+
275+
with pytest.raises(AttributeError, match="'DurableAIAgentContext' object has no attribute 'nonexistent_attr'"):
276+
_ = ai_context.nonexistent_attr
277+
278+
def test_dir_includes_delegated_attributes(self):
279+
"""Test that __dir__ includes attributes from the underlying context."""
280+
orchestration_context = self._create_mock_orchestration_context()
281+
task_tracker = self._create_mock_task_tracker()
282+
283+
ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)
284+
dir_result = dir(ai_context)
285+
286+
# Should include delegated attributes from the underlying context
287+
assert 'instance_id' in dir_result
288+
assert 'current_utc_datetime' in dir_result
289+
assert 'is_replaying' in dir_result
290+
# Should also include public methods
291+
assert 'call_activity' in dir_result
292+
assert 'activity_as_tool' in dir_result

0 commit comments

Comments
 (0)