Skip to content

Commit 7716534

Browse files
authored
Enable Chalice multipart uploads (#134)
* Enable Chalice multipart uploads * Improved errors on chalice * Applied formatting * Improve AGENTS
1 parent 308e2d7 commit 7716534

File tree

12 files changed

+122
-89
lines changed

12 files changed

+122
-89
lines changed

AGENTS.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,18 @@ Run the full test suite with:
1717
uv run pytest
1818
```
1919

20-
You can check the tests coverage by adding the `--cov` flag from `pytest-cov` when running `pytest`.
20+
You can check the tests coverage by adding the `--cov` and `--cov-report=term-missing` (to report the missing lines in the command output) flags from `pytest-cov` when running `pytest`.
2121

2222
```bash
23-
uv run pytest --cov
23+
uv run pytest --cov --cov-report=term-missing
24+
```
25+
26+
## Before commiting
27+
28+
All files must be properly formatted before creating a PR, so we can merge it upstream.
29+
30+
You can run `ruff` for formatting automatically all the files of the project:
31+
32+
```bash
33+
uv run ruff format
2434
```

src/graphql_server/asgi/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,12 @@
4545
)
4646

4747
if TYPE_CHECKING:
48-
from collections.abc import AsyncGenerator, AsyncIterator, Mapping, Sequence # pragma: no cover
48+
from collections.abc import ( # pragma: no cover
49+
AsyncGenerator,
50+
AsyncIterator,
51+
Mapping,
52+
Sequence,
53+
)
4954

5055
from graphql.type import GraphQLSchema # pragma: no cover
5156
from starlette.types import Receive, Scope, Send # pragma: no cover
@@ -112,7 +117,9 @@ async def iter_json(
112117
async def send_json(self, message: Mapping[str, object]) -> None:
113118
try:
114119
await self.ws.send_text(self.view.encode_json(message))
115-
except WebSocketDisconnect as exc: # pragma: no cover - network errors mocked elsewhere
120+
except (
121+
WebSocketDisconnect
122+
) as exc: # pragma: no cover - network errors mocked elsewhere
116123
raise WebSocketDisconnected from exc
117124

118125
async def close(self, code: int, reason: str) -> None:

src/graphql_server/chalice/views.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4+
from io import BytesIO
45
from typing import TYPE_CHECKING, Any, Optional, Union, cast
56

67
from chalice.app import Request, Response
@@ -23,6 +24,8 @@
2324
class ChaliceHTTPRequestAdapter(SyncHTTPRequestAdapter):
2425
def __init__(self, request: Request) -> None:
2526
self.request = request
27+
self._post_data: Optional[dict[str, Union[str, bytes]]] = None
28+
self._files: Optional[dict[str, Any]] = None
2629

2730
@property
2831
def query_params(self) -> QueryParams:
@@ -42,11 +45,49 @@ def headers(self) -> Mapping[str, str]:
4245

4346
@property
4447
def post_data(self) -> Mapping[str, Union[str, bytes]]:
45-
raise NotImplementedError
48+
if self._post_data is None:
49+
self._parse_body()
50+
return self._post_data or {}
4651

4752
@property
4853
def files(self) -> Mapping[str, Any]:
49-
raise NotImplementedError
54+
if self._files is None:
55+
self._parse_body()
56+
return self._files or {}
57+
58+
def _parse_body(self) -> None:
59+
self._post_data = {}
60+
self._files = {}
61+
62+
content_type = self.content_type or ""
63+
64+
if "multipart/form-data" in content_type:
65+
import cgi
66+
67+
fp = BytesIO(self.request.raw_body)
68+
environ = {
69+
"REQUEST_METHOD": "POST",
70+
"CONTENT_TYPE": content_type,
71+
"CONTENT_LENGTH": str(len(self.request.raw_body)),
72+
}
73+
fs = cgi.FieldStorage(fp=fp, environ=environ, keep_blank_values=True)
74+
for key in fs.keys():
75+
field = fs[key]
76+
if isinstance(field, list):
77+
field = field[0]
78+
if getattr(field, "filename", None):
79+
data = field.file.read()
80+
self._files[key] = BytesIO(data)
81+
else:
82+
self._post_data[key] = field.value
83+
elif "application/x-www-form-urlencoded" in content_type:
84+
from urllib.parse import parse_qs
85+
86+
data = parse_qs(self.request.raw_body.decode())
87+
self._post_data = {k: v[0] for k, v in data.items()}
88+
else:
89+
self._post_data = {}
90+
self._files = {}
5091

5192
@property
5293
def content_type(self) -> Optional[str]:
@@ -65,9 +106,11 @@ def __init__(
65106
graphiql: Optional[bool] = None,
66107
graphql_ide: Optional[GraphQL_IDE] = "graphiql",
67108
allow_queries_via_get: bool = True,
109+
multipart_uploads_enabled: bool = False,
68110
) -> None:
69111
self.allow_queries_via_get = allow_queries_via_get
70112
self.schema = schema
113+
self.multipart_uploads_enabled = multipart_uploads_enabled
71114
if graphiql is not None:
72115
warnings.warn(
73116
"The `graphiql` argument is deprecated in favor of `graphql_ide`",
@@ -95,7 +138,6 @@ def get_sub_response(self, request: Request) -> TemporalResponse:
95138
@staticmethod
96139
def error_response(
97140
message: str,
98-
error_code: str,
99141
http_status_code: int,
100142
headers: Optional[dict[str, str | list[str]]] = None,
101143
) -> Response:
@@ -110,9 +152,7 @@ def error_response(
110152
Returns:
111153
An errors response.
112154
"""
113-
body = {"Code": error_code, "Message": message}
114-
115-
return Response(body=body, status_code=http_status_code, headers=headers)
155+
return Response(body=message, status_code=http_status_code, headers=headers)
116156

117157
def get_context(self, request: Request, response: TemporalResponse) -> Context:
118158
return {"request": request, "response": response} # type: ignore
@@ -138,18 +178,7 @@ def execute_request(self, request: Request) -> Response:
138178
try:
139179
return self.run(request=request)
140180
except HTTPException as e:
141-
error_code_map = {
142-
400: "BadRequestError",
143-
401: "UnauthorizedError",
144-
403: "ForbiddenError",
145-
404: "NotFoundError",
146-
409: "ConflictError",
147-
429: "TooManyRequestsError",
148-
500: "ChaliceViewError",
149-
}
150-
151181
return self.error_response(
152-
error_code=error_code_map.get(e.status_code, "ChaliceViewError"),
153182
message=e.reason,
154183
http_status_code=e.status_code,
155184
)

src/graphql_server/channels/testing.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ def __init__(
7171
subprotocols: an ordered list of preferred subprotocols to be sent to the server.
7272
**kwargs: additional arguments to be passed to the `WebsocketCommunicator` constructor.
7373
"""
74-
if connection_params is None: # pragma: no cover - tested via custom initialisation
74+
if (
75+
connection_params is None
76+
): # pragma: no cover - tested via custom initialisation
7577
connection_params = {}
7678
self.protocol = protocol
7779
subprotocols = kwargs.get("subprotocols", [])
@@ -139,7 +141,9 @@ async def subscribe(
139141
},
140142
}
141143

142-
if variables is not None: # pragma: no cover - exercised in higher-level tests
144+
if (
145+
variables is not None
146+
): # pragma: no cover - exercised in higher-level tests
143147
start_message["payload"]["variables"] = variables
144148

145149
await self.send_json_to(start_message)
@@ -155,7 +159,9 @@ async def subscribe(
155159
ret.errors = self.process_errors(payload.get("errors") or [])
156160
ret.extensions = payload.get("extensions", None)
157161
yield ret
158-
elif message["type"] == "error": # pragma: no cover - network failures untested
162+
elif (
163+
message["type"] == "error"
164+
): # pragma: no cover - network failures untested
159165
error_payload = message["payload"]
160166
yield ExecutionResult(
161167
data=None, errors=self.process_errors(error_payload)

src/tests/http/clients/chalice.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import urllib.parse
44
from io import BytesIO
55
from json import dumps
6-
from typing import Any, Optional, Union
6+
from typing import Any, Optional
77
from typing_extensions import Literal
88

99
from graphql import ExecutionResult
10+
from urllib3 import encode_multipart_formdata
1011

1112
from chalice.app import Chalice
1213
from chalice.app import Request as ChaliceRequest
@@ -60,11 +61,14 @@ def __init__(
6061
graphiql=graphiql,
6162
graphql_ide=graphql_ide,
6263
allow_queries_via_get=allow_queries_via_get,
64+
multipart_uploads_enabled=multipart_uploads_enabled,
6365
)
6466
view.result_override = result_override
6567

6668
@self.app.route(
67-
"/graphql", methods=["GET", "POST"], content_types=["application/json"]
69+
"/graphql",
70+
methods=["GET", "POST"],
71+
content_types=["application/json", "multipart/form-data"],
6872
)
6973
def handle_graphql():
7074
assert self.app.current_request is not None
@@ -90,27 +94,27 @@ async def _graphql_request(
9094
extensions=extensions,
9195
)
9296

93-
data: Union[dict[str, object], str, None] = None
94-
95-
if body and files:
96-
body.update({name: (file, name) for name, file in files.items()})
97-
9897
url = "/graphql"
98+
headers = self._get_headers(method=method, headers=headers, files=files)
9999

100100
if method == "get":
101101
body_encoded = urllib.parse.urlencode(body or {})
102102
url = f"{url}?{body_encoded}"
103-
else:
104-
if body:
105-
data = body if files else dumps(body)
106-
kwargs["body"] = data
103+
elif body:
104+
if files:
105+
fields = {"operations": body["operations"], "map": body["map"]}
106+
for filename, file in files.items():
107+
fields[filename] = (filename, file.read(), "text/plain")
108+
data, content_type = encode_multipart_formdata(fields)
109+
headers.update(
110+
{"Content-Type": content_type, "Content-Length": f"{len(data)}"}
111+
)
112+
kwargs["body"] = data
113+
else:
114+
kwargs["body"] = dumps(body)
107115

108116
with Client(self.app) as client:
109-
response = getattr(client.http, method)(
110-
url,
111-
headers=self._get_headers(method=method, headers=headers, files=files),
112-
**kwargs,
113-
)
117+
response = getattr(client.http, method)(url, headers=headers, **kwargs)
114118

115119
return Response(
116120
status_code=response.status_code,

src/tests/http/test_query.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -289,15 +289,7 @@ async def test_invalid_operation_selection(http_client: HttpClient, operation_na
289289
)
290290

291291
assert response.status_code == 400
292-
293-
if isinstance(http_client, ChaliceHttpClient):
294-
# Our Chalice integration purposely wraps errors messages with a JSON object
295-
assert response.json == {
296-
"Code": "BadRequestError",
297-
"Message": f'Unknown operation named "{operation_name}".',
298-
}
299-
else:
300-
assert response.data == f'Unknown operation named "{operation_name}".'.encode()
292+
assert response.data == f'Unknown operation named "{operation_name}".'.encode()
301293

302294

303295
async def test_operation_selection_without_operations(http_client: HttpClient):
@@ -308,12 +300,4 @@ async def test_operation_selection_without_operations(http_client: HttpClient):
308300
)
309301

310302
assert response.status_code == 400
311-
312-
if isinstance(http_client, ChaliceHttpClient):
313-
# Our Chalice integration purposely wraps errors messages with a JSON object
314-
assert response.json == {
315-
"Code": "BadRequestError",
316-
"Message": "Can't get GraphQL operation type",
317-
}
318-
else:
319-
assert response.data == b"Can't get GraphQL operation type"
303+
assert response.data == b"Can't get GraphQL operation type"

src/tests/http/test_upload.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import contextlib
21
import json
32
from io import BytesIO
43

@@ -10,23 +9,11 @@
109

1110
@pytest.fixture
1211
def http_client(http_client_class: type[HttpClient]) -> HttpClient:
13-
with contextlib.suppress(ImportError):
14-
from .clients.chalice import ChaliceHttpClient
15-
16-
if http_client_class is ChaliceHttpClient:
17-
pytest.xfail(reason="Chalice does not support uploads")
18-
1912
return http_client_class()
2013

2114

2215
@pytest.fixture
2316
def enabled_http_client(http_client_class: type[HttpClient]) -> HttpClient:
24-
with contextlib.suppress(ImportError):
25-
from .clients.chalice import ChaliceHttpClient
26-
27-
if http_client_class is ChaliceHttpClient:
28-
pytest.xfail(reason="Chalice does not support uploads")
29-
3017
return http_client_class(multipart_uploads_enabled=True)
3118

3219

src/tests/test/test_client_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
class DummyClient(BaseGraphQLTestClient):
1010
def request(self, body, headers=None, files=None):
11-
return types.SimpleNamespace(content=json.dumps(body).encode(), json=lambda: body)
11+
return types.SimpleNamespace(
12+
content=json.dumps(body).encode(), json=lambda: body
13+
)
1214

1315

1416
def test_build_body_with_variables_and_files():

0 commit comments

Comments
 (0)