Skip to content

Commit 5d1a726

Browse files
authored
Fix some typing issues (#177)
1 parent ef8f1f1 commit 5d1a726

File tree

2 files changed

+40
-30
lines changed

2 files changed

+40
-30
lines changed

openai/api_requestor.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -89,32 +89,33 @@ def _make_session() -> requests.Session:
8989
return s
9090

9191

92-
def parse_stream_helper(line: bytes):
92+
def parse_stream_helper(line: bytes) -> Optional[str]:
9393
if line:
9494
if line.strip() == b"data: [DONE]":
9595
# return here will cause GeneratorExit exception in urllib3
9696
# and it will close http connection with TCP Reset
9797
return None
98-
if hasattr(line, "decode"):
99-
line = line.decode("utf-8")
100-
if line.startswith("data: "):
101-
line = line[len("data: ") :]
102-
return line
98+
if line.startswith(b"data: "):
99+
line = line[len(b"data: ") :]
100+
return line.decode("utf-8")
103101
return None
104102

105103

106-
def parse_stream(rbody):
104+
def parse_stream(rbody: Iterator[bytes]) -> Iterator[str]:
107105
for line in rbody:
108106
_line = parse_stream_helper(line)
109107
if _line is not None:
110108
yield _line
111109

112110

113111
async def parse_stream_async(rbody: aiohttp.StreamReader):
114-
async for line, _ in rbody.iter_chunks():
115-
_line = parse_stream_helper(line)
116-
if _line is not None:
117-
yield _line
112+
async for chunk, _ in rbody.iter_chunks():
113+
# While the `ChunkTupleAsyncStreamIterator` iterator is meant to iterate over chunks (and thus lines) it seems
114+
# to still sometimes return multiple lines at a time, so let's split the chunk by lines again.
115+
for line in chunk.splitlines():
116+
_line = parse_stream_helper(line)
117+
if _line is not None:
118+
yield _line
118119

119120

120121
class APIRequestor:
@@ -296,20 +297,25 @@ async def arequest(
296297
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool, str]:
297298
ctx = aiohttp_session()
298299
session = await ctx.__aenter__()
299-
result = await self.arequest_raw(
300-
method.lower(),
301-
url,
302-
session,
303-
params=params,
304-
supplied_headers=headers,
305-
files=files,
306-
request_id=request_id,
307-
request_timeout=request_timeout,
308-
)
309-
resp, got_stream = await self._interpret_async_response(result, stream)
300+
try:
301+
result = await self.arequest_raw(
302+
method.lower(),
303+
url,
304+
session,
305+
params=params,
306+
supplied_headers=headers,
307+
files=files,
308+
request_id=request_id,
309+
request_timeout=request_timeout,
310+
)
311+
resp, got_stream = await self._interpret_async_response(result, stream)
312+
except Exception:
313+
await ctx.__aexit__(None, None, None)
314+
raise
310315
if got_stream:
311316

312317
async def wrap_resp():
318+
assert isinstance(resp, AsyncGenerator)
313319
try:
314320
async for r in resp:
315321
yield r
@@ -612,7 +618,10 @@ def _interpret_response(
612618
else:
613619
return (
614620
self._interpret_response_line(
615-
result.content, result.status_code, result.headers, stream=False
621+
result.content.decode("utf-8"),
622+
result.status_code,
623+
result.headers,
624+
stream=False,
616625
),
617626
False,
618627
)
@@ -635,13 +644,16 @@ async def _interpret_async_response(
635644
util.log_warn(e, body=result.content)
636645
return (
637646
self._interpret_response_line(
638-
await result.read(), result.status, result.headers, stream=False
647+
(await result.read()).decode("utf-8"),
648+
result.status,
649+
result.headers,
650+
stream=False,
639651
),
640652
False,
641653
)
642654

643655
def _interpret_response_line(
644-
self, rbody, rcode, rheaders, stream: bool
656+
self, rbody: str, rcode: int, rheaders, stream: bool
645657
) -> OpenAIResponse:
646658
# HTTP 204 response code does not have any content in the body.
647659
if rcode == 204:
@@ -655,13 +667,11 @@ def _interpret_response_line(
655667
headers=rheaders,
656668
)
657669
try:
658-
if hasattr(rbody, "decode"):
659-
rbody = rbody.decode("utf-8")
660670
data = json.loads(rbody)
661-
except (JSONDecodeError, UnicodeDecodeError):
671+
except (JSONDecodeError, UnicodeDecodeError) as e:
662672
raise error.APIError(
663673
f"HTTP code {rcode} from API ({rbody})", rbody, rcode, headers=rheaders
664-
)
674+
) from e
665675
resp = OpenAIResponse(data, rheaders)
666676
# In the future, we might add a "status" parameter to errors
667677
# to better handle the "error while streaming" case.

openai/tests/asyncio/test_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
import json
33

44
import pytest
5+
from aiohttp import ClientSession
56

67
import openai
78
from openai import error
8-
from aiohttp import ClientSession
99

1010

1111
pytestmark = [pytest.mark.asyncio]

0 commit comments

Comments
 (0)