Skip to content

Commit 18e0b2d

Browse files
committed
refactor(csrf): improve type hints in CSRF handling
1 parent 701f117 commit 18e0b2d

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

src/flask_wtf/csrf.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
import hmac
33
import logging
44
import os
5+
from typing import Any, Callable, Optional, Set, Union, TypeVar
56
from urllib.parse import urlparse
67

7-
from flask import Blueprint
8+
from flask import Blueprint, Flask
89
from flask import current_app
910
from flask import g
1011
from flask import request
@@ -19,8 +20,10 @@
1920
__all__ = ("generate_csrf", "validate_csrf", "CSRFProtect")
2021
logger = logging.getLogger(__name__)
2122

23+
F = TypeVar('F', bound=Callable[..., object])
2224

23-
def generate_csrf(secret_key=None, token_key=None):
25+
26+
def generate_csrf(secret_key: Optional[str] = None, token_key: Optional[str] = None) -> str:
2427
"""Generate a CSRF token. The token is cached for a request, so multiple
2528
calls to this function will generate the same token.
2629
@@ -63,7 +66,7 @@ def generate_csrf(secret_key=None, token_key=None):
6366
return g.get(field_name)
6467

6568

66-
def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
69+
def validate_csrf(data: str, secret_key: Optional[str] = None, time_limit: Optional[int] = None, token_key: Optional[str] = None) -> None:
6770
"""Check if the given data is a valid CSRF token. This compares the given
6871
signed token to the one stored in the session.
6972
@@ -116,8 +119,8 @@ def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
116119

117120

118121
def _get_config(
119-
value, config_name, default=None, required=True, message="CSRF is not configured."
120-
):
122+
value: Any, config_name: str, default: Any = None, required: bool = True, message: str = "CSRF is not configured."
123+
) -> Any:
121124
"""Find config value based on provided value, Flask config, and default
122125
value.
123126
@@ -139,16 +142,16 @@ def _get_config(
139142

140143

141144
class _FlaskFormCSRF(CSRF):
142-
def setup_form(self, form):
145+
def setup_form(self, form: Any) -> None:
143146
self.meta = form.meta
144147
return super().setup_form(form)
145148

146-
def generate_csrf_token(self, csrf_token_field):
149+
def generate_csrf_token(self, csrf_token_field: Any) -> str:
147150
return generate_csrf(
148151
secret_key=self.meta.csrf_secret, token_key=self.meta.csrf_field_name
149152
)
150153

151-
def validate_csrf_token(self, form, field):
154+
def validate_csrf_token(self, form: Any, field: Any) -> None:
152155
if g.get("csrf_valid", False):
153156
# already validated by CSRFProtect
154157
return
@@ -180,14 +183,14 @@ class CSRFProtect:
180183
See the :ref:`csrf` documentation.
181184
"""
182185

183-
def __init__(self, app=None):
184-
self._exempt_views = set()
185-
self._exempt_blueprints = set()
186+
def __init__(self, app: Optional[Flask] = None) -> None:
187+
self._exempt_views: Set[str] = set()
188+
self._exempt_blueprints: Set[Blueprint] = set()
186189

187190
if app:
188191
self.init_app(app)
189192

190-
def init_app(self, app):
193+
def init_app(self, app: Flask) -> None:
191194
app.extensions["csrf"] = self
192195

193196
app.config.setdefault("WTF_CSRF_ENABLED", True)
@@ -204,7 +207,7 @@ def init_app(self, app):
204207
app.context_processor(lambda: {"csrf_token": generate_csrf})
205208

206209
@app.before_request
207-
def csrf_protect():
210+
def csrf_protect() -> None:
208211
if not app.config["WTF_CSRF_ENABLED"]:
209212
return
210213

@@ -253,7 +256,7 @@ def _get_csrf_token(self):
253256

254257
return None
255258

256-
def protect(self):
259+
def protect(self) -> None:
257260
if request.method not in current_app.config["WTF_CSRF_METHODS"]:
258261
return
259262

@@ -274,7 +277,7 @@ def protect(self):
274277

275278
g.csrf_valid = True # mark this request as CSRF valid
276279

277-
def exempt(self, view):
280+
def exempt(self, view: Union[F, Blueprint, str]) -> Union[F, Blueprint, str]:
278281
"""Mark a view or blueprint to be excluded from CSRF protection.
279282
280283
::
@@ -303,7 +306,7 @@ def some_view():
303306
self._exempt_views.add(view_location)
304307
return view
305308

306-
def _error_response(self, reason):
309+
def _error_response(self, reason: str) -> None:
307310
raise CSRFError(reason)
308311

309312

@@ -318,7 +321,7 @@ class CSRFError(BadRequest):
318321
description = "CSRF validation failed."
319322

320323

321-
def same_origin(current_uri, compare_uri):
324+
def same_origin(current_uri: str, compare_uri: str) -> bool:
322325
current = urlparse(current_uri)
323326
compare = urlparse(compare_uri)
324327

0 commit comments

Comments
 (0)