Skip to content

Commit 93cd420

Browse files
committed
Add test_activity_as_tool_extracts_activity_name_from_trigger test
1 parent bea39ad commit 93cd420

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

tests/openai_agents/test_context.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,48 @@ def test_activity_as_tool_run_activity_without_retry(self, mock_function_tool, m
234234
)
235235
assert result == "activity_result"
236236

237+
@patch('azure.durable_functions.openai_agents.context.function_schema')
238+
@patch('azure.durable_functions.openai_agents.context.FunctionTool')
239+
def test_activity_as_tool_extracts_activity_name_from_trigger(self, mock_function_tool, mock_function_schema):
240+
"""Test that the run_activity function calls task tracker with the activity name specified in the trigger."""
241+
orchestration_context = self._create_mock_orchestration_context()
242+
task_tracker = self._create_mock_task_tracker()
243+
244+
mock_activity_func = Mock()
245+
mock_activity_func._function._name = "test_activity"
246+
mock_activity_func._function._trigger.activity = "activity_name_from_trigger"
247+
mock_activity_func._function._func = lambda x: x
248+
249+
mock_schema = Mock()
250+
mock_schema.name = "test_activity"
251+
mock_schema.description = ""
252+
mock_schema.params_json_schema = {"type": "object"}
253+
mock_function_schema.return_value = mock_schema
254+
255+
mock_tool = Mock(spec=FunctionTool)
256+
mock_function_tool.return_value = mock_tool
257+
258+
ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None)
259+
260+
ai_context.create_activity_tool(mock_activity_func, retry_options=None)
261+
262+
# Get the run_activity function that was passed to FunctionTool
263+
call_args = mock_function_tool.call_args
264+
run_activity = call_args[1]['on_invoke_tool']
265+
266+
# Create a mock context wrapper
267+
mock_ctx = Mock()
268+
269+
# Call the run_activity function
270+
import asyncio
271+
result = asyncio.run(run_activity(mock_ctx, "test_input"))
272+
273+
# Verify the task tracker was called without retry options
274+
task_tracker.get_activity_call_result.assert_called_once_with(
275+
"activity_name_from_trigger", "test_input"
276+
)
277+
assert result == "activity_result"
278+
237279
def test_context_delegation_methods_work(self):
238280
"""Test that common context methods work through delegation."""
239281
orchestration_context = self._create_mock_orchestration_context()

0 commit comments

Comments
 (0)