Skip to content
This repository was archived by the owner on Jun 13, 2025. It is now read-only.

Commit 54aacf3

Browse files
authored
fix: handle internal server error when no GQL request body received (#1134)
1 parent e885187 commit 54aacf3

File tree

2 files changed

+62
-25
lines changed

2 files changed

+62
-25
lines changed

graphql_api/tests/test_views.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,23 @@ async def test_required_variable_missing(self):
309309
data = await self.do_query(schema, query=query, variables={})
310310

311311
assert data == {"detail": "Missing required variables: name", "status": 400}
312+
313+
async def test_empty_request_body(self):
314+
schema = generate_schema_with_required_variables()
315+
316+
request = RequestFactory().post(
317+
"/graphql/gh", "", content_type="application/json"
318+
)
319+
match = ResolverMatch(func=lambda: None, args=(), kwargs={"service": "github"})
320+
request.resolver_match = match
321+
request.user = None
322+
request.current_owner = None
323+
324+
view = AsyncGraphqlView.as_view(schema=schema)
325+
response = await view(request, service="gh")
326+
327+
assert response.status_code == 400
328+
assert json.loads(response.content) == {
329+
"status": 400,
330+
"detail": "Invalid JSON response received.",
331+
}

graphql_api/views.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from ariadne.validation import cost_validator
1313
from ariadne_django.views import GraphQLAsyncView
1414
from django.conf import settings
15+
from django.core.handlers.wsgi import WSGIRequest
1516
from django.http import (
17+
HttpResponse,
1618
HttpResponseBadRequest,
1719
HttpResponseNotAllowed,
1820
JsonResponse,
@@ -84,13 +86,13 @@ class QueryMetricsExtension(Extension):
8486
8587
"""
8688

87-
def __init__(self):
88-
self.start_timestamp = None
89-
self.end_timestamp = None
90-
self.operation_type = None
91-
self.operation_name = None
89+
def __init__(self) -> None:
90+
self.start_timestamp: float = 0
91+
self.end_timestamp: float = 0
92+
self.operation_type: str | None = None
93+
self.operation_name: str | None = None
9294

93-
def set_type_and_name(self, query):
95+
def set_type_and_name(self, query: str) -> None:
9496
operation_type = "unknown_type" # default value
9597
operation_name = "unknown_name" # default value
9698
try:
@@ -119,7 +121,7 @@ def set_type_and_name(self, query):
119121
extra=dict(query_slice=query_slice),
120122
)
121123

122-
def request_started(self, context):
124+
def request_started(self, context: dict[str, Any]) -> None:
123125
"""
124126
Extension hook executed at request's start.
125127
"""
@@ -133,7 +135,7 @@ def request_started(self, context):
133135
),
134136
)
135137

136-
def request_finished(self, context):
138+
def request_finished(self, context: dict[str, Any]) -> None:
137139
"""
138140
Extension hook executed at request's end.
139141
"""
@@ -143,7 +145,7 @@ def request_finished(self, context):
143145
operation_type=self.operation_type, operation_name=self.operation_name
144146
).observe(latency)
145147

146-
def has_errors(self, errors, context):
148+
def has_errors(self, errors: list[dict[str, Any]], context: dict[str, Any]) -> None:
147149
"""
148150
Extension hook executed when GraphQL encountered errors.
149151
"""
@@ -163,10 +165,10 @@ class RequestFinalizer:
163165
"bundle_analysis_base_report_db_path",
164166
]
165167

166-
def __init__(self, request):
168+
def __init__(self, request: WSGIRequest) -> None:
167169
self.request = request
168170

169-
def _remove_temp_files(self):
171+
def _remove_temp_files(self) -> None:
170172
"""
171173
Some requests cause temporary files to be created in /tmp (eg BundleAnalysis)
172174
This cleanup step clears all contents of the /tmp directory after each request
@@ -184,10 +186,10 @@ def _remove_temp_files(self):
184186
extra={"file_path": file_path, "exc": e},
185187
)
186188

187-
def __enter__(self):
189+
def __enter__(self) -> None:
188190
pass
189191

190-
def __exit__(self, exc_type, exc_value, exc_traceback):
192+
def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None:
191193
self._remove_temp_files()
192194

193195

@@ -203,7 +205,7 @@ def get_validation_rules(
203205
data: dict,
204206
) -> Optional[Collection]:
205207
return [
206-
create_required_variables_rule(variables=data.get("variables")),
208+
create_required_variables_rule(variables=data.get("variables", {})),
207209
create_max_aliases_rule(max_aliases=settings.GRAPHQL_MAX_ALIASES),
208210
create_max_depth_rule(max_depth=settings.GRAPHQL_MAX_DEPTH),
209211
cost_validator(
@@ -215,20 +217,22 @@ def get_validation_rules(
215217

216218
validation_rules = get_validation_rules # type: ignore
217219

218-
def get_clean_query(self, request_body):
220+
def get_clean_query(self, request_body: dict[str, Any]) -> str | None:
219221
# clean up graphql query to remove new lines and extra spaces
220222
if "query" in request_body and isinstance(request_body["query"], str):
221223
clean_query = request_body["query"].replace("\n", " ")
222224
clean_query = clean_query.replace(" ", "").strip()
223225
return clean_query
224226

225-
async def get(self, *args, **kwargs):
227+
async def get(self, *args: Any, **kwargs: Any) -> HttpResponse:
226228
if settings.GRAPHQL_PLAYGROUND:
227229
return await super().get(*args, **kwargs)
228230
# No GraphqlPlayground if no settings.DEBUG
229231
return HttpResponseNotAllowed(["POST"])
230232

231-
async def post(self, request, *args, **kwargs):
233+
async def post(
234+
self, request: WSGIRequest, *args: Any, **kwargs: Any
235+
) -> HttpResponse:
232236
await self._get_user(request)
233237
# get request body information for logging
234238
req_body = json.loads(request.body.decode("utf-8")) if request.body else {}
@@ -277,7 +281,20 @@ async def post(self, request, *args, **kwargs):
277281
)
278282

279283
content = response.content.decode("utf-8")
280-
data = json.loads(content)
284+
try:
285+
data = json.loads(content)
286+
except json.JSONDecodeError:
287+
log.error(
288+
"Failed to decode JSON response",
289+
extra={"content": content, "request_body": req_body},
290+
)
291+
return JsonResponse(
292+
data={
293+
"status": 400,
294+
"detail": "Invalid JSON response received.",
295+
},
296+
status=400,
297+
)
281298

282299
if "errors" in data:
283300
inc_counter(
@@ -309,7 +326,7 @@ async def post(self, request, *args, **kwargs):
309326
pass
310327
return response
311328

312-
def context_value(self, request, *_):
329+
def context_value(self, request: WSGIRequest, *_args: Any) -> dict[str, Any]:
313330
request_body = json.loads(request.body.decode("utf-8")) if request.body else {}
314331
self.request = request
315332

@@ -320,7 +337,7 @@ def context_value(self, request, *_):
320337
"clean_query": self.get_clean_query(request_body) if request_body else "",
321338
}
322339

323-
def error_formatter(self, error, debug=False):
340+
def error_formatter(self, error: Any, debug: bool = False) -> dict[str, Any]:
324341
user = self.request.user
325342
is_anonymous = user.is_anonymous if user else True
326343
# the only way to check for a malformed query
@@ -335,7 +352,7 @@ def error_formatter(self, error, debug=False):
335352
if isinstance(original_error, BaseException) or isinstance(
336353
original_error, ServiceException
337354
):
338-
formatted["message"] = original_error.message
355+
formatted["message"] = original_error.message # type: ignore
339356
formatted["type"] = type(original_error).__name__
340357
else:
341358
# otherwise it's not supposed to happen, so we log it
@@ -344,13 +361,13 @@ def error_formatter(self, error, debug=False):
344361
return formatted
345362

346363
@sync_to_async
347-
def _get_user(self, request):
364+
def _get_user(self, request: WSGIRequest) -> None:
348365
# force eager evaluation of `request.user` (a lazy object)
349366
# while we're in a sync context
350367
if request.user:
351368
request.user.pk
352369

353-
def _check_ratelimit(self, request):
370+
def _check_ratelimit(self, request: WSGIRequest) -> bool:
354371
if not settings.GRAPHQL_RATE_LIMIT_ENABLED:
355372
return False
356373

@@ -392,7 +409,7 @@ def _check_ratelimit(self, request):
392409
redis.incr(key)
393410
return False
394411

395-
def get_client_ip(self, request):
412+
def get_client_ip(self, request: WSGIRequest) -> str:
396413
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
397414
if x_forwarded_for:
398415
ip = x_forwarded_for.split(",")[0]
@@ -404,7 +421,7 @@ def get_client_ip(self, request):
404421
BaseAriadneView = AsyncGraphqlView.as_view()
405422

406423

407-
async def ariadne_view(request, service):
424+
async def ariadne_view(request: WSGIRequest, service: str) -> HttpResponse:
408425
response = BaseAriadneView(request, service)
409426
if iscoroutine(response):
410427
response = await response

0 commit comments

Comments
 (0)