Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import asyncio
import json
import os
from datetime import datetime
Expand Down Expand Up @@ -505,13 +506,22 @@ async def _handle_async_stream_response(
) -> list[ChatMessage]:
component_info = ComponentInfo.from_component(self)
chunks: list[StreamingChunk] = []
async for chunk in chat_completion: # pylint: disable=not-an-iterable
assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice."
chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk(
chunk=chunk, previous_chunks=chunks, component_info=component_info
)
chunks.append(chunk_delta)
await callback(chunk_delta)
try:
async for chunk in chat_completion: # pylint: disable=not-an-iterable
assert len(chunk.choices) <= 1, "Streaming responses should have at most one choice."
chunk_delta = _convert_chat_completion_chunk_to_streaming_chunk(
chunk=chunk, previous_chunks=chunks, component_info=component_info
)
chunks.append(chunk_delta)
await callback(chunk_delta)

except asyncio.CancelledError:
await asyncio.shield(chat_completion.close())
# close the stream when task is cancelled
# asyncio.shield ensures the close operation completes
# https://docs.python.org/3/library/asyncio-task.html#shielding-from-cancellation
raise # Re-raise to propagate cancellation

return [_convert_streaming_chunks_to_chat_message(chunks=chunks)]


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
The ``_handle_async_stream_response()`` method in ``OpenAIChatGenerator`` now handles ``asyncio.CancelledError`` exceptions. When a streaming task is cancelled mid-stream, the async for loop gracefully closes the stream using ``asyncio.shield()`` to ensure the cleanup operation completes even during cancellation.
82 changes: 81 additions & 1 deletion test/components/generators/chat/test_openai_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
#
# SPDX-License-Identifier: Apache-2.0

import asyncio
import os
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from openai import AsyncOpenAI, OpenAIError
from openai import AsyncOpenAI, AsyncStream, OpenAIError
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
Expand Down Expand Up @@ -282,6 +283,85 @@ async def streaming_callback(chunk: StreamingChunk) -> None:
assert tool_call.arguments == {"city": "Paris"}
assert message.meta["finish_reason"] == "tool_calls"

@pytest.mark.asyncio
async def test_async_stream_closes_on_cancellation(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
generator = OpenAIChatGenerator(
api_key=Secret.from_token("test-api-key"),
api_base_url="test-base-url",
organization="test-organization",
timeout=30,
max_retries=5,
)

# mocked the async stream that will be passed to the _handle_async_stream_response() method
mock_stream = AsyncMock(spec=AsyncStream)
mock_stream.close = AsyncMock()

async def mock_chunk_generator():
for i in range(100):
yield MagicMock(
choices=[
MagicMock(
index=0,
delta=MagicMock(content=f"chunk{i}", role=None, tool_calls=None),
finish_reason=None,
logprobs=None,
)
],
model="gpt-4",
usage=None,
)
await asyncio.sleep(0.01) # delay between chunks

mock_stream.__aiter__ = lambda self: mock_chunk_generator()

received_chunks = []

async def test_callback(chunk: StreamingChunk):
received_chunks.append(chunk)

# the task that will be cancelled
task = asyncio.create_task(generator._handle_async_stream_response(mock_stream, test_callback))

# trigger the task, process a few chunks, then cancel
await asyncio.sleep(0.05)
task.cancel()

with pytest.raises(asyncio.CancelledError):
await task

mock_stream.close.assert_awaited_once()

# we received some chunks before cancellation but not all of them
assert len(received_chunks) > 0
assert len(received_chunks) < 100

@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
@pytest.mark.asyncio
async def test_run_async_cancellation_integration(self):
generator = OpenAIChatGenerator(model="gpt-4")
messages = [ChatMessage.from_user("Write me an essay about the history of jazz music, at least 500 words.")]
received_chunks = []

async def streaming_callback(chunk: StreamingChunk):
received_chunks.append(chunk)

task = asyncio.create_task(generator.run_async(messages=messages, streaming_callback=streaming_callback))

await asyncio.sleep(2.0)
task.cancel()

with pytest.raises(asyncio.CancelledError):
await task

assert len(received_chunks) > 0
assert len(received_chunks) < 500

@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
Expand Down
3 changes: 2 additions & 1 deletion test/components/generators/chat/test_openai_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,8 @@ def warm_up(self):
component.warm_up()
assert len(warm_up_calls) == call_count

def test_run(self, openai_mock_responses):
def test_run(self, openai_mock_responses, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
chat_messages = [ChatMessage.from_user("What's the capital of France")]
component = OpenAIResponsesChatGenerator(
model="gpt-4", generation_kwargs={"include": ["message.output_text.logprobs"]}
Expand Down
Loading