diff --git a/src/agents/voice/input.py b/src/agents/voice/input.py index 8613d27ac..8cbc8b735 100644 --- a/src/agents/voice/input.py +++ b/src/agents/voice/input.py @@ -77,12 +77,13 @@ class StreamedAudioInput: """ def __init__(self): - self.queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]] = asyncio.Queue() + self.queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32] | None] = asyncio.Queue() - async def add_audio(self, audio: npt.NDArray[np.int16 | np.float32]): + async def add_audio(self, audio: npt.NDArray[np.int16 | np.float32] | None): """Adds more audio data to the stream. Args: - audio: The audio data to add. Must be a numpy array of int16 or float32. + audio: The audio data to add. Must be a numpy array of int16 or float32 or None. + If None passed, it indicates the end of the stream. """ await self.queue.put(audio) diff --git a/src/agents/voice/models/openai_stt.py b/src/agents/voice/models/openai_stt.py index 733406f04..19e91d9be 100644 --- a/src/agents/voice/models/openai_stt.py +++ b/src/agents/voice/models/openai_stt.py @@ -88,7 +88,7 @@ def __init__( self._trace_include_sensitive_data = trace_include_sensitive_data self._trace_include_sensitive_audio_data = trace_include_sensitive_audio_data - self._input_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]] = input.queue + self._input_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32] | None] = input.queue self._output_queue: asyncio.Queue[str | ErrorSentinel | SessionCompleteSentinel] = ( asyncio.Queue() ) @@ -245,7 +245,7 @@ async def _handle_events(self) -> None: await self._output_queue.put(SessionCompleteSentinel()) async def _stream_audio( - self, audio_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32]] + self, audio_queue: asyncio.Queue[npt.NDArray[np.int16 | np.float32] | None] ) -> None: assert self._websocket is not None, "Websocket not initialized" self._start_turn() diff --git a/tests/voice/test_input.py b/tests/voice/test_input.py index d41d870d7..fbef84c1b 100644 --- a/tests/voice/test_input.py +++ b/tests/voice/test_input.py @@ -121,7 +121,14 @@ async def test_streamed_audio_input(self): # Verify the queue contents assert streamed_input.queue.qsize() == 2 # Test non-blocking get - assert np.array_equal(streamed_input.queue.get_nowait(), audio1) + retrieved_audio1 = streamed_input.queue.get_nowait() + # Satisfy type checker + assert retrieved_audio1 is not None + assert np.array_equal(retrieved_audio1, audio1) + # Test blocking get - assert np.array_equal(await streamed_input.queue.get(), audio2) + retrieved_audio2 = await streamed_input.queue.get() + # Satisfy type checker + assert retrieved_audio2 is not None + assert np.array_equal(retrieved_audio2, audio2) assert streamed_input.queue.empty()