1212from ariadne .validation import cost_validator
1313from ariadne_django .views import GraphQLAsyncView
1414from django .conf import settings
15+ from django .core .handlers .wsgi import WSGIRequest
1516from django .http import (
17+ HttpResponse ,
1618 HttpResponseBadRequest ,
1719 HttpResponseNotAllowed ,
1820 JsonResponse ,
@@ -84,13 +86,13 @@ class QueryMetricsExtension(Extension):
8486
8587 """
8688
87- def __init__ (self ):
88- self .start_timestamp = None
89- self .end_timestamp = None
90- self .operation_type = None
91- self .operation_name = None
89+ def __init__ (self ) -> None :
90+ self .start_timestamp : float = 0
91+ self .end_timestamp : float = 0
92+ self .operation_type : str | None = None
93+ self .operation_name : str | None = None
9294
93- def set_type_and_name (self , query ) :
95+ def set_type_and_name (self , query : str ) -> None :
9496 operation_type = "unknown_type" # default value
9597 operation_name = "unknown_name" # default value
9698 try :
@@ -119,7 +121,7 @@ def set_type_and_name(self, query):
119121 extra = dict (query_slice = query_slice ),
120122 )
121123
122- def request_started (self , context ) :
124+ def request_started (self , context : dict [ str , Any ]) -> None :
123125 """
124126 Extension hook executed at request's start.
125127 """
@@ -133,7 +135,7 @@ def request_started(self, context):
133135 ),
134136 )
135137
136- def request_finished (self , context ) :
138+ def request_finished (self , context : dict [ str , Any ]) -> None :
137139 """
138140 Extension hook executed at request's end.
139141 """
@@ -143,7 +145,7 @@ def request_finished(self, context):
143145 operation_type = self .operation_type , operation_name = self .operation_name
144146 ).observe (latency )
145147
146- def has_errors (self , errors , context ) :
148+ def has_errors (self , errors : list [ dict [ str , Any ]], context : dict [ str , Any ]) -> None :
147149 """
148150 Extension hook executed when GraphQL encountered errors.
149151 """
@@ -163,10 +165,10 @@ class RequestFinalizer:
163165 "bundle_analysis_base_report_db_path" ,
164166 ]
165167
166- def __init__ (self , request ) :
168+ def __init__ (self , request : WSGIRequest ) -> None :
167169 self .request = request
168170
169- def _remove_temp_files (self ):
171+ def _remove_temp_files (self ) -> None :
170172 """
171173 Some requests cause temporary files to be created in /tmp (eg BundleAnalysis)
172174 This cleanup step clears all contents of the /tmp directory after each request
@@ -184,10 +186,10 @@ def _remove_temp_files(self):
184186 extra = {"file_path" : file_path , "exc" : e },
185187 )
186188
187- def __enter__ (self ):
189+ def __enter__ (self ) -> None :
188190 pass
189191
190- def __exit__ (self , exc_type , exc_value , exc_traceback ) :
192+ def __exit__ (self , exc_type : Any , exc_value : Any , exc_traceback : Any ) -> None :
191193 self ._remove_temp_files ()
192194
193195
@@ -203,7 +205,7 @@ def get_validation_rules(
203205 data : dict ,
204206 ) -> Optional [Collection ]:
205207 return [
206- create_required_variables_rule (variables = data .get ("variables" )),
208+ create_required_variables_rule (variables = data .get ("variables" , {} )),
207209 create_max_aliases_rule (max_aliases = settings .GRAPHQL_MAX_ALIASES ),
208210 create_max_depth_rule (max_depth = settings .GRAPHQL_MAX_DEPTH ),
209211 cost_validator (
@@ -215,20 +217,22 @@ def get_validation_rules(
215217
216218 validation_rules = get_validation_rules # type: ignore
217219
218- def get_clean_query (self , request_body ) :
220+ def get_clean_query (self , request_body : dict [ str , Any ]) -> str | None :
219221 # clean up graphql query to remove new lines and extra spaces
220222 if "query" in request_body and isinstance (request_body ["query" ], str ):
221223 clean_query = request_body ["query" ].replace ("\n " , " " )
222224 clean_query = clean_query .replace (" " , "" ).strip ()
223225 return clean_query
224226
225- async def get (self , * args , ** kwargs ) :
227+ async def get (self , * args : Any , ** kwargs : Any ) -> HttpResponse :
226228 if settings .GRAPHQL_PLAYGROUND :
227229 return await super ().get (* args , ** kwargs )
228230 # No GraphqlPlayground if no settings.DEBUG
229231 return HttpResponseNotAllowed (["POST" ])
230232
231- async def post (self , request , * args , ** kwargs ):
233+ async def post (
234+ self , request : WSGIRequest , * args : Any , ** kwargs : Any
235+ ) -> HttpResponse :
232236 await self ._get_user (request )
233237 # get request body information for logging
234238 req_body = json .loads (request .body .decode ("utf-8" )) if request .body else {}
@@ -277,7 +281,20 @@ async def post(self, request, *args, **kwargs):
277281 )
278282
279283 content = response .content .decode ("utf-8" )
280- data = json .loads (content )
284+ try :
285+ data = json .loads (content )
286+ except json .JSONDecodeError :
287+ log .error (
288+ "Failed to decode JSON response" ,
289+ extra = {"content" : content , "request_body" : req_body },
290+ )
291+ return JsonResponse (
292+ data = {
293+ "status" : 400 ,
294+ "detail" : "Invalid JSON response received." ,
295+ },
296+ status = 400 ,
297+ )
281298
282299 if "errors" in data :
283300 inc_counter (
@@ -309,7 +326,7 @@ async def post(self, request, *args, **kwargs):
309326 pass
310327 return response
311328
312- def context_value (self , request , * _ ) :
329+ def context_value (self , request : WSGIRequest , * _args : Any ) -> dict [ str , Any ] :
313330 request_body = json .loads (request .body .decode ("utf-8" )) if request .body else {}
314331 self .request = request
315332
@@ -320,7 +337,7 @@ def context_value(self, request, *_):
320337 "clean_query" : self .get_clean_query (request_body ) if request_body else "" ,
321338 }
322339
323- def error_formatter (self , error , debug = False ):
340+ def error_formatter (self , error : Any , debug : bool = False ) -> dict [ str , Any ] :
324341 user = self .request .user
325342 is_anonymous = user .is_anonymous if user else True
326343 # the only way to check for a malformed query
@@ -335,7 +352,7 @@ def error_formatter(self, error, debug=False):
335352 if isinstance (original_error , BaseException ) or isinstance (
336353 original_error , ServiceException
337354 ):
338- formatted ["message" ] = original_error .message
355+ formatted ["message" ] = original_error .message # type: ignore
339356 formatted ["type" ] = type (original_error ).__name__
340357 else :
341358 # otherwise it's not supposed to happen, so we log it
@@ -344,13 +361,13 @@ def error_formatter(self, error, debug=False):
344361 return formatted
345362
346363 @sync_to_async
347- def _get_user (self , request ) :
364+ def _get_user (self , request : WSGIRequest ) -> None :
348365 # force eager evaluation of `request.user` (a lazy object)
349366 # while we're in a sync context
350367 if request .user :
351368 request .user .pk
352369
353- def _check_ratelimit (self , request ) :
370+ def _check_ratelimit (self , request : WSGIRequest ) -> bool :
354371 if not settings .GRAPHQL_RATE_LIMIT_ENABLED :
355372 return False
356373
@@ -392,7 +409,7 @@ def _check_ratelimit(self, request):
392409 redis .incr (key )
393410 return False
394411
395- def get_client_ip (self , request ) :
412+ def get_client_ip (self , request : WSGIRequest ) -> str :
396413 x_forwarded_for = request .META .get ("HTTP_X_FORWARDED_FOR" )
397414 if x_forwarded_for :
398415 ip = x_forwarded_for .split ("," )[0 ]
@@ -404,7 +421,7 @@ def get_client_ip(self, request):
404421BaseAriadneView = AsyncGraphqlView .as_view ()
405422
406423
407- async def ariadne_view (request , service ) :
424+ async def ariadne_view (request : WSGIRequest , service : str ) -> HttpResponse :
408425 response = BaseAriadneView (request , service )
409426 if iscoroutine (response ):
410427 response = await response
0 commit comments