Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions httpx/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import sys

if sys.version_info >= (3, 10):
from contextlib import aclosing
else:
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Awaitable, Protocol, TypeVar

class _SupportsAclose(Protocol):
def aclose(self) -> Awaitable[object]: ...

_SupportsAcloseT = TypeVar("_SupportsAcloseT", bound=_SupportsAclose)

@asynccontextmanager
async def aclosing(thing: _SupportsAcloseT) -> AsyncIterator[Any]:
try:
yield thing
finally:
await thing.aclose()


if sys.version_info >= (3, 13):
from typing import TypeIs
else:
from typing_extensions import TypeIs


__all__ = ["aclosing", "TypeIs"]
15 changes: 13 additions & 2 deletions httpx/_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
RequestFiles,
ResponseContent,
SyncByteStream,
is_async_readable_file,
)
from ._utils import peek_filelike_length, primitive_value_to_str
from ._utils import peek_filelike_length, primitive_value_to_str, to_bytes

__all__ = ["ByteStream"]

Expand Down Expand Up @@ -83,6 +84,11 @@ async def __aiter__(self) -> AsyncIterator[bytes]:
while chunk:
yield chunk
chunk = await self._stream.aread(self.CHUNK_SIZE)
elif is_async_readable_file(self._stream):
chunk = await self._stream.read(self.CHUNK_SIZE)
while chunk:
yield to_bytes(chunk)
chunk = await self._stream.read(self.CHUNK_SIZE)
else:
# Otherwise iterate.
async for part in self._stream:
Expand Down Expand Up @@ -127,7 +133,12 @@ def encode_content(
return headers, IteratorByteStream(content) # type: ignore

elif isinstance(content, AsyncIterable):
headers = {"Transfer-Encoding": "chunked"}
if is_async_readable_file(content) and (
content_length_or_none := peek_filelike_length(content)
):
headers = {"Content-Length": str(content_length_or_none)}
else:
headers = {"Transfer-Encoding": "chunked"}
return headers, AsyncIteratorByteStream(content)

raise TypeError(f"Unexpected type for 'content', {type(content)!r}")
Expand Down
42 changes: 40 additions & 2 deletions httpx/_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
import typing
from pathlib import Path

from ._compat import aclosing
from ._types import (
AsyncByteStream,
FileContent,
FileTypes,
RequestData,
RequestFiles,
SyncByteStream,
is_async_readable_file,
)
from ._utils import (
peek_filelike_length,
Expand Down Expand Up @@ -201,6 +203,11 @@ def render_headers(self) -> bytes:
return self._headers

def render_data(self) -> typing.Iterator[bytes]:
if is_async_readable_file(self.file):
raise TypeError(
"Invalid type for file. AsyncReadableFile is not supported."
)

if isinstance(self.file, (str, bytes)):
yield to_bytes(self.file)
return
Expand All @@ -216,10 +223,27 @@ def render_data(self) -> typing.Iterator[bytes]:
yield to_bytes(chunk)
chunk = self.file.read(self.CHUNK_SIZE)

async def arender_data(self) -> typing.AsyncGenerator[bytes]:
if not is_async_readable_file(self.file):
for chunk in self.render_data():
yield chunk
return
await self.file.seek(0)
chunk = await self.file.read(self.CHUNK_SIZE)
while chunk:
yield to_bytes(chunk)
chunk = await self.file.read(self.CHUNK_SIZE)

def render(self) -> typing.Iterator[bytes]:
yield self.render_headers()
yield from self.render_data()

async def arender(self) -> typing.AsyncGenerator[bytes]:
yield self.render_headers()
async with aclosing(self.arender_data()) as data:
async for chunk in data:
yield chunk


class MultipartStream(SyncByteStream, AsyncByteStream):
"""
Expand Down Expand Up @@ -262,6 +286,19 @@ def iter_chunks(self) -> typing.Iterator[bytes]:
yield b"\r\n"
yield b"--%s--\r\n" % self.boundary

async def aiter_chunks(self) -> typing.AsyncGenerator[bytes]:
for field in self.fields:
yield b"--%s\r\n" % self.boundary
if isinstance(field, FileField):
async with aclosing(field.arender()) as data:
async for chunk in data:
yield chunk
else:
for chunk in field.render():
yield chunk
yield b"\r\n"
yield b"--%s--\r\n" % self.boundary

def get_content_length(self) -> int | None:
"""
Return the length of the multipart encoded content, or `None` if
Expand Down Expand Up @@ -296,5 +333,6 @@ def __iter__(self) -> typing.Iterator[bytes]:
yield chunk

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
for chunk in self.iter_chunks():
yield chunk
async with aclosing(self.aiter_chunks()) as data:
async for chunk in data:
yield chunk
31 changes: 30 additions & 1 deletion httpx/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
Type definitions for type checking purposes.
"""

import inspect
from http.cookiejar import CookieJar
from typing import (
IO,
TYPE_CHECKING,
Any,
AnyStr,
AsyncIterable,
AsyncIterator,
Callable,
Expand All @@ -16,11 +18,14 @@
List,
Mapping,
Optional,
Protocol,
Sequence,
Tuple,
Union,
)

from ._compat import TypeIs

if TYPE_CHECKING: # pragma: no cover
from ._auth import Auth # noqa: F401
from ._config import Proxy, Timeout # noqa: F401
Expand Down Expand Up @@ -71,7 +76,18 @@

RequestData = Mapping[str, Any]

FileContent = Union[IO[bytes], bytes, str]

class AsyncReadableFile(Protocol):
async def __aiter__(self) -> AsyncIterator[AnyStr]: ...

async def read(self, size: int = -1) -> AnyStr: ...

def fileno(self) -> int: ...

async def seek(self, offset: int, whence: Optional[int] = ...) -> int: ...


FileContent = Union[IO[bytes], bytes, str, AsyncReadableFile]
FileTypes = Union[
# file (or bytes)
FileContent,
Expand Down Expand Up @@ -112,3 +128,16 @@ async def __aiter__(self) -> AsyncIterator[bytes]:

async def aclose(self) -> None:
pass


def is_async_readable_file(fp: Any) -> TypeIs[AsyncReadableFile]:
return (
isinstance(fp, AsyncIterable)
and hasattr(fp, "read")
and inspect.iscoroutinefunction(fp.read)
and hasattr(fp, "fileno")
and callable(fp.fileno)
and not inspect.iscoroutinefunction(fp.fileno)
and hasattr(fp, "seek")
and inspect.iscoroutinefunction(fp.seek)
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,5 @@ markers = [
]

[tool.coverage.run]
omit = ["venv/*"]
omit = ["venv/*", "httpx/_compat.py"]
include = ["httpx/*", "tests/*"]
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,5 @@ trio==0.31.0
trio-typing==0.10.0
trustme==1.2.1
uvicorn==0.35.0
aiofiles==25.1.0
types-aiofiles==25.1.0.20251011
73 changes: 73 additions & 0 deletions tests/test_content.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import io
import typing

import aiofiles
import anyio
import pytest
import trio

import httpx
from httpx._content import AsyncIteratorByteStream
from httpx._types import AsyncReadableFile, is_async_readable_file

method = "POST"
url = "https://www.example.com"
Expand Down Expand Up @@ -516,3 +521,71 @@ def test_allow_nan_false():
ValueError, match="Out of range float values are not JSON compliant"
):
httpx.Response(200, json=data_with_inf)


@pytest.mark.parametrize(
"client_method,content_seed,mode",
[
("put", "a🥳", "rt"),
("post", "a🥳", "rt"),
("put", "a🥳", "rb"),
("post", "a🥳", "rb"),
],
ids=["put_text", "post_text", "put_binary", "post_binary"],
)
@pytest.mark.anyio
async def test_chunked_async_file_content(
tmp_path, anyio_backend, monkeypatch, client_method, server, content_seed, mode
):
total_chunks = 3
seed_size = len(content_seed.encode()) if "b" in mode else len(content_seed)
read_calls_expected = total_chunks * seed_size + 1
content = "".join(
[content_seed * AsyncIteratorByteStream.CHUNK_SIZE] * total_chunks
)
content_bytes = content.encode()
to_upload = tmp_path / "upload.txt"
to_upload.write_bytes(content_bytes)
url = server.url.copy_with(path="/echo_body")

async def checks(client: httpx.AsyncClient, async_file: AsyncReadableFile) -> None:
read_called = 0
fileno_called = 0
original_read = async_file.read
original_fileno = async_file.fileno

async def mock_read(*args, **kwargs):
nonlocal read_called
read_called += 1
return await original_read(*args, **kwargs)

def mock_fileno(*args):
nonlocal fileno_called
fileno_called += 1
return original_fileno(*args)

monkeypatch.setattr(async_file, "read", mock_read)
monkeypatch.setattr(async_file, "fileno", mock_fileno)
response = await getattr(client, client_method)(url=url, content=async_file)
assert response.status_code == 200
assert response.content == content_bytes
assert response.request.headers["Content-Length"] == str(len(content_bytes))
assert read_called == read_calls_expected
assert fileno_called == 1

async with (
await anyio.open_file(to_upload, mode=mode)
if anyio_backend != "trio"
else await trio.open_file(to_upload, mode=mode) as async_file,
httpx.AsyncClient() as client,
):
assert is_async_readable_file(async_file)
await checks(client, async_file)

if anyio_backend != "trio": # aiofiles doesn't work with trio
async with (
aiofiles.open(to_upload, mode=mode) as aio_file,
httpx.AsyncClient() as client,
):
assert is_async_readable_file(aio_file)
await checks(client, aio_file)
Loading