diff --git a/docs/ref/realtime/config.md b/docs/ref/realtime/config.md index 3e50f47ad..2445c6a34 100644 --- a/docs/ref/realtime/config.md +++ b/docs/ref/realtime/config.md @@ -11,6 +11,7 @@ ## Audio Configuration ::: agents.realtime.config.RealtimeInputAudioTranscriptionConfig +::: agents.realtime.config.RealtimeInputAudioNoiseReductionConfig ::: agents.realtime.config.RealtimeTurnDetectionConfig ## Guardrails Settings diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index 4369b342b..8d39ad390 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -257,7 +257,15 @@ async def _fetch_response( stream: bool = False, prompt: Any | None = None, ) -> litellm.types.utils.ModelResponse | tuple[Response, AsyncStream[ChatCompletionChunk]]: - converted_messages = Converter.items_to_messages(input) + # Preserve reasoning messages for tool calls when reasoning is on + # This is needed for models like Claude 4 Sonnet/Opus which support interleaved thinking + preserve_thinking_blocks = ( + model_settings.reasoning is not None and model_settings.reasoning.effort is not None + ) + + converted_messages = Converter.items_to_messages( + input, preserve_thinking_blocks=preserve_thinking_blocks + ) if system_instructions: converted_messages.insert( diff --git a/src/agents/models/chatcmpl_converter.py b/src/agents/models/chatcmpl_converter.py index 77ff22ee0..96f02a5fe 100644 --- a/src/agents/models/chatcmpl_converter.py +++ b/src/agents/models/chatcmpl_converter.py @@ -39,7 +39,7 @@ ResponseReasoningItemParam, ) from openai.types.responses.response_input_param import FunctionCallOutput, ItemReference, Message -from openai.types.responses.response_reasoning_item import Summary +from openai.types.responses.response_reasoning_item import Content, Summary from ..agent_output import AgentOutputSchemaBase from ..exceptions import AgentsException, UserError @@ -93,7 +93,9 @@ def convert_response_format( def message_to_output_items(cls, message: ChatCompletionMessage) -> list[TResponseOutputItem]: items: list[TResponseOutputItem] = [] - # Handle reasoning content if available + # Check if message is agents.extentions.models.litellm_model.InternalChatCompletionMessage + # We can't actually import it here because litellm is an optional dependency + # So we use hasattr to check for reasoning_content and thinking_blocks if hasattr(message, "reasoning_content") and message.reasoning_content: reasoning_item = ResponseReasoningItem( id=FAKE_RESPONSES_ID, @@ -101,16 +103,28 @@ def message_to_output_items(cls, message: ChatCompletionMessage) -> list[TRespon type="reasoning", ) - # Store full thinking blocks for Anthropic compatibility + # Store thinking blocks for Anthropic compatibility if hasattr(message, "thinking_blocks") and message.thinking_blocks: - # Store thinking blocks in the reasoning item's content - # Convert thinking blocks to Content objects - from openai.types.responses.response_reasoning_item import Content - - reasoning_item.content = [ - Content(text=str(block.get("thinking", "")), type="reasoning_text") - for block in message.thinking_blocks - ] + # Store thinking text in content and signature in encrypted_content + reasoning_item.content = [] + signature = None + for block in message.thinking_blocks: + if isinstance(block, dict): + thinking_text = block.get("thinking", "") + if thinking_text: + reasoning_item.content.append( + Content(text=thinking_text, type="reasoning_text") + ) + # Store the signature if present + if block.get("signature"): + signature = block.get("signature") + + # Store only the last signature in encrypted_content + # If there are multiple thinking blocks, this should be a problem. + # In practice, there should only be one signature for the entire reasoning step. + # Tested with: claude-sonnet-4-20250514 + if signature: + reasoning_item.encrypted_content = signature items.append(reasoning_item) @@ -301,10 +315,18 @@ def extract_all_content( def items_to_messages( cls, items: str | Iterable[TResponseInputItem], + preserve_thinking_blocks: bool = False, ) -> list[ChatCompletionMessageParam]: """ Convert a sequence of 'Item' objects into a list of ChatCompletionMessageParam. + Args: + items: A string or iterable of response input items to convert + preserve_thinking_blocks: Whether to preserve thinking blocks in tool calls + for reasoning models like Claude 4 Sonnet/Opus which support interleaved + thinking. When True, thinking blocks are reconstructed and included in + assistant messages with tool calls. + Rules: - EasyInputMessage or InputMessage (role=user) => ChatCompletionUserMessageParam - EasyInputMessage or InputMessage (role=system) => ChatCompletionSystemMessageParam @@ -325,6 +347,7 @@ def items_to_messages( result: list[ChatCompletionMessageParam] = [] current_assistant_msg: ChatCompletionAssistantMessageParam | None = None + pending_thinking_blocks: list[dict[str, str]] | None = None def flush_assistant_message() -> None: nonlocal current_assistant_msg @@ -336,10 +359,11 @@ def flush_assistant_message() -> None: current_assistant_msg = None def ensure_assistant_message() -> ChatCompletionAssistantMessageParam: - nonlocal current_assistant_msg + nonlocal current_assistant_msg, pending_thinking_blocks if current_assistant_msg is None: current_assistant_msg = ChatCompletionAssistantMessageParam(role="assistant") current_assistant_msg["tool_calls"] = [] + return current_assistant_msg for item in items: @@ -455,6 +479,13 @@ def ensure_assistant_message() -> ChatCompletionAssistantMessageParam: elif func_call := cls.maybe_function_tool_call(item): asst = ensure_assistant_message() + + # If we have pending thinking blocks, use them as the content + # This is required for Anthropic API tool calls with interleaved thinking + if pending_thinking_blocks: + asst["content"] = pending_thinking_blocks # type: ignore + pending_thinking_blocks = None # Clear after using + tool_calls = list(asst.get("tool_calls", [])) arguments = func_call["arguments"] if func_call["arguments"] else "{}" new_tool_call = ChatCompletionMessageFunctionToolCallParam( @@ -483,9 +514,28 @@ def ensure_assistant_message() -> ChatCompletionAssistantMessageParam: f"Encountered an item_reference, which is not supported: {item_ref}" ) - # 7) reasoning message => not handled - elif cls.maybe_reasoning_message(item): - pass + # 7) reasoning message => extract thinking blocks if present + elif reasoning_item := cls.maybe_reasoning_message(item): + # Reconstruct thinking blocks from content (text) and encrypted_content (signature) + content_items = reasoning_item.get("content", []) + signature = reasoning_item.get("encrypted_content") + + if content_items and preserve_thinking_blocks: + # Reconstruct thinking blocks from content and signature + pending_thinking_blocks = [] + for content_item in content_items: + if ( + isinstance(content_item, dict) + and content_item.get("type") == "reasoning_text" + ): + thinking_block = { + "type": "thinking", + "thinking": content_item.get("text", ""), + } + # Add signature if available + if signature: + thinking_block["signature"] = signature + pending_thinking_blocks.append(thinking_block) # 8) If we haven't recognized it => fail or ignore else: diff --git a/src/agents/models/chatcmpl_stream_handler.py b/src/agents/models/chatcmpl_stream_handler.py index 359d47bb5..474bffe09 100644 --- a/src/agents/models/chatcmpl_stream_handler.py +++ b/src/agents/models/chatcmpl_stream_handler.py @@ -62,6 +62,9 @@ class StreamingState: # Fields for real-time function call streaming function_call_streaming: dict[int, bool] = field(default_factory=dict) function_call_output_idx: dict[int, int] = field(default_factory=dict) + # Store accumulated thinking text and signature for Anthropic compatibility + thinking_text: str = "" + thinking_signature: str | None = None class SequenceNumber: @@ -101,6 +104,19 @@ async def handle_stream( delta = chunk.choices[0].delta + # Handle thinking blocks from Anthropic (for preserving signatures) + if hasattr(delta, "thinking_blocks") and delta.thinking_blocks: + for block in delta.thinking_blocks: + if isinstance(block, dict): + # Accumulate thinking text + thinking_text = block.get("thinking", "") + if thinking_text: + state.thinking_text += thinking_text + # Store signature if present + signature = block.get("signature") + if signature: + state.thinking_signature = signature + # Handle reasoning content for reasoning summaries if hasattr(delta, "reasoning_content"): reasoning_content = delta.reasoning_content @@ -527,7 +543,19 @@ async def handle_stream( # include Reasoning item if it exists if state.reasoning_content_index_and_output: - outputs.append(state.reasoning_content_index_and_output[1]) + reasoning_item = state.reasoning_content_index_and_output[1] + # Store thinking text in content and signature in encrypted_content + if state.thinking_text: + # Add thinking text as a Content object + if not reasoning_item.content: + reasoning_item.content = [] + reasoning_item.content.append( + Content(text=state.thinking_text, type="reasoning_text") + ) + # Store signature in encrypted_content + if state.thinking_signature: + reasoning_item.encrypted_content = state.thinking_signature + outputs.append(reasoning_item) # include text or refusal content if they exist if state.text_content_index_and_output or state.refusal_content_index_and_output: diff --git a/src/agents/realtime/__init__.py b/src/agents/realtime/__init__.py index 7675c466f..3f0793fa1 100644 --- a/src/agents/realtime/__init__.py +++ b/src/agents/realtime/__init__.py @@ -3,6 +3,7 @@ RealtimeAudioFormat, RealtimeClientMessage, RealtimeGuardrailsSettings, + RealtimeInputAudioNoiseReductionConfig, RealtimeInputAudioTranscriptionConfig, RealtimeModelName, RealtimeModelTracingConfig, @@ -101,6 +102,7 @@ "RealtimeAudioFormat", "RealtimeClientMessage", "RealtimeGuardrailsSettings", + "RealtimeInputAudioNoiseReductionConfig", "RealtimeInputAudioTranscriptionConfig", "RealtimeModelName", "RealtimeModelTracingConfig", diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py index 8b70c872f..ddbf48bab 100644 --- a/src/agents/realtime/config.py +++ b/src/agents/realtime/config.py @@ -61,6 +61,13 @@ class RealtimeInputAudioTranscriptionConfig(TypedDict): """An optional prompt to guide transcription.""" +class RealtimeInputAudioNoiseReductionConfig(TypedDict): + """Noise reduction configuration for input audio.""" + + type: NotRequired[Literal["near_field", "far_field"]] + """Noise reduction mode to apply to input audio.""" + + class RealtimeTurnDetectionConfig(TypedDict): """Turn detection config. Allows extra vendor keys if needed.""" @@ -119,6 +126,9 @@ class RealtimeSessionModelSettings(TypedDict): input_audio_transcription: NotRequired[RealtimeInputAudioTranscriptionConfig] """Configuration for transcribing input audio.""" + input_audio_noise_reduction: NotRequired[RealtimeInputAudioNoiseReductionConfig | None] + """Noise reduction configuration for input audio.""" + turn_detection: NotRequired[RealtimeTurnDetectionConfig] """Configuration for detecting conversation turns.""" diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index 4d6cf398c..50aaf3c4b 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -825,14 +825,24 @@ def _get_session_config( "output_audio_format", DEFAULT_MODEL_SETTINGS.get("output_audio_format"), ) + input_audio_noise_reduction = model_settings.get( + "input_audio_noise_reduction", + DEFAULT_MODEL_SETTINGS.get("input_audio_noise_reduction"), + ) input_audio_config = None if any( value is not None - for value in [input_audio_format, input_audio_transcription, turn_detection] + for value in [ + input_audio_format, + input_audio_noise_reduction, + input_audio_transcription, + turn_detection, + ] ): input_audio_config = OpenAIRealtimeAudioInput( format=to_realtime_audio_format(input_audio_format), + noise_reduction=cast(Any, input_audio_noise_reduction), transcription=cast(Any, input_audio_transcription), turn_detection=cast(Any, turn_detection), ) diff --git a/tests/realtime/test_openai_realtime.py b/tests/realtime/test_openai_realtime.py index 34352df44..29b6fbd9a 100644 --- a/tests/realtime/test_openai_realtime.py +++ b/tests/realtime/test_openai_realtime.py @@ -1,3 +1,4 @@ +import json from typing import Any, cast from unittest.mock import AsyncMock, Mock, patch @@ -96,6 +97,88 @@ def mock_create_task_func(coro): assert model._websocket_task is not None assert model.model == "gpt-4o-realtime-preview" + @pytest.mark.asyncio + async def test_session_update_includes_noise_reduction(self, model, mock_websocket): + """Session.update should pass through input_audio_noise_reduction config.""" + config = { + "api_key": "test-api-key-123", + "initial_model_settings": { + "model_name": "gpt-4o-realtime-preview", + "input_audio_noise_reduction": {"type": "near_field"}, + }, + } + + sent_messages: list[dict[str, Any]] = [] + + async def async_websocket(*args, **kwargs): + async def send(payload: str): + sent_messages.append(json.loads(payload)) + return None + + mock_websocket.send.side_effect = send + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_task = AsyncMock() + + def mock_create_task_func(coro): + coro.close() + return mock_task + + mock_create_task.side_effect = mock_create_task_func + await model.connect(config) + + # Find the session.update events + session_updates = [m for m in sent_messages if m.get("type") == "session.update"] + assert len(session_updates) >= 1 + # Verify the last session.update contains the noise_reduction field + session = session_updates[-1]["session"] + assert session.get("audio", {}).get("input", {}).get("noise_reduction") == { + "type": "near_field" + } + + @pytest.mark.asyncio + async def test_session_update_omits_noise_reduction_when_not_provided( + self, model, mock_websocket + ): + """Session.update should omit input_audio_noise_reduction when not provided.""" + config = { + "api_key": "test-api-key-123", + "initial_model_settings": { + "model_name": "gpt-4o-realtime-preview", + }, + } + + sent_messages: list[dict[str, Any]] = [] + + async def async_websocket(*args, **kwargs): + async def send(payload: str): + sent_messages.append(json.loads(payload)) + return None + + mock_websocket.send.side_effect = send + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_task = AsyncMock() + + def mock_create_task_func(coro): + coro.close() + return mock_task + + mock_create_task.side_effect = mock_create_task_func + await model.connect(config) + + # Find the session.update events + session_updates = [m for m in sent_messages if m.get("type") == "session.update"] + assert len(session_updates) >= 1 + # Verify the last session.update omits the noise_reduction field + session = session_updates[-1]["session"] + assert "audio" in session and "input" in session["audio"] + assert "noise_reduction" not in session["audio"]["input"] + @pytest.mark.asyncio async def test_connect_with_custom_headers_overrides_defaults(self, model, mock_websocket): """If custom headers are provided, use them verbatim without adding defaults.""" diff --git a/tests/test_anthropic_thinking_blocks.py b/tests/test_anthropic_thinking_blocks.py index 9513c7833..933be2c0e 100644 --- a/tests/test_anthropic_thinking_blocks.py +++ b/tests/test_anthropic_thinking_blocks.py @@ -10,7 +10,10 @@ from __future__ import annotations -from typing import Any +from typing import Any, cast + +from openai.types.chat import ChatCompletionMessageToolCall +from openai.types.chat.chat_completion_message_tool_call import Function from agents.extensions.models.litellm_model import InternalChatCompletionMessage from agents.models.chatcmpl_converter import Converter @@ -99,3 +102,115 @@ def test_reasoning_items_preserved_in_message_conversion(): thinking_block = reasoning_item.content[0] assert thinking_block.type == "reasoning_text" assert thinking_block.text == "I need to call the weather function for Paris" + + +def test_anthropic_thinking_blocks_with_tool_calls(): + """ + Test for models with extended thinking and interleaved thinking with tool calls. + + This test verifies the Anthropic's API's requirements for thinking blocks + to be the first content in assistant messages when reasoning is enabled and tool + calls are present. + """ + # Create a message with reasoning, thinking blocks and tool calls + message = InternalChatCompletionMessage( + role="assistant", + content="I'll check the weather for you.", + reasoning_content="The user wants weather information, I need to call the weather function", + thinking_blocks=[ + { + "type": "thinking", + "thinking": ( + "The user is asking about weather. " + "Let me use the weather tool to get this information." + ), + "signature": "TestSignature123", + } + ], + tool_calls=[ + ChatCompletionMessageToolCall( + id="call_123", + type="function", + function=Function(name="get_weather", arguments='{"city": "Tokyo"}'), + ) + ], + ) + + # Step 1: Convert message to output items + output_items = Converter.message_to_output_items(message) + + # Verify reasoning item exists and contains thinking blocks + reasoning_items = [ + item for item in output_items if hasattr(item, "type") and item.type == "reasoning" + ] + assert len(reasoning_items) == 1, "Should have exactly one reasoning item" + + reasoning_item = reasoning_items[0] + + # Verify thinking text is stored in content + assert hasattr(reasoning_item, "content") and reasoning_item.content, ( + "Reasoning item should have content" + ) + assert reasoning_item.content[0].type == "reasoning_text", ( + "Content should be reasoning_text type" + ) + + # Verify signature is stored in encrypted_content + assert hasattr(reasoning_item, "encrypted_content"), ( + "Reasoning item should have encrypted_content" + ) + assert reasoning_item.encrypted_content == "TestSignature123", "Signature should be preserved" + + # Verify tool calls are present + tool_call_items = [ + item for item in output_items if hasattr(item, "type") and item.type == "function_call" + ] + assert len(tool_call_items) == 1, "Should have exactly one tool call" + + # Step 2: Convert output items back to messages + # Convert items to dicts for the converter (simulating serialization/deserialization) + items_as_dicts: list[dict[str, Any]] = [] + for item in output_items: + if hasattr(item, "model_dump"): + items_as_dicts.append(item.model_dump()) + else: + items_as_dicts.append(cast(dict[str, Any], item)) + + messages = Converter.items_to_messages(items_as_dicts, preserve_thinking_blocks=True) # type: ignore[arg-type] + + # Find the assistant message with tool calls + assistant_messages = [ + msg for msg in messages if msg.get("role") == "assistant" and msg.get("tool_calls") + ] + assert len(assistant_messages) == 1, "Should have exactly one assistant message with tool calls" + + assistant_msg = assistant_messages[0] + + # Content must start with thinking blocks, not text + content = assistant_msg.get("content") + assert content is not None, "Assistant message should have content" + + assert isinstance(content, list) and len(content) > 0, ( + "Assistant message content should be a non-empty list" + ) + + first_content = content[0] + assert first_content.get("type") == "thinking", ( + f"First content must be 'thinking' type for Anthropic compatibility, " + f"but got '{first_content.get('type')}'" + ) + expected_thinking = ( + "The user is asking about weather. Let me use the weather tool to get this information." + ) + assert first_content.get("thinking") == expected_thinking, ( + "Thinking content should be preserved" + ) + # Signature should also be preserved + assert first_content.get("signature") == "TestSignature123", ( + "Signature should be preserved in thinking block" + ) + + # Verify tool calls are preserved + tool_calls = assistant_msg.get("tool_calls", []) + assert len(cast(list[Any], tool_calls)) == 1, "Tool calls should be preserved" + assert cast(list[Any], tool_calls)[0]["function"]["name"] == "get_weather"