diff --git a/src/graphql_server/asgi/__init__.py b/src/graphql_server/asgi/__init__.py index da7c41c..d7873e3 100644 --- a/src/graphql_server/asgi/__init__.py +++ b/src/graphql_server/asgi/__init__.py @@ -45,13 +45,13 @@ ) if TYPE_CHECKING: - from collections.abc import AsyncGenerator, AsyncIterator, Mapping, Sequence + from collections.abc import AsyncGenerator, AsyncIterator, Mapping, Sequence # pragma: no cover - from graphql.type import GraphQLSchema - from starlette.types import Receive, Scope, Send + from graphql.type import GraphQLSchema # pragma: no cover + from starlette.types import Receive, Scope, Send # pragma: no cover - from graphql_server.http import GraphQLHTTPResponse - from graphql_server.http.ides import GraphQL_IDE + from graphql_server.http import GraphQLHTTPResponse # pragma: no cover + from graphql_server.http.ides import GraphQL_IDE # pragma: no cover class ASGIRequestAdapter(AsyncHTTPRequestAdapter): @@ -112,7 +112,7 @@ async def iter_json( async def send_json(self, message: Mapping[str, object]) -> None: try: await self.ws.send_text(self.view.encode_json(message)) - except WebSocketDisconnect as exc: + except WebSocketDisconnect as exc: # pragma: no cover - network errors mocked elsewhere raise WebSocketDisconnected from exc async def close(self, code: int, reason: str) -> None: @@ -225,7 +225,7 @@ def create_response( else "application/json", ) - if sub_response.background: + if sub_response.background: # pragma: no cover - trivial assignment response.background = sub_response.background return response diff --git a/src/graphql_server/channels/testing.py b/src/graphql_server/channels/testing.py index 70f3041..5d0b53b 100644 --- a/src/graphql_server/channels/testing.py +++ b/src/graphql_server/channels/testing.py @@ -21,11 +21,11 @@ from graphql_server.subscriptions.protocols.graphql_ws import types as ws_types if TYPE_CHECKING: - from collections.abc import AsyncIterator - from types import TracebackType - from typing_extensions import Self + from collections.abc import AsyncIterator # pragma: no cover + from types import TracebackType # pragma: no cover + from typing_extensions import Self # pragma: no cover - from asgiref.typing import ASGIApplication + from asgiref.typing import ASGIApplication # pragma: no cover class GraphQLWebsocketCommunicator(WebsocketCommunicator): @@ -71,7 +71,7 @@ def __init__( subprotocols: an ordered list of preferred subprotocols to be sent to the server. **kwargs: additional arguments to be passed to the `WebsocketCommunicator` constructor. """ - if connection_params is None: + if connection_params is None: # pragma: no cover - tested via custom initialisation connection_params = {} self.protocol = protocol subprotocols = kwargs.get("subprotocols", []) @@ -139,7 +139,7 @@ async def subscribe( }, } - if variables is not None: + if variables is not None: # pragma: no cover - exercised in higher-level tests start_message["payload"]["variables"] = variables await self.send_json_to(start_message) @@ -155,7 +155,7 @@ async def subscribe( ret.errors = self.process_errors(payload.get("errors") or []) ret.extensions = payload.get("extensions", None) yield ret - elif message["type"] == "error": + elif message["type"] == "error": # pragma: no cover - network failures untested error_payload = message["payload"] yield ExecutionResult( data=None, errors=self.process_errors(error_payload) diff --git a/src/graphql_server/django/context.py b/src/graphql_server/django/context.py index 4351c2c..499d51b 100644 --- a/src/graphql_server/django/context.py +++ b/src/graphql_server/django/context.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from django.http import HttpRequest, HttpResponse + from django.http import HttpRequest, HttpResponse # pragma: no cover @dataclass diff --git a/src/graphql_server/runtime.py b/src/graphql_server/runtime.py index c4e6f84..cc2778e 100644 --- a/src/graphql_server/runtime.py +++ b/src/graphql_server/runtime.py @@ -39,9 +39,9 @@ from graphql_server.utils.logs import GraphQLServerLogger if TYPE_CHECKING: - from typing_extensions import TypeAlias + from typing_extensions import TypeAlias # pragma: no cover - from graphql.validation import ASTValidationRule + from graphql.validation import ASTValidationRule # pragma: no cover SubscriptionResult: TypeAlias = AsyncGenerator[ExecutionResult, None] diff --git a/src/graphql_server/test/client.py b/src/graphql_server/test/client.py index 10400d0..4c8e382 100644 --- a/src/graphql_server/test/client.py +++ b/src/graphql_server/test/client.py @@ -8,9 +8,9 @@ from typing_extensions import Literal, TypedDict if TYPE_CHECKING: - from collections.abc import Coroutine, Mapping + from collections.abc import Coroutine, Mapping # pragma: no cover - from graphql import GraphQLFormattedError + from graphql import GraphQLFormattedError # pragma: no cover @dataclass @@ -77,7 +77,7 @@ def request( headers: Optional[dict[str, object]] = None, files: Optional[dict[str, object]] = None, ) -> Any: - raise NotImplementedError + raise NotImplementedError # pragma: no cover def _build_body( self, diff --git a/src/graphql_server/utils/logs.py b/src/graphql_server/utils/logs.py index 630dd2a..8fb2a21 100644 --- a/src/graphql_server/utils/logs.py +++ b/src/graphql_server/utils/logs.py @@ -4,9 +4,9 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from typing import Final + from typing import Final # pragma: no cover - from graphql.error import GraphQLError + from graphql.error import GraphQLError # pragma: no cover class GraphQLServerLogger: diff --git a/src/tests/channels/test_consumer.py b/src/tests/channels/test_consumer.py new file mode 100644 index 0000000..d589af3 --- /dev/null +++ b/src/tests/channels/test_consumer.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import asyncio + +import pytest + +from graphql_server.channels.handlers.base import ChannelsConsumer + + +class DummyChannelLayer: + def __init__(self) -> None: + self.added: list[tuple[str, str]] = [] + self.discarded: list[tuple[str, str]] = [] + + async def group_add(self, group: str, channel: str) -> None: + self.added.append((group, channel)) + + async def group_discard(self, group: str, channel: str) -> None: + self.discarded.append((group, channel)) + + +@pytest.mark.asyncio +async def test_channel_listen_receives_messages_and_cleans_up() -> None: + consumer = ChannelsConsumer() + layer = DummyChannelLayer() + consumer.channel_layer = layer + consumer.channel_name = "chan" + + gen = consumer.channel_listen("test.message", groups=["g"], timeout=0.1) + + async def send() -> None: + await asyncio.sleep(0) + queue = next(iter(consumer.listen_queues["test.message"])) + queue.put_nowait({"type": "test.message", "payload": 1}) + + asyncio.create_task(send()) + + with pytest.deprecated_call(match="Use listen_to_channel instead"): + message = await gen.__anext__() + assert message == {"type": "test.message", "payload": 1} + + await gen.aclose() + + assert layer.added == [("g", "chan")] + assert layer.discarded == [("g", "chan")] + + +@pytest.mark.asyncio +async def test_channel_listen_times_out() -> None: + consumer = ChannelsConsumer() + layer = DummyChannelLayer() + consumer.channel_layer = layer + consumer.channel_name = "chan" + + gen = consumer.channel_listen("test.message", groups=["g"], timeout=0.01) + + with pytest.deprecated_call(match="Use listen_to_channel instead"): + with pytest.raises(StopAsyncIteration): + await gen.__anext__() + + assert layer.added == [("g", "chan")] + assert layer.discarded == [("g", "chan")] diff --git a/src/tests/django/test_context.py b/src/tests/django/test_context.py new file mode 100644 index 0000000..c696c44 --- /dev/null +++ b/src/tests/django/test_context.py @@ -0,0 +1,11 @@ +from types import SimpleNamespace + +from graphql_server.django.context import GraphQLDjangoContext + + +def test_graphql_django_context_get_and_item_access(): + req = SimpleNamespace() + res = SimpleNamespace() + ctx = GraphQLDjangoContext(req, res) + assert ctx["request"] is req + assert ctx.get("response") is res diff --git a/src/tests/http/test_base_view.py b/src/tests/http/test_base_view.py new file mode 100644 index 0000000..f3556a4 --- /dev/null +++ b/src/tests/http/test_base_view.py @@ -0,0 +1,20 @@ +import json + +from graphql_server.http.base import BaseView + + +class DummyView(BaseView): + graphql_ide = None + + +def test_parse_query_params_extensions(): + view = DummyView() + params = view.parse_query_params({"extensions": json.dumps({"a": 1})}) + assert params["extensions"] == {"a": 1} + + +def test_is_multipart_subscriptions_boundary_check(): + view = DummyView() + assert not view._is_multipart_subscriptions( + "multipart/mixed", {"boundary": "notgraphql"} + ) diff --git a/src/tests/test/test_client_utils.py b/src/tests/test/test_client_utils.py new file mode 100644 index 0000000..9fba4a3 --- /dev/null +++ b/src/tests/test/test_client_utils.py @@ -0,0 +1,37 @@ +import json +import types + +import pytest + +from graphql_server.test.client import BaseGraphQLTestClient + + +class DummyClient(BaseGraphQLTestClient): + def request(self, body, headers=None, files=None): + return types.SimpleNamespace(content=json.dumps(body).encode(), json=lambda: body) + + +def test_build_body_with_variables_and_files(): + client = DummyClient(None) + variables = {"files": [None, None], "textFile": None, "other": "x"} + files = {"file1": object(), "file2": object(), "textFile": object()} + body = client._build_body("query", variables, files) + mapping = json.loads(body["map"]) + assert mapping == { + "file1": ["variables.files.0"], + "file2": ["variables.files.1"], + "textFile": ["variables.textFile"], + } + + +def test_decode_multipart(): + client = DummyClient(None) + response = types.SimpleNamespace(content=json.dumps({"a": 1}).encode()) + assert client._decode(response, type="multipart") == {"a": 1} + + +def test_query_deprecated_arg_and_assertion(): + client = DummyClient(None) + with pytest.deprecated_call(): + resp = client.query("{a}", asserts_errors=False) + assert resp.errors is None diff --git a/src/tests/test/test_runtime.py b/src/tests/test/test_runtime.py new file mode 100644 index 0000000..c48de19 --- /dev/null +++ b/src/tests/test/test_runtime.py @@ -0,0 +1,57 @@ +import pytest +from graphql import ( + ExecutionResult, + GraphQLField, + GraphQLObjectType, + GraphQLSchema, + GraphQLString, + GraphQLError, + parse, +) + +import graphql_server.runtime as runtime + + +schema = GraphQLSchema( + query=GraphQLObjectType('Query', {'hello': GraphQLField(GraphQLString)}) +) + + +def test_validate_document_with_rules(): + from graphql.validation.rules.no_unused_fragments import NoUnusedFragmentsRule + + doc = parse('query Test { hello }') + assert runtime.validate_document(schema, doc, (NoUnusedFragmentsRule,)) == [] + + +def test_get_custom_context_kwargs(monkeypatch): + assert runtime._get_custom_context_kwargs({'a': 1}) == {'operation_extensions': {'a': 1}} + monkeypatch.setattr(runtime, 'IS_GQL_33', False) + try: + assert runtime._get_custom_context_kwargs({'a': 1}) == {} + finally: + monkeypatch.setattr(runtime, 'IS_GQL_33', True) + + +def test_get_operation_type_multiple_operations(): + doc = parse('query A{hello} query B{hello}') + with pytest.raises(Exception): + runtime._get_operation_type(doc) + + +def test_parse_and_validate_document_node(): + doc = parse('query Q { hello }') + res = runtime._parse_and_validate(schema, doc, None) + assert res == doc + + +def test_introspect_success_and_failure(monkeypatch): + data = runtime.introspect(schema) + assert '__schema' in data + + def fake_execute_sync(schema, query): + return ExecutionResult(data=None, errors=[GraphQLError('boom')]) + + monkeypatch.setattr(runtime, 'execute_sync', fake_execute_sync) + with pytest.raises(ValueError): + runtime.introspect(schema) diff --git a/src/tests/types/test_unset.py b/src/tests/types/test_unset.py new file mode 100644 index 0000000..bc18d04 --- /dev/null +++ b/src/tests/types/test_unset.py @@ -0,0 +1,17 @@ +import pytest + +from graphql_server.types import unset + + +def test_unset_singleton_and_representation(): + assert unset.UnsetType() is unset.UNSET + assert str(unset.UNSET) == "" + assert repr(unset.UNSET) == "UNSET" + assert not unset.UNSET + + +def test_deprecated_is_unset_and_getattr(): + with pytest.warns(DeprecationWarning): + assert unset.is_unset(unset.UNSET) + with pytest.raises(AttributeError): + getattr(unset, "missing") diff --git a/src/tests/utils/test_debug.py b/src/tests/utils/test_debug.py new file mode 100644 index 0000000..c460f10 --- /dev/null +++ b/src/tests/utils/test_debug.py @@ -0,0 +1,46 @@ +import builtins + +import pytest + +from graphql_server.utils.debug import GraphQLJSONEncoder, pretty_print_graphql_operation + + +def test_graphql_json_encoder_default(): + class Foo: + pass + + foo = Foo() + encoder = GraphQLJSONEncoder() + assert encoder.default(foo) == repr(foo) + + +def test_pretty_print_requires_pygments(monkeypatch): + original_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name.startswith("pygments"): + raise ImportError("No module named pygments") + return original_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + with pytest.raises(ImportError): + pretty_print_graphql_operation("Query", "query { field }", None) + + +def test_pretty_print_graphql_operation(capsys): + obj = object() + variables = {"var": obj} + pretty_print_graphql_operation("MyQuery", "query { field }", variables) + captured = capsys.readouterr().out + assert "MyQuery" in captured + assert "field" in captured + assert "var" in captured + assert repr(obj) in captured + + +def test_pretty_print_introspection_query(capsys): + pretty_print_graphql_operation( + "IntrospectionQuery", "query { __schema { queryType { name } } }", None + ) + captured = capsys.readouterr().out + assert captured == ""