Skip to content

Commit 5a36037

Browse files
Fix unclosed generator on trio
1 parent f0fd919 commit 5a36037

File tree

5 files changed

+76
-29
lines changed

5 files changed

+76
-29
lines changed

httpx/_client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,8 @@ def __init__(
142142
self._response = response
143143
self._timer = timer
144144

145-
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
146-
async for chunk in self._stream:
147-
yield chunk
145+
def __aiter__(self) -> typing.AsyncIterator[bytes]:
146+
return self._stream.__aiter__()
148147

149148
async def aclose(self) -> None:
150149
seconds = await self._timer.async_elapsed()

httpx/_models.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json as jsonlib
44
import typing
55
import urllib.request
6+
from contextlib import aclosing
67
from collections.abc import Mapping
78
from http.cookiejar import Cookie, CookieJar
89

@@ -911,7 +912,7 @@ async def aread(self) -> bytes:
911912

912913
async def aiter_bytes(
913914
self, chunk_size: typing.Optional[int] = None
914-
) -> typing.AsyncIterator[bytes]:
915+
) -> typing.AsyncGenerator[bytes, None]:
915916
"""
916917
A byte-iterator over the decoded response content.
917918
This allows us to handle gzip, deflate, and brotli encoded responses.
@@ -924,19 +925,20 @@ async def aiter_bytes(
924925
decoder = self._get_content_decoder()
925926
chunker = ByteChunker(chunk_size=chunk_size)
926927
with request_context(request=self._request):
927-
async for raw_bytes in self.aiter_raw():
928-
decoded = decoder.decode(raw_bytes)
928+
async with aclosing(self.aiter_raw()) as stream:
929+
async for raw_bytes in stream:
930+
decoded = decoder.decode(raw_bytes)
931+
for chunk in chunker.decode(decoded):
932+
yield chunk
933+
decoded = decoder.flush()
929934
for chunk in chunker.decode(decoded):
935+
yield chunk # pragma: no cover
936+
for chunk in chunker.flush():
930937
yield chunk
931-
decoded = decoder.flush()
932-
for chunk in chunker.decode(decoded):
933-
yield chunk # pragma: no cover
934-
for chunk in chunker.flush():
935-
yield chunk
936938

937939
async def aiter_text(
938940
self, chunk_size: typing.Optional[int] = None
939-
) -> typing.AsyncIterator[str]:
941+
) -> typing.AsyncGenerator[str, None]:
940942
"""
941943
A str-iterator over the decoded response content
942944
that handles both gzip, deflate, etc but also detects the content's
@@ -945,28 +947,30 @@ async def aiter_text(
945947
decoder = TextDecoder(encoding=self.encoding or "utf-8")
946948
chunker = TextChunker(chunk_size=chunk_size)
947949
with request_context(request=self._request):
948-
async for byte_content in self.aiter_bytes():
949-
text_content = decoder.decode(byte_content)
950+
async with aclosing(self.aiter_bytes()) as stream:
951+
async for byte_content in stream:
952+
text_content = decoder.decode(byte_content)
953+
for chunk in chunker.decode(text_content):
954+
yield chunk
955+
text_content = decoder.flush()
950956
for chunk in chunker.decode(text_content):
951957
yield chunk
952-
text_content = decoder.flush()
953-
for chunk in chunker.decode(text_content):
954-
yield chunk
955-
for chunk in chunker.flush():
956-
yield chunk
958+
for chunk in chunker.flush():
959+
yield chunk
957960

958-
async def aiter_lines(self) -> typing.AsyncIterator[str]:
961+
async def aiter_lines(self) -> typing.AsyncGenerator[str, None]:
959962
decoder = LineDecoder()
960963
with request_context(request=self._request):
961-
async for text in self.aiter_text():
962-
for line in decoder.decode(text):
964+
async with aclosing(self.aiter_text()) as stream:
965+
async for text in stream:
966+
for line in decoder.decode(text):
967+
yield line
968+
for line in decoder.flush():
963969
yield line
964-
for line in decoder.flush():
965-
yield line
966970

967971
async def aiter_raw(
968972
self, chunk_size: typing.Optional[int] = None
969-
) -> typing.AsyncIterator[bytes]:
973+
) -> typing.AsyncGenerator[bytes, None]:
970974
"""
971975
A byte-iterator over the raw response content.
972976
"""

httpx/_transports/default.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,14 @@ def close(self) -> None:
232232

233233
class AsyncResponseStream(AsyncByteStream):
234234
def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]):
235-
self._httpcore_stream = httpcore_stream
235+
self._httpcore_stream = httpcore_stream.__aiter__()
236+
237+
def __aiter__(self) -> typing.AsyncIterator[bytes]:
238+
return self
236239

237-
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
240+
async def __anext__(self) -> bytes:
238241
with map_httpcore_exceptions():
239-
async for part in self._httpcore_stream:
240-
yield part
242+
return await self._httpcore_stream.__anext__()
241243

242244
async def aclose(self) -> None:
243245
if hasattr(self._httpcore_stream, "aclose"):

tests/client/test_async_client.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import typing
2+
from contextlib import aclosing
23
from datetime import timedelta
34

45
import pytest
@@ -76,6 +77,34 @@ async def test_stream_response(server):
7677
assert response.content == b"Hello, world!"
7778

7879

80+
@pytest.mark.anyio
81+
async def test_stream_iterator(server):
82+
body = b""
83+
84+
async with httpx.AsyncClient() as client:
85+
async with client.stream("GET", server.url) as response:
86+
async for chunk in response.aiter_bytes():
87+
body += chunk
88+
89+
assert response.status_code == 200
90+
assert body == b"Hello, world!"
91+
92+
93+
@pytest.mark.anyio
94+
async def test_stream_iterator_partial(server):
95+
body = ""
96+
97+
async with httpx.AsyncClient() as client:
98+
async with client.stream("GET", server.url) as response:
99+
async with aclosing(response.aiter_text(5)) as stream:
100+
async for chunk in stream:
101+
body += chunk
102+
break
103+
104+
assert response.status_code == 200
105+
assert body == "Hello"
106+
107+
79108
@pytest.mark.anyio
80109
async def test_access_content_stream_response(server):
81110
async with httpx.AsyncClient() as client:

tests/client/test_client.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,19 @@ def test_stream_iterator(server):
107107
assert body == b"Hello, world!"
108108

109109

110+
def test_stream_iterator_partial(server):
111+
body = ""
112+
113+
with httpx.Client() as client:
114+
with client.stream("GET", server.url) as response:
115+
for chunk in response.iter_text(5):
116+
body += chunk
117+
break
118+
119+
assert response.status_code == 200
120+
assert body == "Hello"
121+
122+
110123
def test_raw_iterator(server):
111124
body = b""
112125

0 commit comments

Comments
 (0)