Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [Unreleased]

- Explicitly close all async generators to ensure predictable behavior

## Version 1.0.9 (April 24th, 2025)

- Resolve https://github.com/advisories/GHSA-vqfr-h8mv-ghfj with h11 dependency update. (#1008)
Expand Down
22 changes: 12 additions & 10 deletions httpcore/_async/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,21 @@
import sys
import types
import typing
from collections.abc import AsyncGenerator

from .._backends.auto import AutoBackend
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
from .._models import Origin, Proxy, Request, Response
from .._synchronization import AsyncEvent, AsyncShieldCancellation, AsyncThreadLock
from .._utils import aclosing
from .connection import AsyncHTTPConnection
from .interfaces import AsyncConnectionInterface, AsyncRequestInterface

if typing.TYPE_CHECKING:
from .http2 import HTTP2ConnectionByteStream
from .http11 import HTTP11ConnectionByteStream


class AsyncPoolRequest:
def __init__(self, request: Request) -> None:
Expand Down Expand Up @@ -389,7 +395,7 @@ def __repr__(self) -> str:
class PoolByteStream:
def __init__(
self,
stream: typing.AsyncIterable[bytes],
stream: HTTP11ConnectionByteStream | HTTP2ConnectionByteStream,
pool_request: AsyncPoolRequest,
pool: AsyncConnectionPool,
) -> None:
Expand All @@ -398,20 +404,16 @@ def __init__(
self._pool = pool
self._closed = False

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
try:
async for part in self._stream:
yield part
except BaseException as exc:
await self.aclose()
raise exc from None
async def __aiter__(self) -> AsyncGenerator[bytes]:
async with aclosing(self._stream.__aiter__()) as iterator:
async for chunk in iterator:
yield chunk

async def aclose(self) -> None:
if not self._closed:
self._closed = True
with AsyncShieldCancellation():
if hasattr(self._stream, "aclose"):
await self._stream.aclose()
await self._stream.aclose()

with self._pool._optional_thread_lock:
self._pool._requests.remove(self._pool_request)
Expand Down
15 changes: 9 additions & 6 deletions httpcore/_async/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
import types
import typing
from collections.abc import AsyncGenerator

import h11

Expand All @@ -20,6 +21,7 @@
from .._models import Origin, Request, Response
from .._synchronization import AsyncLock, AsyncShieldCancellation
from .._trace import Trace
from .._utils import aclosing
from .interfaces import AsyncConnectionInterface

logger = logging.getLogger("httpcore.http11")
Expand Down Expand Up @@ -193,9 +195,7 @@ async def _receive_response_headers(

return http_version, event.status_code, event.reason, headers, trailing_data

async def _receive_response_body(
self, request: Request
) -> typing.AsyncIterator[bytes]:
async def _receive_response_body(self, request: Request) -> AsyncGenerator[bytes]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)

Expand Down Expand Up @@ -327,12 +327,15 @@ def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None:
self._request = request
self._closed = False

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async def __aiter__(self) -> AsyncGenerator[bytes]:
kwargs = {"request": self._request}
try:
async with Trace("receive_response_body", logger, self._request, kwargs):
async for chunk in self._connection._receive_response_body(**kwargs):
yield chunk
async with aclosing(
self._connection._receive_response_body(**kwargs)
) as body:
async for chunk in body:
yield chunk
except BaseException as exc:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
Expand Down
5 changes: 3 additions & 2 deletions httpcore/_async/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
import types
import typing
from collections.abc import AsyncGenerator

import h2.config
import h2.connection
Expand Down Expand Up @@ -308,7 +309,7 @@ async def _receive_response(

async def _receive_response_body(
self, request: Request, stream_id: int
) -> typing.AsyncIterator[bytes]:
) -> AsyncGenerator[bytes]:
"""
Iterator that returns the bytes of the response body for a given stream ID.
"""
Expand Down Expand Up @@ -568,7 +569,7 @@ def __init__(
self._stream_id = stream_id
self._closed = False

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async def __aiter__(self) -> AsyncGenerator[bytes]:
kwargs = {"request": self._request, "stream_id": self._stream_id}
try:
async with Trace("receive_response_body", logger, self._request, kwargs):
Expand Down
5 changes: 3 additions & 2 deletions httpcore/_async/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextlib
import typing
from collections.abc import AsyncGenerator

from .._models import (
URL,
Expand Down Expand Up @@ -56,9 +57,9 @@ async def stream(
url: URL | bytes | str,
*,
headers: HeaderTypes = None,
content: bytes | typing.AsyncIterator[bytes] | None = None,
content: bytes | AsyncGenerator[bytes] | None = None,
extensions: Extensions | None = None,
) -> typing.AsyncIterator[Response]:
) -> AsyncGenerator[Response]:
# Strict type checking on our parameters.
method = enforce_bytes(method, name="method")
url = enforce_url(url, name="url")
Expand Down
15 changes: 10 additions & 5 deletions httpcore/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import ssl
import typing
import urllib.parse
from collections.abc import AsyncGenerator

from ._utils import aclosing

# Functions for typechecking...

Expand Down Expand Up @@ -151,7 +154,7 @@ def __init__(self, content: bytes) -> None:
def __iter__(self) -> typing.Iterator[bytes]:
yield self._content

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async def __aiter__(self) -> AsyncGenerator[bytes]:
yield self._content

def __repr__(self) -> str:
Expand Down Expand Up @@ -463,10 +466,11 @@ async def aread(self) -> bytes:
"You should use 'response.read()' instead."
)
if not hasattr(self, "_content"):
self._content = b"".join([part async for part in self.aiter_stream()])
async with aclosing(self.aiter_stream()) as parts:
self._content = b"".join([part async for part in parts])
return self._content

async def aiter_stream(self) -> typing.AsyncIterator[bytes]:
async def aiter_stream(self) -> AsyncGenerator[bytes]:
if not isinstance(self.stream, typing.AsyncIterable): # pragma: nocover
raise RuntimeError(
"Attempted to stream an synchronous response using 'async for ... in "
Expand All @@ -479,8 +483,9 @@ async def aiter_stream(self) -> typing.AsyncIterator[bytes]:
"more than once."
)
self._stream_consumed = True
async for chunk in self.stream:
yield chunk
async with aclosing(self.stream) as parts:
async for chunk in parts:
yield chunk

async def aclose(self) -> None:
if not isinstance(self.stream, typing.AsyncIterable): # pragma: nocover
Expand Down
15 changes: 15 additions & 0 deletions httpcore/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,21 @@
import socket
import sys

if sys.version_info >= (3, 10):
from contextlib import aclosing as aclosing
else:
from contextlib import AbstractAsyncContextManager

class aclosing(AbstractAsyncContextManager):
def __init__(self, thing):
self.thing = thing

async def __aenter__(self):
return self.thing

async def __aexit__(self, *exc_info):
await self.thing.aclose()


def is_socket_readable(sock: socket.socket | None) -> bool:
"""
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ twine==6.1.0
# Tests & Linting
coverage[toml]==7.5.4
ruff==0.5.0
mypy==1.10.1
mypy==1.14.1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe un-downgrade mypy, since Python 3.8 support was dropped.

trio-typing==0.10.0
pytest==8.2.2
pytest-httpbin==2.0.0
Expand Down
65 changes: 34 additions & 31 deletions scripts/unasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,36 @@
from pprint import pprint

SUBS = [
('from .._backends.auto import AutoBackend', 'from .._backends.sync import SyncBackend'),
('import trio as concurrency', 'from tests import concurrency'),
('AsyncIterator', 'Iterator'),
('Async([A-Z][A-Za-z0-9_]*)', r'\2'),
('async def', 'def'),
('async with', 'with'),
('async for', 'for'),
('await ', ''),
('handle_async_request', 'handle_request'),
('aclose', 'close'),
('aiter_stream', 'iter_stream'),
('aread', 'read'),
('asynccontextmanager', 'contextmanager'),
('__aenter__', '__enter__'),
('__aexit__', '__exit__'),
('__aiter__', '__iter__'),
('@pytest.mark.anyio', ''),
('@pytest.mark.trio', ''),
('AutoBackend', 'SyncBackend'),
(
"from .._backends.auto import AutoBackend",
"from .._backends.sync import SyncBackend",
),
("import trio as concurrency", "from tests import concurrency"),
("AsyncIterator", "Iterator"),
("Async([A-Z][A-Za-z0-9_]*)", r"\2"),
("async def", "def"),
("async with", "with"),
("async for", "for"),
("await ", ""),
("handle_async_request", "handle_request"),
("aclose", "close"),
("aiter_stream", "iter_stream"),
("aread", "read"),
("asynccontextmanager", "contextmanager"),
("__aenter__", "__enter__"),
("__aexit__", "__exit__"),
("__aiter__", "__iter__"),
("@pytest.mark.anyio", ""),
("@pytest.mark.trio", ""),
("AutoBackend", "SyncBackend"),
]
COMPILED_SUBS = [
(re.compile(r'(^|\b)' + regex + r'($|\b)'), repl)
for regex, repl in SUBS
(re.compile(r"(^|\b)" + regex + r"($|\b)"), repl) for regex, repl in SUBS
]

USED_SUBS = set()


def unasync_line(line):
for index, (regex, repl) in enumerate(COMPILED_SUBS):
old_line = line
Expand All @@ -55,30 +58,30 @@ def unasync_file_check(in_path, out_path):
for in_line, out_line in zip(in_file.readlines(), out_file.readlines()):
expected = unasync_line(in_line)
if out_line != expected:
print(f'unasync mismatch between {in_path!r} and {out_path!r}')
print(f'Async code: {in_line!r}')
print(f'Expected sync code: {expected!r}')
print(f'Actual sync code: {out_line!r}')
print(f"unasync mismatch between {in_path!r} and {out_path!r}")
print(f"Async code: {in_line!r}")
print(f"Expected sync code: {expected!r}")
print(f"Actual sync code: {out_line!r}")
sys.exit(1)


def unasync_dir(in_dir, out_dir, check_only=False):
for dirpath, dirnames, filenames in os.walk(in_dir):
for filename in filenames:
if not filename.endswith('.py'):
if not filename.endswith(".py"):
continue
rel_dir = os.path.relpath(dirpath, in_dir)
in_path = os.path.normpath(os.path.join(in_dir, rel_dir, filename))
out_path = os.path.normpath(os.path.join(out_dir, rel_dir, filename))
print(in_path, '->', out_path)
print(in_path, "->", out_path)
if check_only:
unasync_file_check(in_path, out_path)
else:
unasync_file(in_path, out_path)


def main():
check_only = '--check' in sys.argv
check_only = "--check" in sys.argv
unasync_dir("httpcore/_async", "httpcore/_sync", check_only=check_only)
unasync_dir("tests/_async", "tests/_sync", check_only=check_only)

Expand All @@ -87,8 +90,8 @@ def main():

print("These patterns were not used:")
pprint(unused_subs)
exit(1)
exit(1)


if __name__ == '__main__':
if __name__ == "__main__":
main()
3 changes: 0 additions & 3 deletions tests/test_cancellations.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ async def test_h11_timeout_during_response():
assert conn.is_closed()


@pytest.mark.xfail
@pytest.mark.anyio
async def test_h2_timeout_during_handshake():
"""
Expand All @@ -186,7 +185,6 @@ async def test_h2_timeout_during_handshake():
assert conn.is_closed()


@pytest.mark.xfail
@pytest.mark.anyio
async def test_h2_timeout_during_request():
"""
Expand All @@ -207,7 +205,6 @@ async def test_h2_timeout_during_request():
assert conn.is_idle()


@pytest.mark.xfail
@pytest.mark.anyio
async def test_h2_timeout_during_response():
"""
Expand Down
3 changes: 3 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
for chunk in self._chunks:
yield chunk

async def aclose(self) -> None:
pass


@pytest.mark.trio
async def test_response_async_read():
Expand Down
Loading