diff --git a/docs/ref/realtime/config.md b/docs/ref/realtime/config.md index 3e50f47ad..2445c6a34 100644 --- a/docs/ref/realtime/config.md +++ b/docs/ref/realtime/config.md @@ -11,6 +11,7 @@ ## Audio Configuration ::: agents.realtime.config.RealtimeInputAudioTranscriptionConfig +::: agents.realtime.config.RealtimeInputAudioNoiseReductionConfig ::: agents.realtime.config.RealtimeTurnDetectionConfig ## Guardrails Settings diff --git a/src/agents/realtime/__init__.py b/src/agents/realtime/__init__.py index 7675c466f..3f0793fa1 100644 --- a/src/agents/realtime/__init__.py +++ b/src/agents/realtime/__init__.py @@ -3,6 +3,7 @@ RealtimeAudioFormat, RealtimeClientMessage, RealtimeGuardrailsSettings, + RealtimeInputAudioNoiseReductionConfig, RealtimeInputAudioTranscriptionConfig, RealtimeModelName, RealtimeModelTracingConfig, @@ -101,6 +102,7 @@ "RealtimeAudioFormat", "RealtimeClientMessage", "RealtimeGuardrailsSettings", + "RealtimeInputAudioNoiseReductionConfig", "RealtimeInputAudioTranscriptionConfig", "RealtimeModelName", "RealtimeModelTracingConfig", diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py index 8b70c872f..ddbf48bab 100644 --- a/src/agents/realtime/config.py +++ b/src/agents/realtime/config.py @@ -61,6 +61,13 @@ class RealtimeInputAudioTranscriptionConfig(TypedDict): """An optional prompt to guide transcription.""" +class RealtimeInputAudioNoiseReductionConfig(TypedDict): + """Noise reduction configuration for input audio.""" + + type: NotRequired[Literal["near_field", "far_field"]] + """Noise reduction mode to apply to input audio.""" + + class RealtimeTurnDetectionConfig(TypedDict): """Turn detection config. Allows extra vendor keys if needed.""" @@ -119,6 +126,9 @@ class RealtimeSessionModelSettings(TypedDict): input_audio_transcription: NotRequired[RealtimeInputAudioTranscriptionConfig] """Configuration for transcribing input audio.""" + input_audio_noise_reduction: NotRequired[RealtimeInputAudioNoiseReductionConfig | None] + """Noise reduction configuration for input audio.""" + turn_detection: NotRequired[RealtimeTurnDetectionConfig] """Configuration for detecting conversation turns.""" diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index 4d6cf398c..50aaf3c4b 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -825,14 +825,24 @@ def _get_session_config( "output_audio_format", DEFAULT_MODEL_SETTINGS.get("output_audio_format"), ) + input_audio_noise_reduction = model_settings.get( + "input_audio_noise_reduction", + DEFAULT_MODEL_SETTINGS.get("input_audio_noise_reduction"), + ) input_audio_config = None if any( value is not None - for value in [input_audio_format, input_audio_transcription, turn_detection] + for value in [ + input_audio_format, + input_audio_noise_reduction, + input_audio_transcription, + turn_detection, + ] ): input_audio_config = OpenAIRealtimeAudioInput( format=to_realtime_audio_format(input_audio_format), + noise_reduction=cast(Any, input_audio_noise_reduction), transcription=cast(Any, input_audio_transcription), turn_detection=cast(Any, turn_detection), ) diff --git a/tests/realtime/test_openai_realtime.py b/tests/realtime/test_openai_realtime.py index 34352df44..29b6fbd9a 100644 --- a/tests/realtime/test_openai_realtime.py +++ b/tests/realtime/test_openai_realtime.py @@ -1,3 +1,4 @@ +import json from typing import Any, cast from unittest.mock import AsyncMock, Mock, patch @@ -96,6 +97,88 @@ def mock_create_task_func(coro): assert model._websocket_task is not None assert model.model == "gpt-4o-realtime-preview" + @pytest.mark.asyncio + async def test_session_update_includes_noise_reduction(self, model, mock_websocket): + """Session.update should pass through input_audio_noise_reduction config.""" + config = { + "api_key": "test-api-key-123", + "initial_model_settings": { + "model_name": "gpt-4o-realtime-preview", + "input_audio_noise_reduction": {"type": "near_field"}, + }, + } + + sent_messages: list[dict[str, Any]] = [] + + async def async_websocket(*args, **kwargs): + async def send(payload: str): + sent_messages.append(json.loads(payload)) + return None + + mock_websocket.send.side_effect = send + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_task = AsyncMock() + + def mock_create_task_func(coro): + coro.close() + return mock_task + + mock_create_task.side_effect = mock_create_task_func + await model.connect(config) + + # Find the session.update events + session_updates = [m for m in sent_messages if m.get("type") == "session.update"] + assert len(session_updates) >= 1 + # Verify the last session.update contains the noise_reduction field + session = session_updates[-1]["session"] + assert session.get("audio", {}).get("input", {}).get("noise_reduction") == { + "type": "near_field" + } + + @pytest.mark.asyncio + async def test_session_update_omits_noise_reduction_when_not_provided( + self, model, mock_websocket + ): + """Session.update should omit input_audio_noise_reduction when not provided.""" + config = { + "api_key": "test-api-key-123", + "initial_model_settings": { + "model_name": "gpt-4o-realtime-preview", + }, + } + + sent_messages: list[dict[str, Any]] = [] + + async def async_websocket(*args, **kwargs): + async def send(payload: str): + sent_messages.append(json.loads(payload)) + return None + + mock_websocket.send.side_effect = send + return mock_websocket + + with patch("websockets.connect", side_effect=async_websocket): + with patch("asyncio.create_task") as mock_create_task: + mock_task = AsyncMock() + + def mock_create_task_func(coro): + coro.close() + return mock_task + + mock_create_task.side_effect = mock_create_task_func + await model.connect(config) + + # Find the session.update events + session_updates = [m for m in sent_messages if m.get("type") == "session.update"] + assert len(session_updates) >= 1 + # Verify the last session.update omits the noise_reduction field + session = session_updates[-1]["session"] + assert "audio" in session and "input" in session["audio"] + assert "noise_reduction" not in session["audio"]["input"] + @pytest.mark.asyncio async def test_connect_with_custom_headers_overrides_defaults(self, model, mock_websocket): """If custom headers are provided, use them verbatim without adding defaults."""