Skip to content

Commit db098e2

Browse files
committed
[fix] add optional algorithm argument
1 parent 2bd643e commit db098e2

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

jwt_rsa/token.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,25 @@
1212

1313
class JWT:
1414
__slots__ = ('__private_key', '__public_key', '__jwt',
15-
'__expires', '__nbf_delta')
15+
'__expires', '__nbf_delta', '__algorithm')
1616

1717
DEFAULT_EXPIRATION = 86400 * 30 # one month
1818
NBF_DELTA = 20
19+
ALGORITHMS = tuple({
20+
'RS256', 'RS384', 'RS512', 'ES256', 'ES384',
21+
'ES521', 'ES512', 'PS256', 'PS384', 'PS512'
22+
})
1923

2024
def __init__(self, private_key: RSAPrivateKey=None,
2125
public_key: RSAPublicKey=None, expires=None,
22-
nbf_delta=None):
26+
nbf_delta=None, algorithm="RS512"):
2327

2428
self.__private_key = private_key
2529
self.__public_key = public_key
26-
self.__jwt = PyJWT(algorithms={'RS512'})
30+
self.__jwt = PyJWT(algorithms=self.ALGORITHMS)
2731
self.__expires = expires or self.DEFAULT_EXPIRATION
2832
self.__nbf_delta = nbf_delta or self.NBF_DELTA
33+
self.__algorithm = algorithm
2934

3035
def _date_to_timestamp(self, value, default, timedelta_func=add):
3136
if isinstance(value, timedelta):
@@ -64,18 +69,18 @@ def encode(self, expired: DateType=..., nbf: DateType=..., **claims) -> str:
6469
return self.__jwt.encode(
6570
claims,
6671
self.__private_key,
67-
algorithm='RS512',
72+
algorithm=self.__algorithm,
6873
).decode()
6974

7075
def decode(self, token: str, verify=True, **kwargs) -> dict:
7176
if not self.__public_key:
7277
raise RuntimeError("Can't decode without public key")
7378

7479
return self.__jwt.decode(
75-
token.encode(),
80+
token,
7681
key=self.__public_key,
7782
verify=verify,
78-
algorithms={'RS512'},
83+
algorithms=self.ALGORITHMS,
7984
**kwargs
8085
)
8186

0 commit comments

Comments
 (0)