diff --git a/CHANGES/11766.feature.rst b/CHANGES/11766.feature.rst new file mode 100644 index 00000000000..c4980bf7310 --- /dev/null +++ b/CHANGES/11766.feature.rst @@ -0,0 +1,4 @@ +Added ``RequestKey`` and ``ResponseKey`` classes, +which enable static type checking for request & response +context storages similarly to ``AppKey`` +-- by :user:`gsoldatov`. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 46547b871de..ec2b86d1495 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -150,6 +150,7 @@ Gennady Andreyev Georges Dubus Greg Holt Gregory Haynes +Grigoriy Soldatov Gus Goulart Gustavo Carneiro Günther Jena diff --git a/aiohttp/client.py b/aiohttp/client.py index 026006023ce..fca569e3ec4 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -380,7 +380,7 @@ def __init__( def __init_subclass__(cls: type["ClientSession"]) -> None: raise TypeError( - f"Inheritance class {cls.__name__} from ClientSession " "is forbidden" + f"Inheritance class {cls.__name__} from ClientSession is forbidden" ) def __del__(self, _warnings: Any = warnings) -> None: diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 7368cf1b170..f32fb7d80b6 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -834,8 +834,11 @@ def set_exception( @functools.total_ordering -class AppKey(Generic[_T]): - """Keys for static typing support in Application.""" +class BaseKey(Generic[_T]): + """Base for concrete context storage key classes. + + Each storage is provided with its own sub-class for the sake of some additional type safety. + """ __slots__ = ("_name", "_t", "__orig_class__") @@ -861,9 +864,9 @@ def __init__(self, name: str, t: type[_T] | None = None): self._t = t def __lt__(self, other: object) -> bool: - if isinstance(other, AppKey): + if isinstance(other, BaseKey): return self._name < other._name - return True # Order AppKey above other types. + return True # Order BaseKey above other types. def __repr__(self) -> str: t = self._t @@ -881,7 +884,25 @@ def __repr__(self) -> str: t_repr = f"{t.__module__}.{t.__qualname__}" else: t_repr = repr(t) # type: ignore[unreachable] - return f"" + return f"<{self.__class__.__name__}({self._name}, type={t_repr})>" + + +class AppKey(BaseKey[_T]): + """Keys for static typing support in Application.""" + + pass + + +class RequestKey(BaseKey[_T]): + """Keys for static typing support in Request.""" + + pass + + +class ResponseKey(BaseKey[_T]): + """Keys for static typing support in Response.""" + + pass @final @@ -893,7 +914,7 @@ def __init__(self, maps: Iterable[Mapping[str | AppKey[Any], Any]]) -> None: def __init_subclass__(cls) -> None: raise TypeError( - f"Inheritance class {cls.__name__} from ChainMapProxy " "is forbidden" + f"Inheritance class {cls.__name__} from ChainMapProxy is forbidden" ) @overload # type: ignore[override] diff --git a/aiohttp/web.py b/aiohttp/web.py index 1322360cbed..b116b5913d1 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -11,7 +11,7 @@ from typing import Any, cast from .abc import AbstractAccessLogger -from .helpers import AppKey +from .helpers import AppKey, RequestKey, ResponseKey from .log import access_logger from .typedefs import PathLike from .web_app import Application, CleanupError @@ -203,11 +203,13 @@ "BaseRequest", "FileField", "Request", + "RequestKey", # web_response "ContentCoding", "Response", "StreamResponse", "json_response", + "ResponseKey", # web_routedef "AbstractRouteDef", "RouteDef", diff --git a/aiohttp/web_app.py b/aiohttp/web_app.py index ddd30efb72f..3a34311f845 100644 --- a/aiohttp/web_app.py +++ b/aiohttp/web_app.py @@ -130,7 +130,7 @@ def __init__( def __init_subclass__(cls: type["Application"]) -> None: raise TypeError( - f"Inheritance class {cls.__name__} from web.Application " "is forbidden" + f"Inheritance class {cls.__name__} from web.Application is forbidden" ) # MutableMapping API diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 0d58fee567b..7e576350ac0 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -716,7 +716,7 @@ async def finish_response( self.log_exception("Missing return statement on request handler") # type: ignore[unreachable] else: self.log_exception( - "Web-handler should return a response instance, " f"got {resp!r}" + f"Web-handler should return a response instance, got {resp!r}" ) exc = HTTPInternalServerError() resp = Response( diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 84a15753c4b..96a8977def8 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -10,7 +10,7 @@ from collections.abc import Iterator, Mapping, MutableMapping from re import Pattern from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Final, Optional, cast +from typing import TYPE_CHECKING, Any, Final, Optional, TypeVar, cast, overload from urllib.parse import parse_qsl from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy @@ -26,6 +26,7 @@ ChainMapProxy, ETag, HeadersMixin, + RequestKey, frozen_dataclass_decorator, is_expected_content_type, parse_http_date, @@ -65,6 +66,9 @@ from .web_urldispatcher import UrlMappingMatchInfo +_T = TypeVar("_T") + + @frozen_dataclass_decorator class FileField: name: str @@ -101,7 +105,7 @@ class FileField: ############################################################ -class BaseRequest(MutableMapping[str, Any], HeadersMixin): +class BaseRequest(MutableMapping[str | RequestKey[Any], Any], HeadersMixin): POST_METHODS = { hdrs.METH_PATCH, hdrs.METH_POST, @@ -123,7 +127,7 @@ def __init__( loop: asyncio.AbstractEventLoop, *, client_max_size: int = 1024**2, - state: dict[str, Any] | None = None, + state: dict[RequestKey[Any] | str, Any] | None = None, scheme: str | None = None, host: str | None = None, remote: str | None = None, @@ -253,19 +257,31 @@ def rel_url(self) -> URL: # MutableMapping API - def __getitem__(self, key: str) -> Any: + @overload # type: ignore[override] + def __getitem__(self, key: RequestKey[_T]) -> _T: ... + + @overload + def __getitem__(self, key: str) -> Any: ... + + def __getitem__(self, key: str | RequestKey[_T]) -> Any: return self._state[key] - def __setitem__(self, key: str, value: Any) -> None: + @overload # type: ignore[override] + def __setitem__(self, key: RequestKey[_T], value: _T) -> None: ... + + @overload + def __setitem__(self, key: str, value: Any) -> None: ... + + def __setitem__(self, key: str | RequestKey[_T], value: Any) -> None: self._state[key] = value - def __delitem__(self, key: str) -> None: + def __delitem__(self, key: str | RequestKey[_T]) -> None: del self._state[key] def __len__(self) -> int: return len(self._state) - def __iter__(self) -> Iterator[str]: + def __iter__(self) -> Iterator[str | RequestKey[Any]]: return iter(self._state) ######## diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index 2d09f82c225..7b506716f8b 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -8,7 +8,7 @@ from collections.abc import Iterator, MutableMapping from concurrent.futures import Executor from http import HTTPStatus -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast, overload from multidict import CIMultiDict, istr @@ -21,6 +21,7 @@ CookieMixin, ETag, HeadersMixin, + ResponseKey, must_be_empty_body, parse_http_date, populate_with_cookies, @@ -43,6 +44,9 @@ from .web_request import BaseRequest +_T = TypeVar("_T") + + # TODO(py311): Convert to StrEnum for wider use class ContentCoding(enum.Enum): # The content codings that we have support for. @@ -61,7 +65,9 @@ class ContentCoding(enum.Enum): ############################################################ -class StreamResponse(MutableMapping[str, Any], HeadersMixin, CookieMixin): +class StreamResponse( + MutableMapping[str | ResponseKey[Any], Any], HeadersMixin, CookieMixin +): _body: None | bytes | bytearray | Payload _length_check = True @@ -93,7 +99,7 @@ def __init__( the headers when creating a new response object. It is not intended to be used by external code. """ - self._state: dict[str, Any] = {} + self._state: dict[str | ResponseKey[Any], Any] = {} if _real_headers is not None: self._headers = _real_headers @@ -483,19 +489,31 @@ def __repr__(self) -> str: info = "not prepared" return f"<{self.__class__.__name__} {self.reason} {info}>" - def __getitem__(self, key: str) -> Any: + @overload # type: ignore[override] + def __getitem__(self, key: ResponseKey[_T]) -> _T: ... + + @overload + def __getitem__(self, key: str) -> Any: ... + + def __getitem__(self, key: str | ResponseKey[_T]) -> Any: return self._state[key] - def __setitem__(self, key: str, value: Any) -> None: + @overload # type: ignore[override] + def __setitem__(self, key: ResponseKey[_T], value: _T) -> None: ... + + @overload + def __setitem__(self, key: str, value: Any) -> None: ... + + def __setitem__(self, key: str | ResponseKey[_T], value: Any) -> None: self._state[key] = value - def __delitem__(self, key: str) -> None: + def __delitem__(self, key: str | ResponseKey[_T]) -> None: del self._state[key] def __len__(self) -> int: return len(self._state) - def __iter__(self) -> Iterator[str]: + def __iter__(self) -> Iterator[str | ResponseKey[Any]]: return iter(self._state) def __hash__(self) -> int: diff --git a/aiohttp/web_runner.py b/aiohttp/web_runner.py index 0b04c317f93..480914d4480 100644 --- a/aiohttp/web_runner.py +++ b/aiohttp/web_runner.py @@ -373,7 +373,7 @@ def __init__( ) -> None: if not isinstance(app, Application): raise TypeError( - "The first argument should be web.Application " f"instance, got {app!r}" + f"The first argument should be web.Application instance, got {app!r}" ) kwargs["access_log_class"] = access_log_class diff --git a/aiohttp/web_urldispatcher.py b/aiohttp/web_urldispatcher.py index e4b3f2f34fe..acca27cafad 100644 --- a/aiohttp/web_urldispatcher.py +++ b/aiohttp/web_urldispatcher.py @@ -171,7 +171,7 @@ def __init__( pass else: raise TypeError( - "Only async functions are allowed as web-handlers " f", got {handler!r}" + f"Only async functions are allowed as web-handlers, got {handler!r}" ) self._method = method diff --git a/docs/faq.rst b/docs/faq.rst index f5e8b9afb15..2166f1775a7 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -87,8 +87,15 @@ support the :class:`dict` interface. Therefore, data may be stored inside a request object. :: - async def handler(request): - request['unique_key'] = data + request_id_key = web.RequestKey("request_id_key", str) + + @web.middleware + async def request_id_middleware(request, handler): + request[request_id_key] = "some_request_id" + return await handler(request) + + async def handler(request): + request_id = request[request_id_key] See https://github.com/aio-libs/aiohttp_session code for an example. The ``aiohttp_session.get_session(request)`` method uses ``SESSION_KEY`` diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 74ace02c5ec..29f6b0f364e 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -6,6 +6,7 @@ aiohttp aiohttpdemo aiohttp’s aiopg +al alives api api’s @@ -120,6 +121,7 @@ env environ eof epoll +et etag ETag expirations @@ -167,6 +169,7 @@ iterable iterables javascript Jinja +jitter json keepalive keepalived @@ -294,6 +297,7 @@ runtime runtimes sa Satisfiable +scalability schemas sendfile serializable @@ -306,6 +310,7 @@ ssl SSLContext startup stateful +storages subapplication subclassed subclasses @@ -350,6 +355,7 @@ unicode unittest Unittest unix +unobvious unsets unstripped untyped diff --git a/docs/web_advanced.rst b/docs/web_advanced.rst index 5ca4f0a05bd..81fa384d55b 100644 --- a/docs/web_advanced.rst +++ b/docs/web_advanced.rst @@ -446,10 +446,13 @@ Request's storage ^^^^^^^^^^^^^^^^^ Variables that are only needed for the lifetime of a :class:`Request`, can be -stored in a :class:`Request`:: +stored in a :class:`Request`. Similarly to :class:`Application`, :class:`RequestKey` +instances or strings can be used as keys:: + + my_private_key = web.RequestKey("my_private_key", str) async def handler(request): - request['my_private_key'] = "data" + request[my_private_key] = "data" ... This is mostly useful for :ref:`aiohttp-web-middlewares` and @@ -464,9 +467,11 @@ also support :class:`collections.abc.MutableMapping` interface. This is useful when you want to share data with signals and middlewares once all the work in the handler is done:: + my_metric_key = web.ResponseKey("my_metric_key", int) + async def handler(request): [ do all the work ] - response['my_metric'] = 123 + response[my_metric_key] = 123 return response @@ -722,18 +727,20 @@ In contrast, when accessing the stream directly (not recommended in middleware): When working with raw stream data that needs to be shared between middleware and handlers:: + raw_body_key = web.RequestKey("raw_body_key", bytes) + async def stream_parsing_middleware( request: web.Request, handler: Callable[[web.Request], Awaitable[web.StreamResponse]] ) -> web.StreamResponse: # Read stream once and store the data raw_data = await request.content.read() - request['raw_body'] = raw_data + request[raw_body_key] = raw_data return await handler(request) async def handler(request: web.Request) -> web.Response: # Access the stored data instead of reading the stream again - raw_data = request.get('raw_body', b'') + raw_data = request.get(raw_body_key, b'') return web.Response(body=raw_data) Example diff --git a/docs/web_reference.rst b/docs/web_reference.rst index 7347e5d9b06..cb0f280f90b 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -534,6 +534,21 @@ and :ref:`aiohttp-web-signals` handlers. request copy with changed *path*, *method* etc. +.. class:: RequestKey(name, t) + + This class should be used for the keys in :class:`Request` and + :class:`BaseRequest`. They provide a type-safe alternative to + `str` keys when checking your code with a type checker (e.g. mypy). + They also avoid name clashes with keys from different libraries etc. + + :param name: A name to help with debugging. This should be the same as + the variable name (much like how :class:`typing.TypeVar` + is used). + + :param t: The type that should be used for the value in the dict (e.g. + `str`, `Iterator[int]` etc.) + + .. _aiohttp-web-response: @@ -1357,6 +1372,24 @@ content type and *data* encoded by ``dumps`` parameter (:func:`json.dumps` by default). +.. class:: ResponseKey(name, t) + + This class should be used for the keys in :class:`Response`, + :class:`FileResponse` and :class:`StreamResponse`. They provide + a type-safe alternative to `str` keys when checking your code + with a type checker (e.g. mypy). They also avoid name clashes + with keys from different libraries etc. + + :param name: A name to help with debugging. This should be the same as + the variable name (much like how :class:`typing.TypeVar` + is used). + + :param t: The type that should be used for the value in the dict (e.g. + `str`, `Iterator[int]` etc.) + + + + .. _aiohttp-web-app-and-router: Application and Router @@ -1631,6 +1664,7 @@ Application and Router :param t: The type that should be used for the value in the dict (e.g. `str`, `Iterator[int]` etc.) + .. class:: Server A protocol factory compatible with diff --git a/tests/test_web_request.py b/tests/test_web_request.py index 2f59fc77bed..905ba0d0ba7 100644 --- a/tests/test_web_request.py +++ b/tests/test_web_request.py @@ -2,8 +2,9 @@ import datetime import socket import ssl +import sys import weakref -from collections.abc import MutableMapping +from collections.abc import Iterator, MutableMapping from typing import NoReturn from unittest import mock @@ -470,7 +471,73 @@ def test_request_iter() -> None: req = make_mocked_request("GET", "/") req["key"] = "value" req["key2"] = "value2" - assert set(req) == {"key", "key2"} + key3 = web.RequestKey("key3", str) + req[key3] = "value3" + assert set(req) == {"key", "key2", key3} + + +def test_requestkey() -> None: + req = make_mocked_request("GET", "/") + key = web.RequestKey("key", str) + req[key] = "value" + assert req[key] == "value" + assert len(req) == 1 + del req[key] + assert len(req) == 0 + + +def test_request_get_requestkey() -> None: + req = make_mocked_request("GET", "/") + key = web.RequestKey("key", int) + assert req.get(key, "foo") == "foo" + req[key] = 5 + assert req.get(key, "foo") == 5 + + +def test_requestkey_repr_concrete() -> None: + key = web.RequestKey("key", int) + assert repr(key) in ( + "", # pytest-xdist + "", + ) + key2 = web.RequestKey("key", web.Request) + assert repr(key2) in ( + # pytest-xdist: + "", + "", + ) + + +def test_requestkey_repr_nonconcrete() -> None: + key = web.RequestKey("key", Iterator[int]) + if sys.version_info < (3, 11): + assert repr(key) in ( + # pytest-xdist: + "", + "", + ) + else: + assert repr(key) in ( + # pytest-xdist: + "", + "", + ) + + +def test_requestkey_repr_annotated() -> None: + key = web.RequestKey[Iterator[int]]("key") + if sys.version_info < (3, 11): + assert repr(key) in ( + # pytest-xdist: + "", + "", + ) + else: + assert repr(key) in ( + # pytest-xdist: + "", + "", + ) def test___repr__() -> None: diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 57c8fbf9c83..32975f562d1 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -4,8 +4,9 @@ import io import json import re +import sys import weakref -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Iterator from concurrent.futures import ThreadPoolExecutor from unittest import mock @@ -111,11 +112,77 @@ def test_stream_response_len() -> None: assert len(resp) == 1 -def test_request_iter() -> None: +def test_response_iter() -> None: resp = web.StreamResponse() resp["key"] = "value" resp["key2"] = "value2" - assert set(resp) == {"key", "key2"} + key3 = web.ResponseKey("key3", str) + resp[key3] = "value3" + assert set(resp) == {"key", "key2", key3} + + +def test_responsekey() -> None: + resp = web.StreamResponse() + key = web.ResponseKey("key", str) + resp[key] = "value" + assert resp[key] == "value" + assert len(resp) == 1 + del resp[key] + assert len(resp) == 0 + + +def test_response_get_responsekey() -> None: + resp = web.StreamResponse() + key = web.ResponseKey("key", int) + assert resp.get(key, "foo") == "foo" + resp[key] = 5 + assert resp.get(key, "foo") == 5 + + +def test_responsekey_repr_concrete() -> None: + key = web.ResponseKey("key", int) + assert repr(key) in ( + "", # pytest-xdist + "", + ) + key2 = web.ResponseKey("key", web.Request) + assert repr(key2) in ( + # pytest-xdist: + "", + "", + ) + + +def test_responsekey_repr_nonconcrete() -> None: + key = web.ResponseKey("key", Iterator[int]) + if sys.version_info < (3, 11): + assert repr(key) in ( + # pytest-xdist: + "", + "", + ) + else: + assert repr(key) in ( + # pytest-xdist: + "", + "", + ) + + +def test_responsekey_repr_annotated() -> None: + key = web.ResponseKey[Iterator[int]]("key") + if sys.version_info < (3, 11): + assert repr(key) in ( + # pytest-xdist: + "", + "", + ) + else: + assert repr(key) in ( + # pytest-xdist: + "", + "", + ) def test_content_length() -> None: