diff --git a/azure/durable_functions/openai_agents/context.py b/azure/durable_functions/openai_agents/context.py index 2517fe74..6b8b3516 100644 --- a/azure/durable_functions/openai_agents/context.py +++ b/azure/durable_functions/openai_agents/context.py @@ -1,3 +1,4 @@ +import json from typing import Any, Callable, Optional, TYPE_CHECKING, Union from azure.durable_functions.models.DurableOrchestrationContext import ( @@ -138,13 +139,29 @@ def create_activity_tool( else: activity_name = activity_func._function._name + input_name = None + if (activity_func._function._trigger is not None + and hasattr(activity_func._function._trigger, 'name')): + input_name = activity_func._function._trigger.name + async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any: + # Parse JSON input and extract the named value if input_name is specified + activity_input = input + if input_name: + try: + parsed_input = json.loads(input) + if isinstance(parsed_input, dict) and input_name in parsed_input: + activity_input = parsed_input[input_name] + # If parsing fails or the named parameter is not found, pass the original input + except (json.JSONDecodeError, TypeError): + pass + if retry_options: result = self._task_tracker.get_activity_call_result_with_retry( - activity_name, retry_options, input + activity_name, retry_options, activity_input ) else: - result = self._task_tracker.get_activity_call_result(activity_name, input) + result = self._task_tracker.get_activity_call_result(activity_name, activity_input) return result schema = function_schema( diff --git a/azure/durable_functions/openai_agents/model_invocation_activity.py b/azure/durable_functions/openai_agents/model_invocation_activity.py index 4eb897ff..078823c9 100644 --- a/azure/durable_functions/openai_agents/model_invocation_activity.py +++ b/azure/durable_functions/openai_agents/model_invocation_activity.py @@ -163,7 +163,16 @@ class ActivityModelInput(BaseModel): def to_json(self) -> str: """Convert the ActivityModelInput to a JSON string.""" - return self.model_dump_json() + try: + return self.model_dump_json(warnings=False) + except Exception: + # Fallback to basic JSON serialization + try: + return json.dumps(self.model_dump(warnings=False), default=str) + except Exception as fallback_error: + raise ValueError( + f"Unable to serialize ActivityModelInput: {fallback_error}" + ) from fallback_error @classmethod def from_json(cls, json_str: str) -> 'ActivityModelInput': @@ -310,6 +319,7 @@ async def get_response( *, previous_response_id: Optional[str], prompt: Optional[ResponsePromptParam], + conversation_id: Optional[str] = None, ) -> ModelResponse: """Get a response from the model.""" def make_tool_info(tool: Tool) -> ToolInput: diff --git a/azure/durable_functions/openai_agents/orchestrator_generator.py b/azure/durable_functions/openai_agents/orchestrator_generator.py index 56d5aa01..167c88a7 100644 --- a/azure/durable_functions/openai_agents/orchestrator_generator.py +++ b/azure/durable_functions/openai_agents/orchestrator_generator.py @@ -19,8 +19,27 @@ async def durable_openai_agent_activity(input: str, model_provider: ModelProvide model_invoker = ModelInvoker(model_provider=model_provider) result = await model_invoker.invoke_model_activity(activity_input) - json_obj = ModelResponse.__pydantic_serializer__.to_json(result) - return json_obj.decode() + # Use safe/public Pydantic API when possible. Prefer model_dump_json if result is a BaseModel + # Otherwise handle common types (str/bytes/dict/list) and fall back to json.dumps. + import json as _json + + if hasattr(result, "model_dump_json"): + # Pydantic v2 BaseModel + json_str = result.model_dump_json() + else: + if isinstance(result, bytes): + json_str = result.decode() + elif isinstance(result, str): + json_str = result + else: + # Try the internal serializer as a last resort, but fall back to json.dumps + try: + json_bytes = ModelResponse.__pydantic_serializer__.to_json(result) + json_str = json_bytes.decode() + except Exception: + json_str = _json.dumps(result) + + return json_str def durable_openai_agent_orchestrator_generator( diff --git a/azure/durable_functions/openai_agents/task_tracker.py b/azure/durable_functions/openai_agents/task_tracker.py index 1f346de7..f4bcdb65 100644 --- a/azure/durable_functions/openai_agents/task_tracker.py +++ b/azure/durable_functions/openai_agents/task_tracker.py @@ -52,13 +52,13 @@ def _get_activity_result_or_raise(self, task): result = json.loads(result_json) return result - def get_activity_call_result(self, activity_name, input: str): + def get_activity_call_result(self, activity_name, input: Any): """Call an activity and return its result or raise ``YieldException`` if pending.""" task = self._context.call_activity(activity_name, input) return self._get_activity_result_or_raise(task) def get_activity_call_result_with_retry( - self, activity_name, retry_options: RetryOptions, input: str + self, activity_name, retry_options: RetryOptions, input: Any ): """Call an activity with retry and return its result or raise YieldException if pending.""" task = self._context.call_activity_with_retry(activity_name, retry_options, input) diff --git a/samples-v2/openai_agents/function_app.py b/samples-v2/openai_agents/function_app.py index 91edf4dc..c83ce7d2 100644 --- a/samples-v2/openai_agents/function_app.py +++ b/samples-v2/openai_agents/function_app.py @@ -1,4 +1,5 @@ import os +import random import azure.functions as func import azure.durable_functions as df @@ -102,4 +103,13 @@ def tools(context): import basic.tools return basic.tools.main() +@app.activity_trigger(input_name="max") +async def random_number_tool(max: int) -> int: + """Return a random integer between 0 and the given maximum.""" + return random.randint(0, max) +@app.orchestration_trigger(context_name="context") +@app.durable_openai_agent_orchestrator +def message_filter(context): + import handoffs.message_filter + return handoffs.message_filter.main(context.create_activity_tool(random_number_tool)) diff --git a/samples-v2/openai_agents/handoffs/message_filter.py b/samples-v2/openai_agents/handoffs/message_filter.py new file mode 100644 index 00000000..21871ca2 --- /dev/null +++ b/samples-v2/openai_agents/handoffs/message_filter.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import json + +from agents import Agent, HandoffInputData, Runner, function_tool, handoff +from agents.extensions import handoff_filters +from agents.models import is_gpt_5_default + + +def spanish_handoff_message_filter(handoff_message_data: HandoffInputData) -> HandoffInputData: + if is_gpt_5_default(): + print("gpt-5 is enabled, so we're not filtering the input history") + # when using gpt-5, removing some of the items could break things, so we do this filtering only for other models + return HandoffInputData( + input_history=handoff_message_data.input_history, + pre_handoff_items=tuple(handoff_message_data.pre_handoff_items), + new_items=tuple(handoff_message_data.new_items), + ) + + # First, we'll remove any tool-related messages from the message history + handoff_message_data = handoff_filters.remove_all_tools(handoff_message_data) + + # Second, we'll also remove the first two items from the history, just for demonstration + history = ( + tuple(handoff_message_data.input_history[2:]) + if isinstance(handoff_message_data.input_history, tuple) + else handoff_message_data.input_history + ) + + # or, you can use the HandoffInputData.clone(kwargs) method + return HandoffInputData( + input_history=history, + pre_handoff_items=tuple(handoff_message_data.pre_handoff_items), + new_items=tuple(handoff_message_data.new_items), + ) + + +def main(random_number_tool): + first_agent = Agent( + name="Assistant", + instructions="Be extremely concise.", + tools=[random_number_tool], + ) + + spanish_agent = Agent( + name="Spanish Assistant", + instructions="You only speak Spanish and are extremely concise.", + handoff_description="A Spanish-speaking assistant.", + ) + + second_agent = Agent( + name="Assistant", + instructions=( + "Be a helpful assistant. If the user speaks Spanish, handoff to the Spanish assistant." + ), + handoffs=[handoff(spanish_agent, input_filter=spanish_handoff_message_filter)], + ) + + # 1. Send a regular message to the first agent + result = Runner.run_sync(first_agent, input="Hi, my name is Sora.") + + print("Step 1 done") + + # 2. Ask it to generate a number + result = Runner.run_sync( + first_agent, + input=result.to_input_list() + + [{"content": "Can you generate a random number between 0 and 100?", "role": "user"}], + ) + + print("Step 2 done") + + # 3. Call the second agent + result = Runner.run_sync( + second_agent, + input=result.to_input_list() + + [ + { + "content": "I live in New York City. Whats the population of the city?", + "role": "user", + } + ], + ) + + print("Step 3 done") + + # 4. Cause a handoff to occur + result = Runner.run_sync( + second_agent, + input=result.to_input_list() + + [ + { + "content": "Por favor habla en español. ¿Cuál es mi nombre y dónde vivo?", + "role": "user", + } + ], + ) + + print("Step 4 done") + + print("\n===Final messages===\n") + + # 5. That should have caused spanish_handoff_message_filter to be called, which means the + # output should be missing the first two messages, and have no tool calls. + # Let's print the messages to see what happened + for message in result.to_input_list(): + print(json.dumps(message, indent=2)) + # tool_calls = message.tool_calls if isinstance(message, AssistantMessage) else None + + # print(f"{message.role}: {message.content}\n - Tool calls: {tool_calls or 'None'}") + """ + $python examples/handoffs/message_filter.py + Step 1 done + Step 2 done + Step 3 done + Step 4 done + + ===Final messages=== + + { + "content": "Can you generate a random number between 0 and 100?", + "role": "user" + } + { + "id": "...", + "content": [ + { + "annotations": [], + "text": "Sure! Here's a random number between 0 and 100: **42**.", + "type": "output_text" + } + ], + "role": "assistant", + "status": "completed", + "type": "message" + } + { + "content": "I live in New York City. Whats the population of the city?", + "role": "user" + } + { + "id": "...", + "content": [ + { + "annotations": [], + "text": "As of the most recent estimates, the population of New York City is approximately 8.6 million people. However, this number is constantly changing due to various factors such as migration and birth rates. For the latest and most accurate information, it's always a good idea to check the official data from sources like the U.S. Census Bureau.", + "type": "output_text" + } + ], + "role": "assistant", + "status": "completed", + "type": "message" + } + { + "content": "Por favor habla en espa\u00f1ol. \u00bfCu\u00e1l es mi nombre y d\u00f3nde vivo?", + "role": "user" + } + { + "id": "...", + "content": [ + { + "annotations": [], + "text": "No tengo acceso a esa informaci\u00f3n personal, solo s\u00e9 lo que me has contado: vives en Nueva York.", + "type": "output_text" + } + ], + "role": "assistant", + "status": "completed", + "type": "message" + } + """ + + return result.final_output diff --git a/samples-v2/openai_agents/requirements.txt b/samples-v2/openai_agents/requirements.txt index a31252d5..89ebebe2 100644 --- a/samples-v2/openai_agents/requirements.txt +++ b/samples-v2/openai_agents/requirements.txt @@ -5,6 +5,6 @@ azure-functions azure-functions-durable azure-identity -openai==1.98.0 -openai-agents==0.2.4 +openai==1.107.3 +openai-agents==0.3.0 pydantic diff --git a/samples-v2/openai_agents/test_orchestrators.py b/samples-v2/openai_agents/test_orchestrators.py index 3ca4160f..2b363d09 100755 --- a/samples-v2/openai_agents/test_orchestrators.py +++ b/samples-v2/openai_agents/test_orchestrators.py @@ -22,7 +22,8 @@ "non_strict_output_type", "previous_response_id", "remote_image", - "tools" + "tools", + "message_filter", ] BASE_URL = "http://localhost:7071/api/orchestrators" diff --git a/tests/openai_agents/test_context.py b/tests/openai_agents/test_context.py index cb887005..6b9d9389 100644 --- a/tests/openai_agents/test_context.py +++ b/tests/openai_agents/test_context.py @@ -30,6 +30,81 @@ def _create_mock_task_tracker(self): task_tracker.get_activity_call_result_with_retry = Mock(return_value="retry_activity_result") return task_tracker + def _create_mock_activity_func(self, name="test_activity", input_name=None, + activity_name=None): + """Create a mock activity function with configurable parameters.""" + mock_activity_func = Mock() + mock_activity_func._function._name = name + mock_activity_func._function._func = lambda x: x + + if input_name is not None: + # Create trigger with input_name + mock_activity_func._function._trigger = Mock() + mock_activity_func._function._trigger.activity = activity_name + mock_activity_func._function._trigger.name = input_name + else: + # No trigger means no input_name + mock_activity_func._function._trigger = None + + return mock_activity_func + + def _setup_activity_tool_mocks(self, mock_function_tool, mock_function_schema, + activity_name="test_activity", description=""): + """Setup common mocks for function_schema and FunctionTool.""" + mock_schema = Mock() + mock_schema.name = activity_name + mock_schema.description = description + mock_schema.params_json_schema = {"type": "object"} + mock_function_schema.return_value = mock_schema + + mock_tool = Mock(spec=FunctionTool) + mock_function_tool.return_value = mock_tool + + return mock_tool + + def _invoke_activity_tool(self, run_activity, input_data): + """Helper to invoke the activity tool with asyncio.""" + mock_ctx = Mock() + import asyncio + return asyncio.run(run_activity(mock_ctx, input_data)) + + def _test_activity_tool_input_processing(self, input_name=None, input_data="", + expected_input_parameter_value="", + retry_options=None, + activity_name="test_activity"): + """Framework method that runs a complete input processing test.""" + with patch('azure.durable_functions.openai_agents.context.function_schema') \ + as mock_function_schema, \ + patch('azure.durable_functions.openai_agents.context.FunctionTool') \ + as mock_function_tool: + + # Setup + orchestration_context = self._create_mock_orchestration_context() + task_tracker = self._create_mock_task_tracker() + mock_activity_func = self._create_mock_activity_func( + name=activity_name, input_name=input_name) + self._setup_activity_tool_mocks( + mock_function_tool, mock_function_schema, activity_name) + + # Create context and tool + ai_context = DurableAIAgentContext(orchestration_context, task_tracker, None) + ai_context.create_activity_tool(mock_activity_func, retry_options=retry_options) + + # Get and invoke the run_activity function + call_args = mock_function_tool.call_args + run_activity = call_args[1]['on_invoke_tool'] + self._invoke_activity_tool(run_activity, input_data) + + # Verify the expected call was made + if retry_options: + task_tracker.get_activity_call_result_with_retry.assert_called_once_with( + activity_name, retry_options, expected_input_parameter_value + ) + else: + task_tracker.get_activity_call_result.assert_called_once_with( + activity_name, expected_input_parameter_value + ) + def test_init_creates_context_successfully(self): """Test that __init__ creates a DurableAIAgentContext successfully.""" orchestration_context = self._create_mock_orchestration_context() @@ -276,6 +351,60 @@ def test_activity_as_tool_extracts_activity_name_from_trigger(self, mock_functio ) assert result == "activity_result" + def test_create_activity_tool_parses_json_input_with_input_name(self): + """Test JSON input parsing and named value extraction with input_name.""" + self._test_activity_tool_input_processing( + input_name="max", + input_data='{"max": 100}', + expected_input_parameter_value=100, + activity_name="random_number_tool" + ) + + def test_create_activity_tool_handles_non_json_input_gracefully(self): + """Test non-JSON input passes through unchanged with input_name.""" + self._test_activity_tool_input_processing( + input_name="param", + input_data="not json", + expected_input_parameter_value="not json" + ) + + def test_create_activity_tool_handles_json_missing_named_parameter(self): + """Test JSON input without named parameter passes through unchanged.""" + json_input = '{"other_param": 200}' + self._test_activity_tool_input_processing( + input_name="expected_param", + input_data=json_input, + expected_input_parameter_value=json_input + ) + + def test_create_activity_tool_handles_malformed_json_gracefully(self): + """Test malformed JSON passes through unchanged.""" + malformed_json = '{"param": 100' # Missing closing brace + self._test_activity_tool_input_processing( + input_name="param", + input_data=malformed_json, + expected_input_parameter_value=malformed_json + ) + + def test_create_activity_tool_json_parsing_works_with_retry_options(self): + """Test JSON parsing works correctly with retry options.""" + retry_options = RetryOptions(1000, 3) + self._test_activity_tool_input_processing( + input_name="value", + input_data='{"value": "test_data"}', + expected_input_parameter_value="test_data", + retry_options=retry_options + ) + + def test_create_activity_tool_no_input_name_passes_through_json(self): + """Test JSON input passes through unchanged when no input_name.""" + json_input = '{"param": 100}' + self._test_activity_tool_input_processing( + input_name=None, # No input_name + input_data=json_input, + expected_input_parameter_value=json_input + ) + def test_context_delegation_methods_work(self): """Test that common context methods work through delegation.""" orchestration_context = self._create_mock_orchestration_context()