diff --git a/CHANGES/10795.feature.rst b/CHANGES/10795.feature.rst new file mode 100644 index 00000000000..f4bc30e53fa --- /dev/null +++ b/CHANGES/10795.feature.rst @@ -0,0 +1,2 @@ +Added :py:mod:`orjson` support as the default JSON encoder for :py:class:`~aiohttp.ClientSession` and :py:class:`~aiohttp.JsonPayload` +-- by :user:`fatelei` diff --git a/aiohttp/client.py b/aiohttp/client.py index 6a8c667491f..60216d56471 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -4,7 +4,6 @@ import base64 import dataclasses import hashlib -import json import os import sys import traceback @@ -107,7 +106,16 @@ from .http import WS_KEY, HttpVersion, WebSocketReader, WebSocketWriter from .http_websocket import WSHandshakeError, ws_ext_gen, ws_ext_parse from .tracing import Trace, TraceConfig -from .typedefs import JSONEncoder, LooseCookies, LooseHeaders, Query, StrOrURL +from .typedefs import ( + DEFAULT_JSON_BYTES_ENCODER, + DEFAULT_JSON_ENCODER, + JSONBytesEncoder, + JSONEncoder, + LooseCookies, + LooseHeaders, + Query, + StrOrURL, +) __all__ = ( # client_exceptions @@ -277,7 +285,8 @@ def __init__( proxy_auth: Optional[BasicAuth] = None, skip_auto_headers: Optional[Iterable[str]] = None, auth: Optional[BasicAuth] = None, - json_serialize: JSONEncoder = json.dumps, + json_serialize: JSONEncoder = DEFAULT_JSON_ENCODER, + json_serialize_bytes: JSONBytesEncoder = DEFAULT_JSON_BYTES_ENCODER, request_class: Type[ClientRequest] = ClientRequest, response_class: Type[ClientResponse] = ClientResponse, ws_response_class: Type[ClientWebSocketResponse] = ClientWebSocketResponse, @@ -357,6 +366,7 @@ def __init__( self._default_auth = auth self._version = version self._json_serialize = json_serialize + self._json_serialize_bytes = json_serialize_bytes self._raise_for_status = raise_for_status self._auto_decompress = auto_decompress self._trust_env = trust_env @@ -484,7 +494,11 @@ async def _request( "data and json parameters can not be used at the same time" ) elif json is not None: - data = payload.JsonPayload(json, dumps=self._json_serialize) + data = payload.JsonPayload( + json, + dumps=self._json_serialize, + dumps_bytes=self._json_serialize_bytes, + ) redirects = 0 history: List[ClientResponse] = [] @@ -1316,6 +1330,11 @@ def json_serialize(self) -> JSONEncoder: """Json serializer callable""" return self._json_serialize + @property + def json_serialize_bytes(self) -> JSONBytesEncoder: + """Json bytes serializer callable""" + return self._json_serialize_bytes + @property def connector_owner(self) -> bool: """Should connector be closed on session closing""" diff --git a/aiohttp/payload.py b/aiohttp/payload.py index 2ee8b8cb908..03e3e127f28 100644 --- a/aiohttp/payload.py +++ b/aiohttp/payload.py @@ -1,7 +1,6 @@ import asyncio import enum import io -import json import mimetypes import os import sys @@ -36,7 +35,13 @@ sentinel, ) from .streams import StreamReader -from .typedefs import JSONEncoder, _CIMultiDict +from .typedefs import ( + DEFAULT_JSON_BYTES_ENCODER, + DEFAULT_JSON_ENCODER, + JSONBytesEncoder, + JSONEncoder, + _CIMultiDict, +) __all__ = ( "PAYLOAD_REGISTRY", @@ -939,15 +944,17 @@ def __init__( value: Any, encoding: str = "utf-8", content_type: str = "application/json", - dumps: JSONEncoder = json.dumps, - *args: Any, + dumps: JSONEncoder = DEFAULT_JSON_ENCODER, + *, + dumps_bytes: JSONBytesEncoder = DEFAULT_JSON_BYTES_ENCODER, **kwargs: Any, ) -> None: + # Prefer bytes serializer to avoid extra encode/decode + body = dumps_bytes(value) super().__init__( - dumps(value).encode(encoding), + body, content_type=content_type, encoding=encoding, - *args, **kwargs, ) diff --git a/aiohttp/typedefs.py b/aiohttp/typedefs.py index cc8c0825b4e..6ee54cc94af 100644 --- a/aiohttp/typedefs.py +++ b/aiohttp/typedefs.py @@ -17,8 +17,33 @@ Query = _Query -DEFAULT_JSON_ENCODER = json.dumps -DEFAULT_JSON_DECODER = json.loads +# Try to use orjson for better performance, fallback to standard json +try: + import orjson + + def _orjson_dumps(obj: Any) -> str: + """orjson encoder that returns str (like json.dumps).""" + return orjson.dumps(obj).decode("utf-8") + + def _orjson_dumps_bytes(obj: Any) -> bytes: + """orjson encoder that returns bytes directly (fast path).""" + return orjson.dumps(obj) + + def _orjson_loads(s: str) -> Any: + """orjson decoder that accepts str (like json.loads).""" + return orjson.loads(s) + + DEFAULT_JSON_ENCODER = _orjson_dumps + DEFAULT_JSON_DECODER = _orjson_loads + DEFAULT_JSON_BYTES_ENCODER = _orjson_dumps_bytes +except ImportError: + DEFAULT_JSON_ENCODER = json.dumps + DEFAULT_JSON_DECODER = json.loads + + def _json_dumps_bytes_fallback(obj: Any) -> bytes: + return json.dumps(obj).encode("utf-8") + + DEFAULT_JSON_BYTES_ENCODER = _json_dumps_bytes_fallback if TYPE_CHECKING: _CIMultiDict = CIMultiDict[str] @@ -37,6 +62,7 @@ Byteish = Union[bytes, bytearray, memoryview] JSONEncoder = Callable[[Any], str] JSONDecoder = Callable[[str], Any] +JSONBytesEncoder = Callable[[Any], bytes] LooseHeaders = Union[ Mapping[str, str], Mapping[istr, str], diff --git a/docs/client_quickstart.rst b/docs/client_quickstart.rst index 48f123b94bd..2c9cfd6f817 100644 --- a/docs/client_quickstart.rst +++ b/docs/client_quickstart.rst @@ -210,16 +210,21 @@ serialization. But it is possible to use different ``serializer``. :class:`ClientSession` accepts ``json_serialize`` parameter:: - import ujson + import orjson async with aiohttp.ClientSession( - json_serialize=ujson.dumps) as session: + json_serialize=orjson.dumps) as session: await session.post(url, json={'test': 'object'}) .. note:: - ``ujson`` library is faster than standard :mod:`json` but slightly - incompatible. + ``orjson`` library is much faster than standard :mod:`json` and is now + the default when available. You can install it with the ``speedups`` extra: + ``pip install aiohttp[speedups]`` or separately with ``pip install orjson``. + ``ujson`` was previously recommended but is now deprecated in favor of + ``orjson`` due to security and maintenance concerns. + If ``orjson`` is not available, aiohttp will fall back to the standard + :mod:`json` module. JSON Response Content ===================== diff --git a/requirements/runtime-deps.in b/requirements/runtime-deps.in index f849b448cf6..49dbd36efbd 100644 --- a/requirements/runtime-deps.in +++ b/requirements/runtime-deps.in @@ -8,6 +8,7 @@ Brotli; platform_python_implementation == 'CPython' brotlicffi; platform_python_implementation != 'CPython' frozenlist >= 1.1.1 multidict >=4.5, < 7.0 +orjson >= 3.8.0 ; platform_python_implementation == "CPython" propcache >= 0.2.0 yarl >= 1.17.0, < 2.0 zstandard; platform_python_implementation == 'CPython' and python_version < "3.14" diff --git a/setup.cfg b/setup.cfg index 7f783324897..62150dd1b01 100644 --- a/setup.cfg +++ b/setup.cfg @@ -58,6 +58,7 @@ install_requires = multidict >=4.5, < 7.0 propcache >= 0.2.0 yarl >= 1.17.0, < 2.0 + orjson >= 3.8.0 [options.exclude_package_data] * = diff --git a/tests/test_orjson_integration.py b/tests/test_orjson_integration.py new file mode 100644 index 00000000000..0a7639d3ec8 --- /dev/null +++ b/tests/test_orjson_integration.py @@ -0,0 +1,281 @@ +"""Tests for orjson integration in aiohttp.""" + +import json +import sys +from typing import Any, Optional +from unittest.mock import Mock, patch + +import pytest + +from aiohttp import JsonPayload, typedefs, web +from aiohttp.client import ClientSession +from aiohttp.test_utils import AiohttpClient + + +# Mock orjson functions for testing +def mock_orjson_dumps(data: Any, *, option: Optional[int] = None) -> bytes: + """Mock orjson.dumps function that returns bytes.""" + # Add some unique marker to distinguish from json.dumps + result = {"_orjson_used": True, "data": data} + return json.dumps(result).encode("utf-8") + + +def mock_orjson_loads(data: bytes) -> Any: + """Mock orjson.loads function that accepts bytes.""" + decoded = json.loads(data.decode("utf-8")) + if decoded.get("_orjson_used"): + return decoded["data"] + return decoded + + +class TestOrjsonIntegration: + """Test orjson integration for JSON serialization.""" + + def test_default_json_encoder_uses_orjson_when_available(self) -> None: + """Test that DEFAULT_JSON_ENCODER uses orjson when available.""" + with patch.dict(sys.modules, {"orjson": Mock()}): + # Mock orjson module + orjson_mock = sys.modules["orjson"] + orjson_mock.dumps = mock_orjson_dumps + orjson_mock.OPT_NAIVE_UTC = 1 # Mock option constant + orjson_mock.OPT_OMIT_MICROSECONDS = 2 + + # Reload typedefs to pick up orjson + import importlib + + importlib.reload(typedefs) + + # Test that orjson is used + test_data = {"test": "data", "number": 42} + result = typedefs.DEFAULT_JSON_ENCODER(test_data) + + # Should return string (our mock converts bytes to string) + assert isinstance(result, str) + parsed_result = json.loads(result) + assert parsed_result["_orjson_used"] is True + assert parsed_result["data"] == test_data + + def test_default_json_encoder_fallback_to_json_dumps(self) -> None: + """Test that DEFAULT_JSON_ENCODER falls back to json.dumps when orjson unavailable.""" + # Ensure orjson is not available + with patch.dict(sys.modules, {"orjson": None}): + import importlib + + importlib.reload(typedefs) + + # Test that json.dumps is used + test_data = {"test": "data", "number": 42} + result = typedefs.DEFAULT_JSON_ENCODER(test_data) + + # Should return same as json.dumps + assert result == json.dumps(test_data) + # Should not have orjson marker + assert "_orjson_used" not in result + + def test_json_payload_uses_default_encoder_by_default(self) -> None: + """Test that JsonPayload uses DEFAULT_JSON_ENCODER by default.""" + with patch.dict(sys.modules, {"orjson": Mock()}): + orjson_mock = sys.modules["orjson"] + orjson_mock.dumps = mock_orjson_dumps + orjson_mock.OPT_NAIVE_UTC = 1 + orjson_mock.OPT_OMIT_MICROSECONDS = 2 + + import importlib + + importlib.reload(typedefs) + + test_data = {"test": "payload_data"} + payload = JsonPayload(test_data) + + # The payload should use orjson internally + assert payload.content_type == "application/json" + + # Check the actual serialized data contains orjson marker + data = payload._value + assert isinstance(data, bytes) + parsed = json.loads(data.decode("utf-8")) + assert parsed["_orjson_used"] is True + assert parsed["data"] == test_data + + def test_json_payload_can_override_encoder(self) -> None: + """Test that JsonPayload can still use custom encoder when provided.""" + + def custom_encoder(obj: Any) -> str: + return json.dumps({"custom": True, "data": obj}) + + test_data = {"test": "custom_data"} + payload = JsonPayload(test_data, dumps=custom_encoder) + + # Should use custom encoder, not orjson + data = payload._value + assert isinstance(data, bytes) + parsed = json.loads(data.decode("utf-8")) + assert parsed["custom"] is True + assert parsed["data"] == test_data + assert "_orjson_used" not in parsed + + async def test_client_session_uses_default_encoder_by_default(self) -> None: + """Test that ClientSession uses DEFAULT_JSON_ENCODER by default.""" + with patch.dict(sys.modules, {"orjson": Mock()}): + orjson_mock = sys.modules["orjson"] + orjson_mock.dumps = mock_orjson_dumps + orjson_mock.OPT_NAIVE_UTC = 1 + orjson_mock.OPT_OMIT_MICROSECONDS = 2 + + import importlib + + importlib.reload(typedefs) + + async with ClientSession() as session: + # Verify the session uses DEFAULT_JSON_ENCODER + assert session._json_serialize is typedefs.DEFAULT_JSON_ENCODER + + async def test_client_session_can_override_encoder(self) -> None: + """Test that ClientSession can use custom JSON encoder when provided.""" + + def custom_encoder(obj: Any) -> str: + return json.dumps({"session_custom": True, "data": obj}) + + async with ClientSession(json_serialize=custom_encoder) as session: + # Verify the session uses custom encoder + assert session._json_serialize is custom_encoder + + # Test actual usage + result = session._json_serialize({"test": "data"}) + parsed = json.loads(result) + assert parsed["session_custom"] is True + assert parsed["data"] == {"test": "data"} + + async def test_client_session_json_request_with_orjson( + self, aiohttp_client: AiohttpClient + ) -> None: + """Test client session JSON requests work with orjson.""" + with patch.dict(sys.modules, {"orjson": Mock()}): + orjson_mock = sys.modules["orjson"] + orjson_mock.dumps = mock_orjson_dumps + orjson_mock.loads = mock_orjson_loads + orjson_mock.OPT_NAIVE_UTC = 1 + orjson_mock.OPT_OMIT_MICROSECONDS = 2 + + import importlib + + importlib.reload(typedefs) + + received_data = None + + async def handler(request: web.Request) -> web.Response: + nonlocal received_data + received_data = await request.json() + return web.Response(text="OK") + + app = web.Application() + app.router.add_post("/", handler) + client = await aiohttp_client(app) + + test_data = {"test": "json_request", "number": 123} + + # Make request with JSON data + async with client.post("/", json=test_data) as resp: + assert resp.status == 200 + + # Verify the data was received correctly + # Since we use orjson mock, the received data should be the original + assert received_data == test_data + + async def test_json_payload_with_fallback_when_orjson_unavailable(self) -> None: + """Test JsonPayload works with json.dumps when orjson is not available.""" + # Ensure orjson is not available + with patch.dict(sys.modules, {"orjson": None}): + import importlib + + importlib.reload(typedefs) + + test_data = {"fallback": "test", "works": True} + payload = JsonPayload(test_data) + + # Should still work with json.dumps + assert payload.content_type == "application/json" + + # Check the serialized data + data = payload._value + assert isinstance(data, bytes) + parsed = json.loads(data.decode("utf-8")) + assert parsed == test_data + assert "_orjson_used" not in str(data) + + def test_orjson_handles_datetime_serialization(self) -> None: + """Test that orjson integration handles datetime objects properly.""" + from datetime import datetime + + with patch.dict(sys.modules, {"orjson": Mock()}): + orjson_mock = sys.modules["orjson"] + + def orjson_dumps_with_datetime( + data: Any, *, option: Optional[int] = None + ) -> bytes: + """Mock orjson.dumps that handles datetime.""" + if isinstance(data, dict): + # Convert datetime to string for JSON serialization + serializable_data = {} + for k, v in data.items(): + if isinstance(v, datetime): + serializable_data[k] = v.isoformat() + else: + serializable_data[k] = v + result = {"_orjson_used": True, "data": serializable_data} + else: + result = {"_orjson_used": True, "data": data} + return json.dumps(result).encode("utf-8") + + orjson_mock.dumps = orjson_dumps_with_datetime + orjson_mock.OPT_NAIVE_UTC = 1 + orjson_mock.OPT_OMIT_MICROSECONDS = 2 + + import importlib + + importlib.reload(typedefs) + + test_data = { + "timestamp": datetime(2023, 1, 1, 12, 0, 0), + "message": "datetime test", + } + + payload = JsonPayload(test_data) + data = payload._value + parsed = json.loads(data.decode("utf-8")) + + assert parsed["_orjson_used"] is True + assert parsed["data"]["message"] == "datetime test" + assert parsed["data"]["timestamp"] == "2023-01-01T12:00:00" + + def test_encoding_with_orjson(self) -> None: + """Test that JsonPayload respects encoding parameter with orjson.""" + with patch.dict(sys.modules, {"orjson": Mock()}): + orjson_mock = sys.modules["orjson"] + orjson_mock.dumps = mock_orjson_dumps + orjson_mock.OPT_NAIVE_UTC = 1 + orjson_mock.OPT_OMIT_MICROSECONDS = 2 + + import importlib + + importlib.reload(typedefs) + + test_data = {"unicode": "café", "emoji": "🚀"} + + # Test with default UTF-8 encoding + payload_utf8 = JsonPayload(test_data) + assert payload_utf8.encoding == "utf-8" + + # Test with custom encoding + payload_latin1 = JsonPayload(test_data, encoding="latin-1") + assert payload_latin1.encoding == "latin-1" + + +@pytest.fixture(autouse=True) +def reset_typedefs(): + """Reset typedefs module after each test.""" + yield + # Reload typedefs to reset to original state + import importlib + + importlib.reload(typedefs)