11import time
2- from datetime import timedelta , datetime
3- from operator import sub , add
4- from typing import Union
2+ from datetime import datetime , timedelta
3+ from operator import add , sub
4+ from typing import Any , Callable , Dict , Optional , TYPE_CHECKING , TypeVar , \
5+ Union
56
67from jwt import PyJWT
8+
79from jwt_rsa .rsa import RSAPrivateKey , RSAPublicKey
810
11+ if TYPE_CHECKING :
12+ # pylama:ignore=E0602
13+ DateType = Union [timedelta , datetime , float , int , ellipsis ]
14+ else :
15+ DateType = Union [timedelta , datetime , float , int , type (Ellipsis )]
916
10- DateType = Union [ timedelta , datetime , float , int , type ( Ellipsis )]
17+ R = TypeVar ( 'R' )
1118
1219
1320class JWT :
@@ -21,9 +28,14 @@ class JWT:
2128 'ES521' , 'ES512' , 'PS256' , 'PS384' , 'PS512'
2229 })
2330
24- def __init__ (self , private_key : RSAPrivateKey = None ,
25- public_key : RSAPublicKey = None , expires = None ,
26- nbf_delta = None , algorithm = "RS512" ):
31+ def __init__ (
32+ self ,
33+ private_key : Optional [RSAPrivateKey ] = None ,
34+ public_key : Optional [RSAPublicKey ] = None ,
35+ expires : Optional [int ] = None ,
36+ nbf_delta : Optional [int ] = None ,
37+ algorithm : str = "RS512"
38+ ):
2739
2840 self .__private_key = private_key
2941 self .__public_key = public_key
@@ -32,7 +44,12 @@ def __init__(self, private_key: RSAPrivateKey=None,
3244 self .__nbf_delta = nbf_delta or self .NBF_DELTA
3345 self .__algorithm = algorithm
3446
35- def _date_to_timestamp (self , value , default , timedelta_func = add ):
47+ def _date_to_timestamp (
48+ self ,
49+ value : DateType ,
50+ default : Callable [[], R ],
51+ timedelta_func : Callable [[float , float ], int ] = add
52+ ) -> Union [int , float , R ]:
3653 if isinstance (value , timedelta ):
3754 return timedelta_func (time .time (), value .total_seconds ())
3855 elif isinstance (value , datetime ):
@@ -44,7 +61,12 @@ def _date_to_timestamp(self, value, default, timedelta_func=add):
4461
4562 raise ValueError (type (value ))
4663
47- def encode (self , expired : DateType = ..., nbf : DateType = ..., ** claims ) -> str :
64+ def encode (
65+ self ,
66+ expired : DateType = ...,
67+ nbf : DateType = ...,
68+ ** claims : int
69+ ) -> str :
4870 if not self .__private_key :
4971 raise RuntimeError ("Can't encode without private key" )
5072
@@ -72,7 +94,9 @@ def encode(self, expired: DateType=..., nbf: DateType=..., **claims) -> str:
7294 algorithm = self .__algorithm ,
7395 ).decode ()
7496
75- def decode (self , token : str , verify = True , ** kwargs ) -> dict :
97+ def decode (
98+ self , token : str , verify : bool = True , ** kwargs : Any
99+ ) -> Dict [str , Any ]:
76100 if not self .__public_key :
77101 raise RuntimeError ("Can't decode without public key" )
78102
0 commit comments