|
12 | 12 |
|
13 | 13 | class JWT: |
14 | 14 | __slots__ = ('__private_key', '__public_key', '__jwt', |
15 | | - '__expires', '__nbf_delta') |
| 15 | + '__expires', '__nbf_delta', '__algorithm') |
16 | 16 |
|
17 | 17 | DEFAULT_EXPIRATION = 86400 * 30 # one month |
18 | 18 | NBF_DELTA = 20 |
| 19 | + ALGORITHMS = tuple({ |
| 20 | + 'RS256', 'RS384', 'RS512', 'ES256', 'ES384', |
| 21 | + 'ES521', 'ES512', 'PS256', 'PS384', 'PS512' |
| 22 | + }) |
19 | 23 |
|
20 | 24 | def __init__(self, private_key: RSAPrivateKey=None, |
21 | 25 | public_key: RSAPublicKey=None, expires=None, |
22 | | - nbf_delta=None): |
| 26 | + nbf_delta=None, algorithm="RS512"): |
23 | 27 |
|
24 | 28 | self.__private_key = private_key |
25 | 29 | self.__public_key = public_key |
26 | | - self.__jwt = PyJWT(algorithms={'RS512'}) |
| 30 | + self.__jwt = PyJWT(algorithms=self.ALGORITHMS) |
27 | 31 | self.__expires = expires or self.DEFAULT_EXPIRATION |
28 | 32 | self.__nbf_delta = nbf_delta or self.NBF_DELTA |
| 33 | + self.__algorithm = algorithm |
29 | 34 |
|
30 | 35 | def _date_to_timestamp(self, value, default, timedelta_func=add): |
31 | 36 | if isinstance(value, timedelta): |
@@ -64,18 +69,18 @@ def encode(self, expired: DateType=..., nbf: DateType=..., **claims) -> str: |
64 | 69 | return self.__jwt.encode( |
65 | 70 | claims, |
66 | 71 | self.__private_key, |
67 | | - algorithm='RS512', |
| 72 | + algorithm=self.__algorithm, |
68 | 73 | ).decode() |
69 | 74 |
|
70 | 75 | def decode(self, token: str, verify=True, **kwargs) -> dict: |
71 | 76 | if not self.__public_key: |
72 | 77 | raise RuntimeError("Can't decode without public key") |
73 | 78 |
|
74 | 79 | return self.__jwt.decode( |
75 | | - token.encode(), |
| 80 | + token, |
76 | 81 | key=self.__public_key, |
77 | 82 | verify=verify, |
78 | | - algorithms={'RS512'}, |
| 83 | + algorithms=self.ALGORITHMS, |
79 | 84 | **kwargs |
80 | 85 | ) |
81 | 86 |
|
|
0 commit comments