Skip to content

Enable Chalice multipart uploads #134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,18 @@ Run the full test suite with:
uv run pytest
```

You can check the tests coverage by adding the `--cov` flag from `pytest-cov` when running `pytest`.
You can check the tests coverage by adding the `--cov` and `--cov-report=term-missing` (to report the missing lines in the command output) flags from `pytest-cov` when running `pytest`.

```bash
uv run pytest --cov
uv run pytest --cov --cov-report=term-missing
```

## Before commiting

All files must be properly formatted before creating a PR, so we can merge it upstream.

You can run `ruff` for formatting automatically all the files of the project:

```bash
uv run ruff format
```
11 changes: 9 additions & 2 deletions src/graphql_server/asgi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@
)

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, AsyncIterator, Mapping, Sequence # pragma: no cover
from collections.abc import ( # pragma: no cover
AsyncGenerator,
AsyncIterator,
Mapping,
Sequence,
)

from graphql.type import GraphQLSchema # pragma: no cover
from starlette.types import Receive, Scope, Send # pragma: no cover
Expand Down Expand Up @@ -112,7 +117,9 @@ async def iter_json(
async def send_json(self, message: Mapping[str, object]) -> None:
try:
await self.ws.send_text(self.view.encode_json(message))
except WebSocketDisconnect as exc: # pragma: no cover - network errors mocked elsewhere
except (
WebSocketDisconnect
) as exc: # pragma: no cover - network errors mocked elsewhere
raise WebSocketDisconnected from exc

async def close(self, code: int, reason: str) -> None:
Expand Down
63 changes: 46 additions & 17 deletions src/graphql_server/chalice/views.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import warnings
from io import BytesIO
from typing import TYPE_CHECKING, Any, Optional, Union, cast

from chalice.app import Request, Response
Expand All @@ -23,6 +24,8 @@
class ChaliceHTTPRequestAdapter(SyncHTTPRequestAdapter):
def __init__(self, request: Request) -> None:
self.request = request
self._post_data: Optional[dict[str, Union[str, bytes]]] = None
self._files: Optional[dict[str, Any]] = None

@property
def query_params(self) -> QueryParams:
Expand All @@ -42,11 +45,49 @@ def headers(self) -> Mapping[str, str]:

@property
def post_data(self) -> Mapping[str, Union[str, bytes]]:
raise NotImplementedError
if self._post_data is None:
self._parse_body()
return self._post_data or {}

@property
def files(self) -> Mapping[str, Any]:
raise NotImplementedError
if self._files is None:
self._parse_body()
return self._files or {}

def _parse_body(self) -> None:
self._post_data = {}
self._files = {}

content_type = self.content_type or ""

if "multipart/form-data" in content_type:
import cgi

fp = BytesIO(self.request.raw_body)
environ = {
"REQUEST_METHOD": "POST",
"CONTENT_TYPE": content_type,
"CONTENT_LENGTH": str(len(self.request.raw_body)),
}
fs = cgi.FieldStorage(fp=fp, environ=environ, keep_blank_values=True)
for key in fs.keys():
field = fs[key]
if isinstance(field, list):
field = field[0]
if getattr(field, "filename", None):
data = field.file.read()
self._files[key] = BytesIO(data)
else:
self._post_data[key] = field.value
elif "application/x-www-form-urlencoded" in content_type:
from urllib.parse import parse_qs

data = parse_qs(self.request.raw_body.decode())
self._post_data = {k: v[0] for k, v in data.items()}
else:
self._post_data = {}
self._files = {}

@property
def content_type(self) -> Optional[str]:
Expand All @@ -65,9 +106,11 @@ def __init__(
graphiql: Optional[bool] = None,
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
allow_queries_via_get: bool = True,
multipart_uploads_enabled: bool = False,
) -> None:
self.allow_queries_via_get = allow_queries_via_get
self.schema = schema
self.multipart_uploads_enabled = multipart_uploads_enabled
if graphiql is not None:
warnings.warn(
"The `graphiql` argument is deprecated in favor of `graphql_ide`",
Expand Down Expand Up @@ -95,7 +138,6 @@ def get_sub_response(self, request: Request) -> TemporalResponse:
@staticmethod
def error_response(
message: str,
error_code: str,
http_status_code: int,
headers: Optional[dict[str, str | list[str]]] = None,
) -> Response:
Expand All @@ -110,9 +152,7 @@ def error_response(
Returns:
An errors response.
"""
body = {"Code": error_code, "Message": message}

return Response(body=body, status_code=http_status_code, headers=headers)
return Response(body=message, status_code=http_status_code, headers=headers)

def get_context(self, request: Request, response: TemporalResponse) -> Context:
return {"request": request, "response": response} # type: ignore
Expand All @@ -138,18 +178,7 @@ def execute_request(self, request: Request) -> Response:
try:
return self.run(request=request)
except HTTPException as e:
error_code_map = {
400: "BadRequestError",
401: "UnauthorizedError",
403: "ForbiddenError",
404: "NotFoundError",
409: "ConflictError",
429: "TooManyRequestsError",
500: "ChaliceViewError",
}

return self.error_response(
error_code=error_code_map.get(e.status_code, "ChaliceViewError"),
message=e.reason,
http_status_code=e.status_code,
)
Expand Down
12 changes: 9 additions & 3 deletions src/graphql_server/channels/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def __init__(
subprotocols: an ordered list of preferred subprotocols to be sent to the server.
**kwargs: additional arguments to be passed to the `WebsocketCommunicator` constructor.
"""
if connection_params is None: # pragma: no cover - tested via custom initialisation
if (
connection_params is None
): # pragma: no cover - tested via custom initialisation
connection_params = {}
self.protocol = protocol
subprotocols = kwargs.get("subprotocols", [])
Expand Down Expand Up @@ -139,7 +141,9 @@ async def subscribe(
},
}

if variables is not None: # pragma: no cover - exercised in higher-level tests
if (
variables is not None
): # pragma: no cover - exercised in higher-level tests
start_message["payload"]["variables"] = variables

await self.send_json_to(start_message)
Expand All @@ -155,7 +159,9 @@ async def subscribe(
ret.errors = self.process_errors(payload.get("errors") or [])
ret.extensions = payload.get("extensions", None)
yield ret
elif message["type"] == "error": # pragma: no cover - network failures untested
elif (
message["type"] == "error"
): # pragma: no cover - network failures untested
error_payload = message["payload"]
yield ExecutionResult(
data=None, errors=self.process_errors(error_payload)
Expand Down
36 changes: 20 additions & 16 deletions src/tests/http/clients/chalice.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import urllib.parse
from io import BytesIO
from json import dumps
from typing import Any, Optional, Union
from typing import Any, Optional
from typing_extensions import Literal

from graphql import ExecutionResult
from urllib3 import encode_multipart_formdata

from chalice.app import Chalice
from chalice.app import Request as ChaliceRequest
Expand Down Expand Up @@ -60,11 +61,14 @@ def __init__(
graphiql=graphiql,
graphql_ide=graphql_ide,
allow_queries_via_get=allow_queries_via_get,
multipart_uploads_enabled=multipart_uploads_enabled,
)
view.result_override = result_override

@self.app.route(
"/graphql", methods=["GET", "POST"], content_types=["application/json"]
"/graphql",
methods=["GET", "POST"],
content_types=["application/json", "multipart/form-data"],
)
def handle_graphql():
assert self.app.current_request is not None
Expand All @@ -90,27 +94,27 @@ async def _graphql_request(
extensions=extensions,
)

data: Union[dict[str, object], str, None] = None

if body and files:
body.update({name: (file, name) for name, file in files.items()})

url = "/graphql"
headers = self._get_headers(method=method, headers=headers, files=files)

if method == "get":
body_encoded = urllib.parse.urlencode(body or {})
url = f"{url}?{body_encoded}"
else:
if body:
data = body if files else dumps(body)
kwargs["body"] = data
elif body:
if files:
fields = {"operations": body["operations"], "map": body["map"]}
for filename, file in files.items():
fields[filename] = (filename, file.read(), "text/plain")
data, content_type = encode_multipart_formdata(fields)
headers.update(
{"Content-Type": content_type, "Content-Length": f"{len(data)}"}
)
kwargs["body"] = data
else:
kwargs["body"] = dumps(body)

with Client(self.app) as client:
response = getattr(client.http, method)(
url,
headers=self._get_headers(method=method, headers=headers, files=files),
**kwargs,
)
response = getattr(client.http, method)(url, headers=headers, **kwargs)

return Response(
status_code=response.status_code,
Expand Down
20 changes: 2 additions & 18 deletions src/tests/http/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,15 +289,7 @@ async def test_invalid_operation_selection(http_client: HttpClient, operation_na
)

assert response.status_code == 400

if isinstance(http_client, ChaliceHttpClient):
# Our Chalice integration purposely wraps errors messages with a JSON object
assert response.json == {
"Code": "BadRequestError",
"Message": f'Unknown operation named "{operation_name}".',
}
else:
assert response.data == f'Unknown operation named "{operation_name}".'.encode()
assert response.data == f'Unknown operation named "{operation_name}".'.encode()


async def test_operation_selection_without_operations(http_client: HttpClient):
Expand All @@ -308,12 +300,4 @@ async def test_operation_selection_without_operations(http_client: HttpClient):
)

assert response.status_code == 400

if isinstance(http_client, ChaliceHttpClient):
# Our Chalice integration purposely wraps errors messages with a JSON object
assert response.json == {
"Code": "BadRequestError",
"Message": "Can't get GraphQL operation type",
}
else:
assert response.data == b"Can't get GraphQL operation type"
assert response.data == b"Can't get GraphQL operation type"
13 changes: 0 additions & 13 deletions src/tests/http/test_upload.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import contextlib
import json
from io import BytesIO

Expand All @@ -10,23 +9,11 @@

@pytest.fixture
def http_client(http_client_class: type[HttpClient]) -> HttpClient:
with contextlib.suppress(ImportError):
from .clients.chalice import ChaliceHttpClient

if http_client_class is ChaliceHttpClient:
pytest.xfail(reason="Chalice does not support uploads")

return http_client_class()


@pytest.fixture
def enabled_http_client(http_client_class: type[HttpClient]) -> HttpClient:
with contextlib.suppress(ImportError):
from .clients.chalice import ChaliceHttpClient

if http_client_class is ChaliceHttpClient:
pytest.xfail(reason="Chalice does not support uploads")

return http_client_class(multipart_uploads_enabled=True)


Expand Down
4 changes: 3 additions & 1 deletion src/tests/test/test_client_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

class DummyClient(BaseGraphQLTestClient):
def request(self, body, headers=None, files=None):
return types.SimpleNamespace(content=json.dumps(body).encode(), json=lambda: body)
return types.SimpleNamespace(
content=json.dumps(body).encode(), json=lambda: body
)


def test_build_body_with_variables_and_files():
Expand Down
Loading