Skip to content
This repository was archived by the owner on Jun 13, 2025. It is now read-only.

Commit 8ed56ee

Browse files
committed
add mypy to the views.py
1 parent 0c1790c commit 8ed56ee

File tree

1 file changed

+28
-24
lines changed

1 file changed

+28
-24
lines changed

graphql_api/views.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from ariadne.validation import cost_validator
1313
from ariadne_django.views import GraphQLAsyncView
1414
from django.conf import settings
15+
from django.core.handlers.wsgi import WSGIRequest
1516
from django.http import (
17+
HttpResponse,
1618
HttpResponseBadRequest,
1719
HttpResponseNotAllowed,
1820
JsonResponse,
@@ -84,13 +86,13 @@ class QueryMetricsExtension(Extension):
8486
8587
"""
8688

87-
def __init__(self):
88-
self.start_timestamp = None
89-
self.end_timestamp = None
90-
self.operation_type = None
91-
self.operation_name = None
89+
def __init__(self) -> None:
90+
self.start_timestamp: float = 0
91+
self.end_timestamp: float = 0
92+
self.operation_type: str | None = None
93+
self.operation_name: str | None = None
9294

93-
def set_type_and_name(self, query):
95+
def set_type_and_name(self, query: str) -> None:
9496
operation_type = "unknown_type" # default value
9597
operation_name = "unknown_name" # default value
9698
try:
@@ -119,7 +121,7 @@ def set_type_and_name(self, query):
119121
extra=dict(query_slice=query_slice),
120122
)
121123

122-
def request_started(self, context):
124+
def request_started(self, context: dict[str, Any]) -> None:
123125
"""
124126
Extension hook executed at request's start.
125127
"""
@@ -133,7 +135,7 @@ def request_started(self, context):
133135
),
134136
)
135137

136-
def request_finished(self, context):
138+
def request_finished(self, context: dict[str, Any]) -> None:
137139
"""
138140
Extension hook executed at request's end.
139141
"""
@@ -143,7 +145,7 @@ def request_finished(self, context):
143145
operation_type=self.operation_type, operation_name=self.operation_name
144146
).observe(latency)
145147

146-
def has_errors(self, errors, context):
148+
def has_errors(self, errors: list[dict[str, Any]], context: dict[str, Any]) -> None:
147149
"""
148150
Extension hook executed when GraphQL encountered errors.
149151
"""
@@ -163,10 +165,10 @@ class RequestFinalizer:
163165
"bundle_analysis_base_report_db_path",
164166
]
165167

166-
def __init__(self, request):
168+
def __init__(self, request: WSGIRequest) -> None:
167169
self.request = request
168170

169-
def _remove_temp_files(self):
171+
def _remove_temp_files(self) -> None:
170172
"""
171173
Some requests cause temporary files to be created in /tmp (eg BundleAnalysis)
172174
This cleanup step clears all contents of the /tmp directory after each request
@@ -184,10 +186,10 @@ def _remove_temp_files(self):
184186
extra={"file_path": file_path, "exc": e},
185187
)
186188

187-
def __enter__(self):
189+
def __enter__(self) -> None:
188190
pass
189191

190-
def __exit__(self, exc_type, exc_value, exc_traceback):
192+
def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None:
191193
self._remove_temp_files()
192194

193195

@@ -203,7 +205,7 @@ def get_validation_rules(
203205
data: dict,
204206
) -> Optional[Collection]:
205207
return [
206-
create_required_variables_rule(variables=data.get("variables")),
208+
create_required_variables_rule(variables=data.get("variables", {})),
207209
create_max_aliases_rule(max_aliases=settings.GRAPHQL_MAX_ALIASES),
208210
create_max_depth_rule(max_depth=settings.GRAPHQL_MAX_DEPTH),
209211
cost_validator(
@@ -215,20 +217,22 @@ def get_validation_rules(
215217

216218
validation_rules = get_validation_rules # type: ignore
217219

218-
def get_clean_query(self, request_body):
220+
def get_clean_query(self, request_body: dict[str, Any]) -> str | None:
219221
# clean up graphql query to remove new lines and extra spaces
220222
if "query" in request_body and isinstance(request_body["query"], str):
221223
clean_query = request_body["query"].replace("\n", " ")
222224
clean_query = clean_query.replace(" ", "").strip()
223225
return clean_query
224226

225-
async def get(self, *args, **kwargs):
227+
async def get(self, *args: Any, **kwargs: Any) -> HttpResponse:
226228
if settings.GRAPHQL_PLAYGROUND:
227229
return await super().get(*args, **kwargs)
228230
# No GraphqlPlayground if no settings.DEBUG
229231
return HttpResponseNotAllowed(["POST"])
230232

231-
async def post(self, request, *args, **kwargs):
233+
async def post(
234+
self, request: WSGIRequest, *args: Any, **kwargs: Any
235+
) -> HttpResponse:
232236
await self._get_user(request)
233237
# get request body information for logging
234238
req_body = json.loads(request.body.decode("utf-8")) if request.body else {}
@@ -322,7 +326,7 @@ async def post(self, request, *args, **kwargs):
322326
pass
323327
return response
324328

325-
def context_value(self, request, *_):
329+
def context_value(self, request: WSGIRequest, *_args: Any) -> dict[str, Any]:
326330
request_body = json.loads(request.body.decode("utf-8")) if request.body else {}
327331
self.request = request
328332

@@ -333,7 +337,7 @@ def context_value(self, request, *_):
333337
"clean_query": self.get_clean_query(request_body) if request_body else "",
334338
}
335339

336-
def error_formatter(self, error, debug=False):
340+
def error_formatter(self, error: Any, debug: bool = False) -> dict[str, Any]:
337341
user = self.request.user
338342
is_anonymous = user.is_anonymous if user else True
339343
# the only way to check for a malformed query
@@ -348,7 +352,7 @@ def error_formatter(self, error, debug=False):
348352
if isinstance(original_error, BaseException) or isinstance(
349353
original_error, ServiceException
350354
):
351-
formatted["message"] = original_error.message
355+
formatted["message"] = original_error.message # type: ignore
352356
formatted["type"] = type(original_error).__name__
353357
else:
354358
# otherwise it's not supposed to happen, so we log it
@@ -357,13 +361,13 @@ def error_formatter(self, error, debug=False):
357361
return formatted
358362

359363
@sync_to_async
360-
def _get_user(self, request):
364+
def _get_user(self, request: WSGIRequest) -> None:
361365
# force eager evaluation of `request.user` (a lazy object)
362366
# while we're in a sync context
363367
if request.user:
364368
request.user.pk
365369

366-
def _check_ratelimit(self, request):
370+
def _check_ratelimit(self, request: WSGIRequest) -> bool:
367371
if not settings.GRAPHQL_RATE_LIMIT_ENABLED:
368372
return False
369373

@@ -405,7 +409,7 @@ def _check_ratelimit(self, request):
405409
redis.incr(key)
406410
return False
407411

408-
def get_client_ip(self, request):
412+
def get_client_ip(self, request: WSGIRequest) -> str:
409413
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
410414
if x_forwarded_for:
411415
ip = x_forwarded_for.split(",")[0]
@@ -417,7 +421,7 @@ def get_client_ip(self, request):
417421
BaseAriadneView = AsyncGraphqlView.as_view()
418422

419423

420-
async def ariadne_view(request, service):
424+
async def ariadne_view(request: WSGIRequest, service: str) -> HttpResponse:
421425
response = BaseAriadneView(request, service)
422426
if iscoroutine(response):
423427
response = await response

0 commit comments

Comments
 (0)