Skip to content

Commit e7ee4da

Browse files
authored
Several fixes to make Completion.acreate(stream=True) work (#172)
* Added a failing test case for async completion stream * Consume async generator with async for * Consume the stream in chunks as sent by API, to avoid "empty" parts The api will send chunks like ``` b'data: {"id": "cmpl-6W18L0k1kFoHUoSsJOwcPq7DKBaGX", "object": "text_completion", "created": 1673088873, "choices": [{"text": "_", "index": 0, "logprobs": null, "finish_reason": null}], "model": "ada"}\n\n' ``` The default iterator will break on each `\n` character, whereas iter_chunks will just output parts as they arrive * Add another test using global aiosession * Manually consume aiohttp_session asyncontextmanager to ensure that session is only closed once the response stream is finished Previously we'd exit the with statement before the response stream is consumed by the caller, therefore, unless we're using a global ClientSession, the session is closed (and thus the request) before it should be. * Ensure we close the session even if the caller raises an exception while consuming the stream
1 parent 71cee6a commit e7ee4da

File tree

3 files changed

+56
-17
lines changed

3 files changed

+56
-17
lines changed

openai/api_requestor.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ def _make_session() -> requests.Session:
8989
return s
9090

9191

92-
def parse_stream_helper(line):
92+
def parse_stream_helper(line: bytes):
9393
if line:
94-
if line == b"data: [DONE]":
94+
if line.strip() == b"data: [DONE]":
9595
# return here will cause GeneratorExit exception in urllib3
9696
# and it will close http connection with TCP Reset
9797
return None
@@ -111,7 +111,7 @@ def parse_stream(rbody):
111111

112112

113113
async def parse_stream_async(rbody: aiohttp.StreamReader):
114-
async for line in rbody:
114+
async for line, _ in rbody.iter_chunks():
115115
_line = parse_stream_helper(line)
116116
if _line is not None:
117117
yield _line
@@ -294,18 +294,31 @@ async def arequest(
294294
request_id: Optional[str] = None,
295295
request_timeout: Optional[Union[float, Tuple[float, float]]] = None,
296296
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
297-
async with aiohttp_session() as session:
298-
result = await self.arequest_raw(
299-
method.lower(),
300-
url,
301-
session,
302-
params=params,
303-
supplied_headers=headers,
304-
files=files,
305-
request_id=request_id,
306-
request_timeout=request_timeout,
307-
)
308-
resp, got_stream = await self._interpret_async_response(result, stream)
297+
ctx = aiohttp_session()
298+
session = await ctx.__aenter__()
299+
result = await self.arequest_raw(
300+
method.lower(),
301+
url,
302+
session,
303+
params=params,
304+
supplied_headers=headers,
305+
files=files,
306+
request_id=request_id,
307+
request_timeout=request_timeout,
308+
)
309+
resp, got_stream = await self._interpret_async_response(result, stream)
310+
if got_stream:
311+
312+
async def wrap_resp():
313+
try:
314+
async for r in resp:
315+
yield r
316+
finally:
317+
await ctx.__aexit__(None, None, None)
318+
319+
return wrap_resp(), got_stream, self.api_key
320+
else:
321+
await ctx.__aexit__(None, None, None)
309322
return resp, got_stream, self.api_key
310323

311324
def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False):
@@ -507,7 +520,9 @@ def request_raw(
507520
except requests.exceptions.Timeout as e:
508521
raise error.Timeout("Request timed out: {}".format(e)) from e
509522
except requests.exceptions.RequestException as e:
510-
raise error.APIConnectionError("Error communicating with OpenAI: {}".format(e)) from e
523+
raise error.APIConnectionError(
524+
"Error communicating with OpenAI: {}".format(e)
525+
) from e
511526
util.log_info(
512527
"OpenAI API response",
513528
path=abs_url,

openai/api_resources/abstract/engine_api_resource.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ async def acreate(
236236
engine=engine,
237237
plain_old_data=cls.plain_old_data,
238238
)
239-
for line in response
239+
async for line in response
240240
)
241241
else:
242242
obj = util.convert_to_openai_object(

openai/tests/asyncio/test_endpoints.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import openai
77
from openai import error
8+
from aiohttp import ClientSession
89

910

1011
pytestmark = [pytest.mark.asyncio]
@@ -63,3 +64,26 @@ async def test_timeout_does_not_error():
6364
model="ada",
6465
request_timeout=10,
6566
)
67+
68+
69+
async def test_completions_stream_finishes_global_session():
70+
async with ClientSession() as session:
71+
openai.aiosession.set(session)
72+
73+
# A query that should be fast
74+
parts = []
75+
async for part in await openai.Completion.acreate(
76+
prompt="test", model="ada", request_timeout=3, stream=True
77+
):
78+
parts.append(part)
79+
assert len(parts) > 1
80+
81+
82+
async def test_completions_stream_finishes_local_session():
83+
# A query that should be fast
84+
parts = []
85+
async for part in await openai.Completion.acreate(
86+
prompt="test", model="ada", request_timeout=3, stream=True
87+
):
88+
parts.append(part)
89+
assert len(parts) > 1

0 commit comments

Comments
 (0)