1212from ariadne .validation import cost_validator
1313from ariadne_django .views import GraphQLAsyncView
1414from django .conf import settings
15+ from django .core .handlers .wsgi import WSGIRequest
1516from django .http import (
17+ HttpResponse ,
1618 HttpResponseBadRequest ,
1719 HttpResponseNotAllowed ,
1820 JsonResponse ,
@@ -84,13 +86,13 @@ class QueryMetricsExtension(Extension):
8486
8587 """
8688
87- def __init__ (self ):
88- self .start_timestamp = None
89- self .end_timestamp = None
90- self .operation_type = None
91- self .operation_name = None
89+ def __init__ (self ) -> None :
90+ self .start_timestamp : float = 0
91+ self .end_timestamp : float = 0
92+ self .operation_type : str | None = None
93+ self .operation_name : str | None = None
9294
93- def set_type_and_name (self , query ) :
95+ def set_type_and_name (self , query : str ) -> None :
9496 operation_type = "unknown_type" # default value
9597 operation_name = "unknown_name" # default value
9698 try :
@@ -119,7 +121,7 @@ def set_type_and_name(self, query):
119121 extra = dict (query_slice = query_slice ),
120122 )
121123
122- def request_started (self , context ) :
124+ def request_started (self , context : dict [ str , Any ]) -> None :
123125 """
124126 Extension hook executed at request's start.
125127 """
@@ -133,7 +135,7 @@ def request_started(self, context):
133135 ),
134136 )
135137
136- def request_finished (self , context ) :
138+ def request_finished (self , context : dict [ str , Any ]) -> None :
137139 """
138140 Extension hook executed at request's end.
139141 """
@@ -143,7 +145,7 @@ def request_finished(self, context):
143145 operation_type = self .operation_type , operation_name = self .operation_name
144146 ).observe (latency )
145147
146- def has_errors (self , errors , context ) :
148+ def has_errors (self , errors : list [ dict [ str , Any ]], context : dict [ str , Any ]) -> None :
147149 """
148150 Extension hook executed when GraphQL encountered errors.
149151 """
@@ -163,10 +165,10 @@ class RequestFinalizer:
163165 "bundle_analysis_base_report_db_path" ,
164166 ]
165167
166- def __init__ (self , request ) :
168+ def __init__ (self , request : WSGIRequest ) -> None :
167169 self .request = request
168170
169- def _remove_temp_files (self ):
171+ def _remove_temp_files (self ) -> None :
170172 """
171173 Some requests cause temporary files to be created in /tmp (eg BundleAnalysis)
172174 This cleanup step clears all contents of the /tmp directory after each request
@@ -184,10 +186,10 @@ def _remove_temp_files(self):
184186 extra = {"file_path" : file_path , "exc" : e },
185187 )
186188
187- def __enter__ (self ):
189+ def __enter__ (self ) -> None :
188190 pass
189191
190- def __exit__ (self , exc_type , exc_value , exc_traceback ) :
192+ def __exit__ (self , exc_type : Any , exc_value : Any , exc_traceback : Any ) -> None :
191193 self ._remove_temp_files ()
192194
193195
@@ -203,7 +205,7 @@ def get_validation_rules(
203205 data : dict ,
204206 ) -> Optional [Collection ]:
205207 return [
206- create_required_variables_rule (variables = data .get ("variables" )),
208+ create_required_variables_rule (variables = data .get ("variables" , {} )),
207209 create_max_aliases_rule (max_aliases = settings .GRAPHQL_MAX_ALIASES ),
208210 create_max_depth_rule (max_depth = settings .GRAPHQL_MAX_DEPTH ),
209211 cost_validator (
@@ -215,20 +217,22 @@ def get_validation_rules(
215217
216218 validation_rules = get_validation_rules # type: ignore
217219
218- def get_clean_query (self , request_body ) :
220+ def get_clean_query (self , request_body : dict [ str , Any ]) -> str | None :
219221 # clean up graphql query to remove new lines and extra spaces
220222 if "query" in request_body and isinstance (request_body ["query" ], str ):
221223 clean_query = request_body ["query" ].replace ("\n " , " " )
222224 clean_query = clean_query .replace (" " , "" ).strip ()
223225 return clean_query
224226
225- async def get (self , * args , ** kwargs ) :
227+ async def get (self , * args : Any , ** kwargs : Any ) -> HttpResponse :
226228 if settings .GRAPHQL_PLAYGROUND :
227229 return await super ().get (* args , ** kwargs )
228230 # No GraphqlPlayground if no settings.DEBUG
229231 return HttpResponseNotAllowed (["POST" ])
230232
231- async def post (self , request , * args , ** kwargs ):
233+ async def post (
234+ self , request : WSGIRequest , * args : Any , ** kwargs : Any
235+ ) -> HttpResponse :
232236 await self ._get_user (request )
233237 # get request body information for logging
234238 req_body = json .loads (request .body .decode ("utf-8" )) if request .body else {}
@@ -322,7 +326,7 @@ async def post(self, request, *args, **kwargs):
322326 pass
323327 return response
324328
325- def context_value (self , request , * _ ) :
329+ def context_value (self , request : WSGIRequest , * _args : Any ) -> dict [ str , Any ] :
326330 request_body = json .loads (request .body .decode ("utf-8" )) if request .body else {}
327331 self .request = request
328332
@@ -333,7 +337,7 @@ def context_value(self, request, *_):
333337 "clean_query" : self .get_clean_query (request_body ) if request_body else "" ,
334338 }
335339
336- def error_formatter (self , error , debug = False ):
340+ def error_formatter (self , error : Any , debug : bool = False ) -> dict [ str , Any ] :
337341 user = self .request .user
338342 is_anonymous = user .is_anonymous if user else True
339343 # the only way to check for a malformed query
@@ -348,7 +352,7 @@ def error_formatter(self, error, debug=False):
348352 if isinstance (original_error , BaseException ) or isinstance (
349353 original_error , ServiceException
350354 ):
351- formatted ["message" ] = original_error .message
355+ formatted ["message" ] = original_error .message # type: ignore
352356 formatted ["type" ] = type (original_error ).__name__
353357 else :
354358 # otherwise it's not supposed to happen, so we log it
@@ -357,13 +361,13 @@ def error_formatter(self, error, debug=False):
357361 return formatted
358362
359363 @sync_to_async
360- def _get_user (self , request ) :
364+ def _get_user (self , request : WSGIRequest ) -> None :
361365 # force eager evaluation of `request.user` (a lazy object)
362366 # while we're in a sync context
363367 if request .user :
364368 request .user .pk
365369
366- def _check_ratelimit (self , request ) :
370+ def _check_ratelimit (self , request : WSGIRequest ) -> bool :
367371 if not settings .GRAPHQL_RATE_LIMIT_ENABLED :
368372 return False
369373
@@ -405,7 +409,7 @@ def _check_ratelimit(self, request):
405409 redis .incr (key )
406410 return False
407411
408- def get_client_ip (self , request ) :
412+ def get_client_ip (self , request : WSGIRequest ) -> str :
409413 x_forwarded_for = request .META .get ("HTTP_X_FORWARDED_FOR" )
410414 if x_forwarded_for :
411415 ip = x_forwarded_for .split ("," )[0 ]
@@ -417,7 +421,7 @@ def get_client_ip(self, request):
417421BaseAriadneView = AsyncGraphqlView .as_view ()
418422
419423
420- async def ariadne_view (request , service ) :
424+ async def ariadne_view (request : WSGIRequest , service : str ) -> HttpResponse :
421425 response = BaseAriadneView (request , service )
422426 if iscoroutine (response ):
423427 response = await response
0 commit comments