22import hmac
33import logging
44import os
5+ from typing import Any , Callable , Optional , Set , Union , TypeVar
56from urllib .parse import urlparse
67
7- from flask import Blueprint
8+ from flask import Blueprint , Flask
89from flask import current_app
910from flask import g
1011from flask import request
1920__all__ = ("generate_csrf" , "validate_csrf" , "CSRFProtect" )
2021logger = logging .getLogger (__name__ )
2122
23+ F = TypeVar ('F' , bound = Callable [..., object ])
2224
23- def generate_csrf (secret_key = None , token_key = None ):
25+
26+ def generate_csrf (secret_key : Optional [str ] = None , token_key : Optional [str ] = None ) -> str :
2427 """Generate a CSRF token. The token is cached for a request, so multiple
2528 calls to this function will generate the same token.
2629
@@ -63,7 +66,7 @@ def generate_csrf(secret_key=None, token_key=None):
6366 return g .get (field_name )
6467
6568
66- def validate_csrf (data , secret_key = None , time_limit = None , token_key = None ):
69+ def validate_csrf (data : str , secret_key : Optional [ str ] = None , time_limit : Optional [ int ] = None , token_key : Optional [ str ] = None ) -> None :
6770 """Check if the given data is a valid CSRF token. This compares the given
6871 signed token to the one stored in the session.
6972
@@ -116,8 +119,8 @@ def validate_csrf(data, secret_key=None, time_limit=None, token_key=None):
116119
117120
118121def _get_config (
119- value , config_name , default = None , required = True , message = "CSRF is not configured."
120- ):
122+ value : Any , config_name : str , default : Any = None , required : bool = True , message : str = "CSRF is not configured."
123+ ) -> Any :
121124 """Find config value based on provided value, Flask config, and default
122125 value.
123126
@@ -139,16 +142,16 @@ def _get_config(
139142
140143
141144class _FlaskFormCSRF (CSRF ):
142- def setup_form (self , form ) :
145+ def setup_form (self , form : Any ) -> None :
143146 self .meta = form .meta
144147 return super ().setup_form (form )
145148
146- def generate_csrf_token (self , csrf_token_field ) :
149+ def generate_csrf_token (self , csrf_token_field : Any ) -> str :
147150 return generate_csrf (
148151 secret_key = self .meta .csrf_secret , token_key = self .meta .csrf_field_name
149152 )
150153
151- def validate_csrf_token (self , form , field ) :
154+ def validate_csrf_token (self , form : Any , field : Any ) -> None :
152155 if g .get ("csrf_valid" , False ):
153156 # already validated by CSRFProtect
154157 return
@@ -180,14 +183,14 @@ class CSRFProtect:
180183 See the :ref:`csrf` documentation.
181184 """
182185
183- def __init__ (self , app = None ):
184- self ._exempt_views = set ()
185- self ._exempt_blueprints = set ()
186+ def __init__ (self , app : Optional [ Flask ] = None ) -> None :
187+ self ._exempt_views : Set [ str ] = set ()
188+ self ._exempt_blueprints : Set [ Blueprint ] = set ()
186189
187190 if app :
188191 self .init_app (app )
189192
190- def init_app (self , app ) :
193+ def init_app (self , app : Flask ) -> None :
191194 app .extensions ["csrf" ] = self
192195
193196 app .config .setdefault ("WTF_CSRF_ENABLED" , True )
@@ -204,7 +207,7 @@ def init_app(self, app):
204207 app .context_processor (lambda : {"csrf_token" : generate_csrf })
205208
206209 @app .before_request
207- def csrf_protect ():
210+ def csrf_protect () -> None :
208211 if not app .config ["WTF_CSRF_ENABLED" ]:
209212 return
210213
@@ -253,7 +256,7 @@ def _get_csrf_token(self):
253256
254257 return None
255258
256- def protect (self ):
259+ def protect (self ) -> None :
257260 if request .method not in current_app .config ["WTF_CSRF_METHODS" ]:
258261 return
259262
@@ -274,7 +277,7 @@ def protect(self):
274277
275278 g .csrf_valid = True # mark this request as CSRF valid
276279
277- def exempt (self , view ) :
280+ def exempt (self , view : Union [ F , Blueprint , str ]) -> Union [ F , Blueprint , str ] :
278281 """Mark a view or blueprint to be excluded from CSRF protection.
279282
280283 ::
@@ -303,7 +306,7 @@ def some_view():
303306 self ._exempt_views .add (view_location )
304307 return view
305308
306- def _error_response (self , reason ) :
309+ def _error_response (self , reason : str ) -> None :
307310 raise CSRFError (reason )
308311
309312
@@ -318,7 +321,7 @@ class CSRFError(BadRequest):
318321 description = "CSRF validation failed."
319322
320323
321- def same_origin (current_uri , compare_uri ) :
324+ def same_origin (current_uri : str , compare_uri : str ) -> bool :
322325 current = urlparse (current_uri )
323326 compare = urlparse (compare_uri )
324327
0 commit comments