diff --git a/AGENTS.md b/AGENTS.md index 3bbba5e..170bbd5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -17,8 +17,18 @@ Run the full test suite with: uv run pytest ``` -You can check the tests coverage by adding the `--cov` flag from `pytest-cov` when running `pytest`. +You can check the tests coverage by adding the `--cov` and `--cov-report=term-missing` (to report the missing lines in the command output) flags from `pytest-cov` when running `pytest`. ```bash -uv run pytest --cov +uv run pytest --cov --cov-report=term-missing +``` + +## Before commiting + +All files must be properly formatted before creating a PR, so we can merge it upstream. + +You can run `ruff` for formatting automatically all the files of the project: + +```bash +uv run ruff format ``` diff --git a/src/graphql_server/asgi/__init__.py b/src/graphql_server/asgi/__init__.py index d7873e3..b4e9e0d 100644 --- a/src/graphql_server/asgi/__init__.py +++ b/src/graphql_server/asgi/__init__.py @@ -45,7 +45,12 @@ ) if TYPE_CHECKING: - from collections.abc import AsyncGenerator, AsyncIterator, Mapping, Sequence # pragma: no cover + from collections.abc import ( # pragma: no cover + AsyncGenerator, + AsyncIterator, + Mapping, + Sequence, + ) from graphql.type import GraphQLSchema # pragma: no cover from starlette.types import Receive, Scope, Send # pragma: no cover @@ -112,7 +117,9 @@ async def iter_json( async def send_json(self, message: Mapping[str, object]) -> None: try: await self.ws.send_text(self.view.encode_json(message)) - except WebSocketDisconnect as exc: # pragma: no cover - network errors mocked elsewhere + except ( + WebSocketDisconnect + ) as exc: # pragma: no cover - network errors mocked elsewhere raise WebSocketDisconnected from exc async def close(self, code: int, reason: str) -> None: diff --git a/src/graphql_server/chalice/views.py b/src/graphql_server/chalice/views.py index 399adb5..690d9ca 100644 --- a/src/graphql_server/chalice/views.py +++ b/src/graphql_server/chalice/views.py @@ -1,6 +1,7 @@ from __future__ import annotations import warnings +from io import BytesIO from typing import TYPE_CHECKING, Any, Optional, Union, cast from chalice.app import Request, Response @@ -23,6 +24,8 @@ class ChaliceHTTPRequestAdapter(SyncHTTPRequestAdapter): def __init__(self, request: Request) -> None: self.request = request + self._post_data: Optional[dict[str, Union[str, bytes]]] = None + self._files: Optional[dict[str, Any]] = None @property def query_params(self) -> QueryParams: @@ -42,11 +45,49 @@ def headers(self) -> Mapping[str, str]: @property def post_data(self) -> Mapping[str, Union[str, bytes]]: - raise NotImplementedError + if self._post_data is None: + self._parse_body() + return self._post_data or {} @property def files(self) -> Mapping[str, Any]: - raise NotImplementedError + if self._files is None: + self._parse_body() + return self._files or {} + + def _parse_body(self) -> None: + self._post_data = {} + self._files = {} + + content_type = self.content_type or "" + + if "multipart/form-data" in content_type: + import cgi + + fp = BytesIO(self.request.raw_body) + environ = { + "REQUEST_METHOD": "POST", + "CONTENT_TYPE": content_type, + "CONTENT_LENGTH": str(len(self.request.raw_body)), + } + fs = cgi.FieldStorage(fp=fp, environ=environ, keep_blank_values=True) + for key in fs.keys(): + field = fs[key] + if isinstance(field, list): + field = field[0] + if getattr(field, "filename", None): + data = field.file.read() + self._files[key] = BytesIO(data) + else: + self._post_data[key] = field.value + elif "application/x-www-form-urlencoded" in content_type: + from urllib.parse import parse_qs + + data = parse_qs(self.request.raw_body.decode()) + self._post_data = {k: v[0] for k, v in data.items()} + else: + self._post_data = {} + self._files = {} @property def content_type(self) -> Optional[str]: @@ -65,9 +106,11 @@ def __init__( graphiql: Optional[bool] = None, graphql_ide: Optional[GraphQL_IDE] = "graphiql", allow_queries_via_get: bool = True, + multipart_uploads_enabled: bool = False, ) -> None: self.allow_queries_via_get = allow_queries_via_get self.schema = schema + self.multipart_uploads_enabled = multipart_uploads_enabled if graphiql is not None: warnings.warn( "The `graphiql` argument is deprecated in favor of `graphql_ide`", @@ -95,7 +138,6 @@ def get_sub_response(self, request: Request) -> TemporalResponse: @staticmethod def error_response( message: str, - error_code: str, http_status_code: int, headers: Optional[dict[str, str | list[str]]] = None, ) -> Response: @@ -110,9 +152,7 @@ def error_response( Returns: An errors response. """ - body = {"Code": error_code, "Message": message} - - return Response(body=body, status_code=http_status_code, headers=headers) + return Response(body=message, status_code=http_status_code, headers=headers) def get_context(self, request: Request, response: TemporalResponse) -> Context: return {"request": request, "response": response} # type: ignore @@ -138,18 +178,7 @@ def execute_request(self, request: Request) -> Response: try: return self.run(request=request) except HTTPException as e: - error_code_map = { - 400: "BadRequestError", - 401: "UnauthorizedError", - 403: "ForbiddenError", - 404: "NotFoundError", - 409: "ConflictError", - 429: "TooManyRequestsError", - 500: "ChaliceViewError", - } - return self.error_response( - error_code=error_code_map.get(e.status_code, "ChaliceViewError"), message=e.reason, http_status_code=e.status_code, ) diff --git a/src/graphql_server/channels/testing.py b/src/graphql_server/channels/testing.py index 5d0b53b..b817f80 100644 --- a/src/graphql_server/channels/testing.py +++ b/src/graphql_server/channels/testing.py @@ -71,7 +71,9 @@ def __init__( subprotocols: an ordered list of preferred subprotocols to be sent to the server. **kwargs: additional arguments to be passed to the `WebsocketCommunicator` constructor. """ - if connection_params is None: # pragma: no cover - tested via custom initialisation + if ( + connection_params is None + ): # pragma: no cover - tested via custom initialisation connection_params = {} self.protocol = protocol subprotocols = kwargs.get("subprotocols", []) @@ -139,7 +141,9 @@ async def subscribe( }, } - if variables is not None: # pragma: no cover - exercised in higher-level tests + if ( + variables is not None + ): # pragma: no cover - exercised in higher-level tests start_message["payload"]["variables"] = variables await self.send_json_to(start_message) @@ -155,7 +159,9 @@ async def subscribe( ret.errors = self.process_errors(payload.get("errors") or []) ret.extensions = payload.get("extensions", None) yield ret - elif message["type"] == "error": # pragma: no cover - network failures untested + elif ( + message["type"] == "error" + ): # pragma: no cover - network failures untested error_payload = message["payload"] yield ExecutionResult( data=None, errors=self.process_errors(error_payload) diff --git a/src/tests/http/clients/chalice.py b/src/tests/http/clients/chalice.py index 2b1904b..a3dcadd 100644 --- a/src/tests/http/clients/chalice.py +++ b/src/tests/http/clients/chalice.py @@ -3,10 +3,11 @@ import urllib.parse from io import BytesIO from json import dumps -from typing import Any, Optional, Union +from typing import Any, Optional from typing_extensions import Literal from graphql import ExecutionResult +from urllib3 import encode_multipart_formdata from chalice.app import Chalice from chalice.app import Request as ChaliceRequest @@ -60,11 +61,14 @@ def __init__( graphiql=graphiql, graphql_ide=graphql_ide, allow_queries_via_get=allow_queries_via_get, + multipart_uploads_enabled=multipart_uploads_enabled, ) view.result_override = result_override @self.app.route( - "/graphql", methods=["GET", "POST"], content_types=["application/json"] + "/graphql", + methods=["GET", "POST"], + content_types=["application/json", "multipart/form-data"], ) def handle_graphql(): assert self.app.current_request is not None @@ -90,27 +94,27 @@ async def _graphql_request( extensions=extensions, ) - data: Union[dict[str, object], str, None] = None - - if body and files: - body.update({name: (file, name) for name, file in files.items()}) - url = "/graphql" + headers = self._get_headers(method=method, headers=headers, files=files) if method == "get": body_encoded = urllib.parse.urlencode(body or {}) url = f"{url}?{body_encoded}" - else: - if body: - data = body if files else dumps(body) - kwargs["body"] = data + elif body: + if files: + fields = {"operations": body["operations"], "map": body["map"]} + for filename, file in files.items(): + fields[filename] = (filename, file.read(), "text/plain") + data, content_type = encode_multipart_formdata(fields) + headers.update( + {"Content-Type": content_type, "Content-Length": f"{len(data)}"} + ) + kwargs["body"] = data + else: + kwargs["body"] = dumps(body) with Client(self.app) as client: - response = getattr(client.http, method)( - url, - headers=self._get_headers(method=method, headers=headers, files=files), - **kwargs, - ) + response = getattr(client.http, method)(url, headers=headers, **kwargs) return Response( status_code=response.status_code, diff --git a/src/tests/http/test_query.py b/src/tests/http/test_query.py index 949a566..ff7299d 100644 --- a/src/tests/http/test_query.py +++ b/src/tests/http/test_query.py @@ -289,15 +289,7 @@ async def test_invalid_operation_selection(http_client: HttpClient, operation_na ) assert response.status_code == 400 - - if isinstance(http_client, ChaliceHttpClient): - # Our Chalice integration purposely wraps errors messages with a JSON object - assert response.json == { - "Code": "BadRequestError", - "Message": f'Unknown operation named "{operation_name}".', - } - else: - assert response.data == f'Unknown operation named "{operation_name}".'.encode() + assert response.data == f'Unknown operation named "{operation_name}".'.encode() async def test_operation_selection_without_operations(http_client: HttpClient): @@ -308,12 +300,4 @@ async def test_operation_selection_without_operations(http_client: HttpClient): ) assert response.status_code == 400 - - if isinstance(http_client, ChaliceHttpClient): - # Our Chalice integration purposely wraps errors messages with a JSON object - assert response.json == { - "Code": "BadRequestError", - "Message": "Can't get GraphQL operation type", - } - else: - assert response.data == b"Can't get GraphQL operation type" + assert response.data == b"Can't get GraphQL operation type" diff --git a/src/tests/http/test_upload.py b/src/tests/http/test_upload.py index 1bd2fd1..a3fe1c8 100644 --- a/src/tests/http/test_upload.py +++ b/src/tests/http/test_upload.py @@ -1,4 +1,3 @@ -import contextlib import json from io import BytesIO @@ -10,23 +9,11 @@ @pytest.fixture def http_client(http_client_class: type[HttpClient]) -> HttpClient: - with contextlib.suppress(ImportError): - from .clients.chalice import ChaliceHttpClient - - if http_client_class is ChaliceHttpClient: - pytest.xfail(reason="Chalice does not support uploads") - return http_client_class() @pytest.fixture def enabled_http_client(http_client_class: type[HttpClient]) -> HttpClient: - with contextlib.suppress(ImportError): - from .clients.chalice import ChaliceHttpClient - - if http_client_class is ChaliceHttpClient: - pytest.xfail(reason="Chalice does not support uploads") - return http_client_class(multipart_uploads_enabled=True) diff --git a/src/tests/test/test_client_utils.py b/src/tests/test/test_client_utils.py index 9fba4a3..115ecf2 100644 --- a/src/tests/test/test_client_utils.py +++ b/src/tests/test/test_client_utils.py @@ -8,7 +8,9 @@ class DummyClient(BaseGraphQLTestClient): def request(self, body, headers=None, files=None): - return types.SimpleNamespace(content=json.dumps(body).encode(), json=lambda: body) + return types.SimpleNamespace( + content=json.dumps(body).encode(), json=lambda: body + ) def test_build_body_with_variables_and_files(): diff --git a/src/tests/test/test_runtime.py b/src/tests/test/test_runtime.py index c48de19..2685253 100644 --- a/src/tests/test/test_runtime.py +++ b/src/tests/test/test_runtime.py @@ -1,57 +1,58 @@ import pytest from graphql import ( ExecutionResult, + GraphQLError, GraphQLField, GraphQLObjectType, GraphQLSchema, GraphQLString, - GraphQLError, parse, ) -import graphql_server.runtime as runtime - +from graphql_server import runtime schema = GraphQLSchema( - query=GraphQLObjectType('Query', {'hello': GraphQLField(GraphQLString)}) + query=GraphQLObjectType("Query", {"hello": GraphQLField(GraphQLString)}) ) def test_validate_document_with_rules(): from graphql.validation.rules.no_unused_fragments import NoUnusedFragmentsRule - doc = parse('query Test { hello }') + doc = parse("query Test { hello }") assert runtime.validate_document(schema, doc, (NoUnusedFragmentsRule,)) == [] def test_get_custom_context_kwargs(monkeypatch): - assert runtime._get_custom_context_kwargs({'a': 1}) == {'operation_extensions': {'a': 1}} - monkeypatch.setattr(runtime, 'IS_GQL_33', False) + assert runtime._get_custom_context_kwargs({"a": 1}) == { + "operation_extensions": {"a": 1} + } + monkeypatch.setattr(runtime, "IS_GQL_33", False) try: - assert runtime._get_custom_context_kwargs({'a': 1}) == {} + assert runtime._get_custom_context_kwargs({"a": 1}) == {} finally: - monkeypatch.setattr(runtime, 'IS_GQL_33', True) + monkeypatch.setattr(runtime, "IS_GQL_33", True) def test_get_operation_type_multiple_operations(): - doc = parse('query A{hello} query B{hello}') + doc = parse("query A{hello} query B{hello}") with pytest.raises(Exception): runtime._get_operation_type(doc) def test_parse_and_validate_document_node(): - doc = parse('query Q { hello }') + doc = parse("query Q { hello }") res = runtime._parse_and_validate(schema, doc, None) assert res == doc def test_introspect_success_and_failure(monkeypatch): data = runtime.introspect(schema) - assert '__schema' in data + assert "__schema" in data def fake_execute_sync(schema, query): - return ExecutionResult(data=None, errors=[GraphQLError('boom')]) + return ExecutionResult(data=None, errors=[GraphQLError("boom")]) - monkeypatch.setattr(runtime, 'execute_sync', fake_execute_sync) + monkeypatch.setattr(runtime, "execute_sync", fake_execute_sync) with pytest.raises(ValueError): runtime.introspect(schema) diff --git a/src/tests/types/test_unset.py b/src/tests/types/test_unset.py index bc18d04..b24e4b1 100644 --- a/src/tests/types/test_unset.py +++ b/src/tests/types/test_unset.py @@ -14,4 +14,4 @@ def test_deprecated_is_unset_and_getattr(): with pytest.warns(DeprecationWarning): assert unset.is_unset(unset.UNSET) with pytest.raises(AttributeError): - getattr(unset, "missing") + unset.missing diff --git a/src/tests/utils/test_debug.py b/src/tests/utils/test_debug.py index c460f10..b0feb08 100644 --- a/src/tests/utils/test_debug.py +++ b/src/tests/utils/test_debug.py @@ -2,7 +2,10 @@ import pytest -from graphql_server.utils.debug import GraphQLJSONEncoder, pretty_print_graphql_operation +from graphql_server.utils.debug import ( + GraphQLJSONEncoder, + pretty_print_graphql_operation, +) def test_graphql_json_encoder_default(): diff --git a/src/tests/websockets/test_graphql_transport_ws.py b/src/tests/websockets/test_graphql_transport_ws.py index 4bf9816..e7e99dc 100644 --- a/src/tests/websockets/test_graphql_transport_ws.py +++ b/src/tests/websockets/test_graphql_transport_ws.py @@ -423,7 +423,7 @@ async def test_duplicated_operation_ids(ws: WebSocketClient): async def test_reused_operation_ids(ws: WebSocketClient): """Test that an operation id can be reused after it has been - previously used for a completed operation. + previously used for a completed operation. """ # Use sub1 as an id for an operation await ws.send_message(