diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index c3ad08f14..14567a26e 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -14,11 +14,11 @@ jobs: strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: "actions/checkout@v4" - - uses: "actions/setup-python@v4" + - uses: "actions/setup-python@v5" with: python-version: "${{ matrix.python-version }}" allow-prereleases: true diff --git a/CHANGELOG.md b/CHANGELOG.md index b43d7f01a..49ca238b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## [Unreleased] + +- Drop Python 3.8 support +- Explicitly close all async generators to ensure predictable behavior + ## Version 1.0.9 (April 24th, 2025) - Resolve https://github.com/advisories/GHSA-vqfr-h8mv-ghfj with h11 dependency update. (#1008) diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index 96e973d0c..f95b7186d 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -4,12 +4,14 @@ import sys import types import typing +from collections.abc import AsyncGenerator from .._backends.auto import AutoBackend from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol from .._models import Origin, Proxy, Request, Response from .._synchronization import AsyncEvent, AsyncShieldCancellation, AsyncThreadLock +from .._utils import safe_async_iterate from .connection import AsyncHTTPConnection from .interfaces import AsyncConnectionInterface, AsyncRequestInterface @@ -398,13 +400,10 @@ def __init__( self._pool = pool self._closed = False - async def __aiter__(self) -> typing.AsyncIterator[bytes]: - try: - async for part in self._stream: - yield part - except BaseException as exc: - await self.aclose() - raise exc from None + async def __aiter__(self) -> AsyncGenerator[bytes]: + async with safe_async_iterate(self._stream) as iterator: + async for chunk in iterator: + yield chunk async def aclose(self) -> None: if not self._closed: diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py index e6d6d7098..20eb66188 100644 --- a/httpcore/_async/http11.py +++ b/httpcore/_async/http11.py @@ -6,6 +6,7 @@ import time import types import typing +from collections.abc import AsyncGenerator import h11 @@ -20,6 +21,7 @@ from .._models import Origin, Request, Response from .._synchronization import AsyncLock, AsyncShieldCancellation from .._trace import Trace +from .._utils import safe_async_iterate from .interfaces import AsyncConnectionInterface logger = logging.getLogger("httpcore.http11") @@ -154,9 +156,10 @@ async def _send_request_body(self, request: Request) -> None: timeout = timeouts.get("write", None) assert isinstance(request.stream, typing.AsyncIterable) - async for chunk in request.stream: - event = h11.Data(data=chunk) - await self._send_event(event, timeout=timeout) + async with safe_async_iterate(request.stream) as iterator: + async for chunk in iterator: + event = h11.Data(data=chunk) + await self._send_event(event, timeout=timeout) await self._send_event(h11.EndOfMessage(), timeout=timeout) @@ -193,9 +196,7 @@ async def _receive_response_headers( return http_version, event.status_code, event.reason, headers, trailing_data - async def _receive_response_body( - self, request: Request - ) -> typing.AsyncIterator[bytes]: + async def _receive_response_body(self, request: Request) -> AsyncGenerator[bytes]: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("read", None) @@ -327,12 +328,15 @@ def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None: self._request = request self._closed = False - async def __aiter__(self) -> typing.AsyncIterator[bytes]: + async def __aiter__(self) -> AsyncGenerator[bytes]: kwargs = {"request": self._request} try: async with Trace("receive_response_body", logger, self._request, kwargs): - async for chunk in self._connection._receive_response_body(**kwargs): - yield chunk + async with safe_async_iterate( + self._connection._receive_response_body(**kwargs) + ) as iterator: + async for chunk in iterator: + yield chunk except BaseException as exc: # If we get an exception while streaming the response, # we want to close the response (and possibly the connection) diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index dbd0beeb4..4e9d8fb5b 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -5,6 +5,7 @@ import time import types import typing +from collections.abc import AsyncGenerator import h2.config import h2.connection @@ -21,6 +22,7 @@ from .._models import Origin, Request, Response from .._synchronization import AsyncLock, AsyncSemaphore, AsyncShieldCancellation from .._trace import Trace +from .._utils import safe_async_iterate from .interfaces import AsyncConnectionInterface logger = logging.getLogger("httpcore.http2") @@ -258,8 +260,10 @@ async def _send_request_body(self, request: Request, stream_id: int) -> None: return assert isinstance(request.stream, typing.AsyncIterable) - async for data in request.stream: - await self._send_stream_data(request, stream_id, data) + async with safe_async_iterate(request.stream) as iterator: + async for chunk in iterator: + await self._send_stream_data(request, stream_id, chunk) + await self._send_end_stream(request, stream_id) async def _send_stream_data( @@ -308,7 +312,7 @@ async def _receive_response( async def _receive_response_body( self, request: Request, stream_id: int - ) -> typing.AsyncIterator[bytes]: + ) -> AsyncGenerator[bytes]: """ Iterator that returns the bytes of the response body for a given stream ID. """ @@ -568,14 +572,17 @@ def __init__( self._stream_id = stream_id self._closed = False - async def __aiter__(self) -> typing.AsyncIterator[bytes]: + async def __aiter__(self) -> AsyncGenerator[bytes]: kwargs = {"request": self._request, "stream_id": self._stream_id} try: async with Trace("receive_response_body", logger, self._request, kwargs): - async for chunk in self._connection._receive_response_body( - request=self._request, stream_id=self._stream_id - ): - yield chunk + async with safe_async_iterate( + self._connection._receive_response_body( + request=self._request, stream_id=self._stream_id + ) + ) as iterator: + async for chunk in iterator: + yield chunk except BaseException as exc: # If we get an exception while streaming the response, # we want to close the response (and possibly the connection) diff --git a/httpcore/_async/interfaces.py b/httpcore/_async/interfaces.py index 361583bed..92859b6a6 100644 --- a/httpcore/_async/interfaces.py +++ b/httpcore/_async/interfaces.py @@ -2,6 +2,7 @@ import contextlib import typing +from collections.abc import AsyncGenerator from .._models import ( URL, @@ -58,7 +59,7 @@ async def stream( headers: HeaderTypes = None, content: bytes | typing.AsyncIterator[bytes] | None = None, extensions: Extensions | None = None, - ) -> typing.AsyncIterator[Response]: + ) -> AsyncGenerator[Response]: # Strict type checking on our parameters. method = enforce_bytes(method, name="method") url = enforce_url(url, name="url") diff --git a/httpcore/_models.py b/httpcore/_models.py index 8a65f1334..cc5403f79 100644 --- a/httpcore/_models.py +++ b/httpcore/_models.py @@ -4,6 +4,9 @@ import ssl import typing import urllib.parse +from collections.abc import AsyncGenerator + +from ._utils import safe_async_iterate # Functions for typechecking... @@ -151,7 +154,7 @@ def __init__(self, content: bytes) -> None: def __iter__(self) -> typing.Iterator[bytes]: yield self._content - async def __aiter__(self) -> typing.AsyncIterator[bytes]: + async def __aiter__(self) -> AsyncGenerator[bytes]: yield self._content def __repr__(self) -> str: @@ -463,10 +466,11 @@ async def aread(self) -> bytes: "You should use 'response.read()' instead." ) if not hasattr(self, "_content"): - self._content = b"".join([part async for part in self.aiter_stream()]) + async with safe_async_iterate(self.aiter_stream()) as parts: + self._content = b"".join([part async for part in parts]) return self._content - async def aiter_stream(self) -> typing.AsyncIterator[bytes]: + async def aiter_stream(self) -> AsyncGenerator[bytes]: if not isinstance(self.stream, typing.AsyncIterable): # pragma: nocover raise RuntimeError( "Attempted to stream an synchronous response using 'async for ... in " @@ -479,8 +483,9 @@ async def aiter_stream(self) -> typing.AsyncIterator[bytes]: "more than once." ) self._stream_consumed = True - async for chunk in self.stream: - yield chunk + async with safe_async_iterate(self.stream) as iterator: + async for chunk in iterator: + yield chunk async def aclose(self) -> None: if not isinstance(self.stream, typing.AsyncIterable): # pragma: nocover diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 9ccfa53e5..32c57ebc8 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -4,12 +4,14 @@ import sys import types import typing +from collections.abc import Generator from .._backends.sync import SyncBackend from .._backends.base import SOCKET_OPTION, NetworkBackend from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol from .._models import Origin, Proxy, Request, Response from .._synchronization import Event, ShieldCancellation, ThreadLock +from .._utils import safe_iterate from .connection import HTTPConnection from .interfaces import ConnectionInterface, RequestInterface @@ -398,13 +400,10 @@ def __init__( self._pool = pool self._closed = False - def __iter__(self) -> typing.Iterator[bytes]: - try: - for part in self._stream: - yield part - except BaseException as exc: - self.close() - raise exc from None + def __iter__(self) -> Generator[bytes]: + with safe_iterate(self._stream) as iterator: + for chunk in iterator: + yield chunk def close(self) -> None: if not self._closed: diff --git a/httpcore/_sync/http11.py b/httpcore/_sync/http11.py index ebd3a9748..a6763c5ff 100644 --- a/httpcore/_sync/http11.py +++ b/httpcore/_sync/http11.py @@ -6,6 +6,7 @@ import time import types import typing +from collections.abc import Generator import h11 @@ -20,6 +21,7 @@ from .._models import Origin, Request, Response from .._synchronization import Lock, ShieldCancellation from .._trace import Trace +from .._utils import safe_iterate from .interfaces import ConnectionInterface logger = logging.getLogger("httpcore.http11") @@ -154,9 +156,10 @@ def _send_request_body(self, request: Request) -> None: timeout = timeouts.get("write", None) assert isinstance(request.stream, typing.Iterable) - for chunk in request.stream: - event = h11.Data(data=chunk) - self._send_event(event, timeout=timeout) + with safe_iterate(request.stream) as iterator: + for chunk in iterator: + event = h11.Data(data=chunk) + self._send_event(event, timeout=timeout) self._send_event(h11.EndOfMessage(), timeout=timeout) @@ -193,9 +196,7 @@ def _receive_response_headers( return http_version, event.status_code, event.reason, headers, trailing_data - def _receive_response_body( - self, request: Request - ) -> typing.Iterator[bytes]: + def _receive_response_body(self, request: Request) -> Generator[bytes]: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("read", None) @@ -327,12 +328,15 @@ def __init__(self, connection: HTTP11Connection, request: Request) -> None: self._request = request self._closed = False - def __iter__(self) -> typing.Iterator[bytes]: + def __iter__(self) -> Generator[bytes]: kwargs = {"request": self._request} try: with Trace("receive_response_body", logger, self._request, kwargs): - for chunk in self._connection._receive_response_body(**kwargs): - yield chunk + with safe_iterate( + self._connection._receive_response_body(**kwargs) + ) as iterator: + for chunk in iterator: + yield chunk except BaseException as exc: # If we get an exception while streaming the response, # we want to close the response (and possibly the connection) diff --git a/httpcore/_sync/http2.py b/httpcore/_sync/http2.py index ddcc18900..238f805ec 100644 --- a/httpcore/_sync/http2.py +++ b/httpcore/_sync/http2.py @@ -5,6 +5,7 @@ import time import types import typing +from collections.abc import Generator import h2.config import h2.connection @@ -21,6 +22,7 @@ from .._models import Origin, Request, Response from .._synchronization import Lock, Semaphore, ShieldCancellation from .._trace import Trace +from .._utils import safe_iterate from .interfaces import ConnectionInterface logger = logging.getLogger("httpcore.http2") @@ -258,8 +260,10 @@ def _send_request_body(self, request: Request, stream_id: int) -> None: return assert isinstance(request.stream, typing.Iterable) - for data in request.stream: - self._send_stream_data(request, stream_id, data) + with safe_iterate(request.stream) as iterator: + for chunk in iterator: + self._send_stream_data(request, stream_id, chunk) + self._send_end_stream(request, stream_id) def _send_stream_data( @@ -308,7 +312,7 @@ def _receive_response( def _receive_response_body( self, request: Request, stream_id: int - ) -> typing.Iterator[bytes]: + ) -> Generator[bytes]: """ Iterator that returns the bytes of the response body for a given stream ID. """ @@ -568,14 +572,17 @@ def __init__( self._stream_id = stream_id self._closed = False - def __iter__(self) -> typing.Iterator[bytes]: + def __iter__(self) -> Generator[bytes]: kwargs = {"request": self._request, "stream_id": self._stream_id} try: with Trace("receive_response_body", logger, self._request, kwargs): - for chunk in self._connection._receive_response_body( - request=self._request, stream_id=self._stream_id - ): - yield chunk + with safe_iterate( + self._connection._receive_response_body( + request=self._request, stream_id=self._stream_id + ) + ) as iterator: + for chunk in iterator: + yield chunk except BaseException as exc: # If we get an exception while streaming the response, # we want to close the response (and possibly the connection) diff --git a/httpcore/_sync/interfaces.py b/httpcore/_sync/interfaces.py index e673d4cc1..130cd532a 100644 --- a/httpcore/_sync/interfaces.py +++ b/httpcore/_sync/interfaces.py @@ -2,6 +2,7 @@ import contextlib import typing +from collections.abc import Generator from .._models import ( URL, @@ -58,7 +59,7 @@ def stream( headers: HeaderTypes = None, content: bytes | typing.Iterator[bytes] | None = None, extensions: Extensions | None = None, - ) -> typing.Iterator[Response]: + ) -> Generator[Response]: # Strict type checking on our parameters. method = enforce_bytes(method, name="method") url = enforce_url(url, name="url") diff --git a/httpcore/_utils.py b/httpcore/_utils.py index c44ff93cb..6951457df 100644 --- a/httpcore/_utils.py +++ b/httpcore/_utils.py @@ -3,6 +3,19 @@ import select import socket import sys +import typing +from collections.abc import ( + AsyncGenerator, + AsyncIterable, + AsyncIterator, + Generator, + Iterable, + Iterator, +) +from contextlib import asynccontextmanager, contextmanager +from inspect import isasyncgen + +T = typing.TypeVar("T") def is_socket_readable(sock: socket.socket | None) -> bool: @@ -35,3 +48,32 @@ def is_socket_readable(sock: socket.socket | None) -> bool: p = select.poll() p.register(sock_fd, select.POLLIN) return bool(p.poll(0)) + + +@asynccontextmanager +async def safe_async_iterate( + iterable_or_iterator: AsyncIterable[T] | AsyncIterator[T], / +) -> AsyncGenerator[AsyncIterator[T]]: + iterator = ( + iterable_or_iterator + if isinstance(iterable_or_iterator, AsyncIterator) + else iterable_or_iterator.__aiter__() + ) + try: + yield iterator + finally: + if isasyncgen(iterator): + await iterator.aclose() + + +@contextmanager +def safe_iterate( + iterable_or_iterator: Iterable[T] | Iterator[T], / +) -> Generator[Iterator[T], None, None]: + # This is boilerplate code, only needed to make unasync happy + iterator = ( + iterable_or_iterator + if isinstance(iterable_or_iterator, Iterator) + else iterable_or_iterator.__iter__() + ) + yield iterator diff --git a/pyproject.toml b/pyproject.toml index 1bdd99eb9..fe143b2e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "httpcore" dynamic = ["readme", "version"] description = "A minimal low-level HTTP client." license = "BSD-3-Clause" -requires-python = ">=3.8" +requires-python = ">=3.9" authors = [ { name = "Tom Christie", email = "tom@tomchristie.com" }, ] @@ -26,6 +26,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Internet :: WWW/HTTP", ] dependencies = [ @@ -96,7 +97,7 @@ filterwarnings = ["error"] [tool.coverage.run] omit = [ - "venv/*", + "venv/*", "httpcore/_sync/*" ] include = ["httpcore/*", "tests/*"] diff --git a/requirements.txt b/requirements.txt index 880330c6f..b87315b8f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,7 @@ twine==6.1.0 # Tests & Linting coverage[toml]==7.5.4 ruff==0.5.0 -mypy==1.10.1 +mypy==1.16.1 trio-typing==0.10.0 pytest==8.2.2 pytest-httpbin==2.0.0 diff --git a/scripts/unasync.py b/scripts/unasync.py index 5a5627d71..4dc173576 100644 --- a/scripts/unasync.py +++ b/scripts/unasync.py @@ -5,33 +5,37 @@ from pprint import pprint SUBS = [ - ('from .._backends.auto import AutoBackend', 'from .._backends.sync import SyncBackend'), - ('import trio as concurrency', 'from tests import concurrency'), - ('AsyncIterator', 'Iterator'), - ('Async([A-Z][A-Za-z0-9_]*)', r'\2'), - ('async def', 'def'), - ('async with', 'with'), - ('async for', 'for'), - ('await ', ''), - ('handle_async_request', 'handle_request'), - ('aclose', 'close'), - ('aiter_stream', 'iter_stream'), - ('aread', 'read'), - ('asynccontextmanager', 'contextmanager'), - ('__aenter__', '__enter__'), - ('__aexit__', '__exit__'), - ('__aiter__', '__iter__'), - ('@pytest.mark.anyio', ''), - ('@pytest.mark.trio', ''), - ('AutoBackend', 'SyncBackend'), + ( + "from .._backends.auto import AutoBackend", + "from .._backends.sync import SyncBackend", + ), + ("import trio as concurrency", "from tests import concurrency"), + ("AsyncIterator", "Iterator"), + ("Async([A-Z][A-Za-z0-9_]*)", r"\2"), + ("async def", "def"), + ("async with", "with"), + ("async for", "for"), + ("await ", ""), + ("handle_async_request", "handle_request"), + ("aclose", "close"), + ("aiter_stream", "iter_stream"), + ("aread", "read"), + ("asynccontextmanager", "contextmanager"), + ("safe_async_iterate", "safe_iterate"), + ("__aenter__", "__enter__"), + ("__aexit__", "__exit__"), + ("__aiter__", "__iter__"), + ("@pytest.mark.anyio", ""), + ("@pytest.mark.trio", ""), + ("AutoBackend", "SyncBackend"), ] COMPILED_SUBS = [ - (re.compile(r'(^|\b)' + regex + r'($|\b)'), repl) - for regex, repl in SUBS + (re.compile(r"(^|\b)" + regex + r"($|\b)"), repl) for regex, repl in SUBS ] USED_SUBS = set() + def unasync_line(line): for index, (regex, repl) in enumerate(COMPILED_SUBS): old_line = line @@ -55,22 +59,22 @@ def unasync_file_check(in_path, out_path): for in_line, out_line in zip(in_file.readlines(), out_file.readlines()): expected = unasync_line(in_line) if out_line != expected: - print(f'unasync mismatch between {in_path!r} and {out_path!r}') - print(f'Async code: {in_line!r}') - print(f'Expected sync code: {expected!r}') - print(f'Actual sync code: {out_line!r}') + print(f"unasync mismatch between {in_path!r} and {out_path!r}") + print(f"Async code: {in_line!r}") + print(f"Expected sync code: {expected!r}") + print(f"Actual sync code: {out_line!r}") sys.exit(1) def unasync_dir(in_dir, out_dir, check_only=False): for dirpath, dirnames, filenames in os.walk(in_dir): for filename in filenames: - if not filename.endswith('.py'): + if not filename.endswith(".py"): continue rel_dir = os.path.relpath(dirpath, in_dir) in_path = os.path.normpath(os.path.join(in_dir, rel_dir, filename)) out_path = os.path.normpath(os.path.join(out_dir, rel_dir, filename)) - print(in_path, '->', out_path) + print(in_path, "->", out_path) if check_only: unasync_file_check(in_path, out_path) else: @@ -78,7 +82,7 @@ def unasync_dir(in_dir, out_dir, check_only=False): def main(): - check_only = '--check' in sys.argv + check_only = "--check" in sys.argv unasync_dir("httpcore/_async", "httpcore/_sync", check_only=check_only) unasync_dir("tests/_async", "tests/_sync", check_only=check_only) @@ -87,8 +91,8 @@ def main(): print("These patterns were not used:") pprint(unused_subs) - exit(1) - + exit(1) + -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tests/test_cancellations.py b/tests/test_cancellations.py index 033acef60..fe8d3c911 100644 --- a/tests/test_cancellations.py +++ b/tests/test_cancellations.py @@ -171,7 +171,6 @@ async def test_h11_timeout_during_response(): assert conn.is_closed() -@pytest.mark.xfail @pytest.mark.anyio async def test_h2_timeout_during_handshake(): """ @@ -186,7 +185,6 @@ async def test_h2_timeout_during_handshake(): assert conn.is_closed() -@pytest.mark.xfail @pytest.mark.anyio async def test_h2_timeout_during_request(): """ @@ -207,7 +205,6 @@ async def test_h2_timeout_during_request(): assert conn.is_idle() -@pytest.mark.xfail @pytest.mark.anyio async def test_h2_timeout_during_response(): """