diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index cd227e7aba..11c51b958f 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import asyncio import json import os from datetime import datetime @@ -505,13 +506,22 @@ async def _handle_async_stream_response( ) -> list[ChatMessage]: component_info = ComponentInfo.from_component(self) chunks: list[StreamingChunk] = [] - async for chunk in chat_completion: # pylint: disable=not-an-iterable - assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice." - chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk( - chunk=chunk, previous_chunks=chunks, component_info=component_info - ) - chunks.append(chunk_delta) - await callback(chunk_delta) + try: + async for chunk in chat_completion: # pylint: disable=not-an-iterable + assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice." + chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk( + chunk=chunk, previous_chunks=chunks, component_info=component_info + ) + chunks.append(chunk_delta) + await callback(chunk_delta) + + except asyncio.CancelledError: + await asyncio.shield(chat_completion.close()) + # close the stream when task is cancelled + # asyncio.shield ensures the close operation completes + # https://docs.python.org/3/library/asyncio-task.html#shielding-from-cancellation + raise # Re-raise to propagate cancellation + return [_convert_streaming_chunks_to_chat_message(chunks=chunks)] diff --git a/releasenotes/notes/fix-OpenAIChatGenerator-handles-asyncio.CancelledError-closing-response-stream-37e1e85255e1dc41.yaml b/releasenotes/notes/fix-OpenAIChatGenerator-handles-asyncio.CancelledError-closing-response-stream-37e1e85255e1dc41.yaml new file mode 100644 index 0000000000..4484f8fe59 --- /dev/null +++ b/releasenotes/notes/fix-OpenAIChatGenerator-handles-asyncio.CancelledError-closing-response-stream-37e1e85255e1dc41.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + The ``_handle_async_stream_response()`` method in ``OpenAIChatGenerator`` now handles ``asyncio.CancelledError`` exceptions. When a streaming task is cancelled mid-stream, the async for loop gracefully closes the stream using ``asyncio.shield()`` to ensure the cleanup operation completes even during cancellation. diff --git a/test/components/generators/chat/test_openai_async.py b/test/components/generators/chat/test_openai_async.py index 50da9777f9..a2203bb450 100644 --- a/test/components/generators/chat/test_openai_async.py +++ b/test/components/generators/chat/test_openai_async.py @@ -2,12 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 +import asyncio import os from datetime import datetime from unittest.mock import AsyncMock, MagicMock, patch import pytest -from openai import AsyncOpenAI, OpenAIError +from openai import AsyncOpenAI, AsyncStream, OpenAIError from openai.types.chat import ( ChatCompletion, ChatCompletionChunk, @@ -282,6 +283,85 @@ async def streaming_callback(chunk: StreamingChunk) -> None: assert tool_call.arguments == {"city": "Paris"} assert message.meta["finish_reason"] == "tool_calls" + @pytest.mark.asyncio + async def test_async_stream_closes_on_cancellation(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + generator = OpenAIChatGenerator( + api_key=Secret.from_token("test-api-key"), + api_base_url="test-base-url", + organization="test-organization", + timeout=30, + max_retries=5, + ) + + # mocked the async stream that will be passed to the _handle_async_stream_response() method + mock_stream = AsyncMock(spec=AsyncStream) + mock_stream.close = AsyncMock() + + async def mock_chunk_generator(): + for i in range(100): + yield MagicMock( + choices=[ + MagicMock( + index=0, + delta=MagicMock(content=f"chunk{i}", role=None, tool_calls=None), + finish_reason=None, + logprobs=None, + ) + ], + model="gpt-4", + usage=None, + ) + await asyncio.sleep(0.01) # delay between chunks + + mock_stream.__aiter__ = lambda self: mock_chunk_generator() + + received_chunks = [] + + async def test_callback(chunk: StreamingChunk): + received_chunks.append(chunk) + + # the task that will be cancelled + task = asyncio.create_task(generator._handle_async_stream_response(mock_stream, test_callback)) + + # trigger the task, process a few chunks, then cancel + await asyncio.sleep(0.05) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + mock_stream.close.assert_awaited_once() + + # we received some chunks before cancellation but not all of them + assert len(received_chunks) > 0 + assert len(received_chunks) < 100 + + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", + ) + @pytest.mark.integration + @pytest.mark.asyncio + async def test_run_async_cancellation_integration(self): + generator = OpenAIChatGenerator(model="gpt-4") + messages = [ChatMessage.from_user("Write me an essay about the history of jazz music, at least 500 words.")] + received_chunks = [] + + async def streaming_callback(chunk: StreamingChunk): + received_chunks.append(chunk) + + task = asyncio.create_task(generator.run_async(messages=messages, streaming_callback=streaming_callback)) + + await asyncio.sleep(2.0) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + assert len(received_chunks) > 0 + assert len(received_chunks) < 500 + @pytest.mark.skipif( not os.environ.get("OPENAI_API_KEY", None), reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", diff --git a/test/components/generators/chat/test_openai_responses.py b/test/components/generators/chat/test_openai_responses.py index ca2ccc2c17..fae59f18d1 100644 --- a/test/components/generators/chat/test_openai_responses.py +++ b/test/components/generators/chat/test_openai_responses.py @@ -780,7 +780,8 @@ def warm_up(self): component.warm_up() assert len(warm_up_calls) == call_count - def test_run(self, openai_mock_responses): + def test_run(self, openai_mock_responses, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") chat_messages = [ChatMessage.from_user("What's the capital of France")] component = OpenAIResponsesChatGenerator( model="gpt-4", generation_kwargs={"include": ["message.output_text.logprobs"]}