1+ from __future__ import annotations
2+
13import binascii
24import json
35import warnings
4- from collections .abc import Mapping
5- from typing import Any , Dict , List , Optional , Type
6+ from typing import Any , Type
67
78from .algorithms import (
89 Algorithm ,
2324class PyJWS :
2425 header_typ = "JWT"
2526
26- def __init__ (self , algorithms = None , options = None ):
27+ def __init__ (self , algorithms = None , options = None ) -> None :
2728 self ._algorithms = get_default_algorithms ()
2829 self ._valid_algs = (
2930 set (algorithms ) if algorithms is not None else set (self ._algorithms )
@@ -39,10 +40,10 @@ def __init__(self, algorithms=None, options=None):
3940 self .options = {** self ._get_default_options (), ** options }
4041
4142 @staticmethod
42- def _get_default_options ():
43+ def _get_default_options () -> dict [ str , bool ] :
4344 return {"verify_signature" : True }
4445
45- def register_algorithm (self , alg_id , alg_obj ) :
46+ def register_algorithm (self , alg_id : str , alg_obj : Algorithm ) -> None :
4647 """
4748 Registers a new Algorithm for use when creating and verifying tokens.
4849 """
@@ -55,7 +56,7 @@ def register_algorithm(self, alg_id, alg_obj):
5556 self ._algorithms [alg_id ] = alg_obj
5657 self ._valid_algs .add (alg_id )
5758
58- def unregister_algorithm (self , alg_id ) :
59+ def unregister_algorithm (self , alg_id : str ) -> None :
5960 """
6061 Unregisters an Algorithm for use when creating and verifying tokens
6162 Throws KeyError if algorithm is not registered.
@@ -69,7 +70,7 @@ def unregister_algorithm(self, alg_id):
6970 del self ._algorithms [alg_id ]
7071 self ._valid_algs .remove (alg_id )
7172
72- def get_algorithms (self ):
73+ def get_algorithms (self ) -> list [ str ] :
7374 """
7475 Returns a list of supported values for the 'alg' parameter.
7576 """
@@ -96,9 +97,9 @@ def encode(
9697 self ,
9798 payload : bytes ,
9899 key : str ,
99- algorithm : Optional [ str ] = "HS256" ,
100- headers : Optional [ Dict [ str , Any ]] = None ,
101- json_encoder : Optional [ Type [json .JSONEncoder ]] = None ,
100+ algorithm : str | None = "HS256" ,
101+ headers : dict [ str , Any ] | None = None ,
102+ json_encoder : Type [json .JSONEncoder ] | None = None ,
102103 is_payload_detached : bool = False ,
103104 ) -> str :
104105 segments = []
@@ -117,7 +118,7 @@ def encode(
117118 is_payload_detached = True
118119
119120 # Header
120- header = {"typ" : self .header_typ , "alg" : algorithm_ } # type: Dict[str, Any]
121+ header : dict [ str , Any ] = {"typ" : self .header_typ , "alg" : algorithm_ }
121122
122123 if headers :
123124 self ._validate_headers (headers )
@@ -165,11 +166,11 @@ def decode_complete(
165166 self ,
166167 jwt : str ,
167168 key : str = "" ,
168- algorithms : Optional [ List [ str ]] = None ,
169- options : Optional [ Dict [ str , Any ]] = None ,
170- detached_payload : Optional [ bytes ] = None ,
169+ algorithms : list [ str ] | None = None ,
170+ options : dict [ str , Any ] | None = None ,
171+ detached_payload : bytes | None = None ,
171172 ** kwargs ,
172- ) -> Dict [str , Any ]:
173+ ) -> dict [str , Any ]:
173174 if kwargs :
174175 warnings .warn (
175176 "passing additional kwargs to decode_complete() is deprecated "
@@ -210,9 +211,9 @@ def decode(
210211 self ,
211212 jwt : str ,
212213 key : str = "" ,
213- algorithms : Optional [ List [ str ]] = None ,
214- options : Optional [ Dict [ str , Any ]] = None ,
215- detached_payload : Optional [ bytes ] = None ,
214+ algorithms : list [ str ] | None = None ,
215+ options : dict [ str , Any ] | None = None ,
216+ detached_payload : bytes | None = None ,
216217 ** kwargs ,
217218 ) -> str :
218219 if kwargs :
@@ -227,7 +228,7 @@ def decode(
227228 )
228229 return decoded ["payload" ]
229230
230- def get_unverified_header (self , jwt ) :
231+ def get_unverified_header (self , jwt : str | bytes ) -> dict :
231232 """Returns back the JWT header parameters as a dict()
232233
233234 Note: The signature is not verified so the header parameters
@@ -238,7 +239,7 @@ def get_unverified_header(self, jwt):
238239
239240 return headers
240241
241- def _load (self , jwt ) :
242+ def _load (self , jwt : str | bytes ) -> tuple [ bytes , bytes , dict , bytes ] :
242243 if isinstance (jwt , str ):
243244 jwt = jwt .encode ("utf-8" )
244245
@@ -261,7 +262,7 @@ def _load(self, jwt):
261262 except ValueError as e :
262263 raise DecodeError (f"Invalid header string: { e } " ) from e
263264
264- if not isinstance (header , Mapping ):
265+ if not isinstance (header , dict ):
265266 raise DecodeError ("Invalid header string: must be a json object" )
266267
267268 try :
@@ -278,16 +279,16 @@ def _load(self, jwt):
278279
279280 def _verify_signature (
280281 self ,
281- signing_input ,
282- header ,
283- signature ,
284- key = "" ,
285- algorithms = None ,
286- ):
282+ signing_input : bytes ,
283+ header : dict ,
284+ signature : bytes ,
285+ key : str = "" ,
286+ algorithms : list [ str ] | None = None ,
287+ ) -> None :
287288
288289 alg = header .get ("alg" )
289290
290- if algorithms is not None and alg not in algorithms :
291+ if not alg or ( algorithms is not None and alg not in algorithms ) :
291292 raise InvalidAlgorithmError ("The specified alg value is not allowed" )
292293
293294 try :
@@ -299,11 +300,11 @@ def _verify_signature(
299300 if not alg_obj .verify (signing_input , key , signature ):
300301 raise InvalidSignatureError ("Signature verification failed" )
301302
302- def _validate_headers (self , headers ) :
303+ def _validate_headers (self , headers : dict [ str , Any ]) -> None :
303304 if "kid" in headers :
304305 self ._validate_kid (headers ["kid" ])
305306
306- def _validate_kid (self , kid ) :
307+ def _validate_kid (self , kid : str ) -> None :
307308 if not isinstance (kid , str ):
308309 raise InvalidTokenError ("Key ID header parameter must be a string" )
309310
0 commit comments