Skip to content

Commit d34228a

Browse files
committed
Add new API
1 parent 9baab51 commit d34228a

File tree

1 file changed

+106
-95
lines changed

1 file changed

+106
-95
lines changed

jwt_rsa/token.py

Lines changed: 106 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import time
2+
from dataclasses import dataclass, field
23
from datetime import datetime, timedelta
34
from operator import add, sub
45
from typing import (
5-
TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, TypeVar, Union,
6+
TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, TypeVar, Union, overload,
67
)
78

89
from jwt import PyJWT
@@ -19,103 +20,113 @@
1920

2021
R = TypeVar("R")
2122
DAY = 86400
22-
23-
24-
class JWT:
25-
__slots__ = (
26-
"__private_key", "__public_key", "__jwt",
27-
"__expires", "__nbf_delta", "__algorithm",
28-
"__algorithms",
29-
)
30-
31-
DEFAULT_EXPIRATION = timedelta(days=31).total_seconds()
32-
NBF_DELTA = 20
33-
ALGORITHMS = tuple(AlgorithmType.__args__)
23+
DEFAULT_EXPIRATION = timedelta(days=31).total_seconds()
24+
NBF_DELTA = 20
25+
ALGORITHMS = tuple(AlgorithmType.__args__)
26+
27+
28+
def date_to_timestamp(
29+
value: DateType,
30+
default: Callable[[], R],
31+
timedelta_func: Callable[[float, float], int] = add,
32+
) -> Union[int, float, R]:
33+
if isinstance(value, timedelta):
34+
return timedelta_func(time.time(), value.total_seconds())
35+
elif isinstance(value, datetime):
36+
return value.timestamp()
37+
elif isinstance(value, (int, float)):
38+
return value
39+
elif value is Ellipsis:
40+
return default()
41+
42+
raise ValueError(type(value))
43+
44+
45+
@dataclass(frozen=True, init=False)
46+
class JWTDecoder:
47+
jwt: PyJWT = field(repr=False, compare=False)
48+
public_key: RSAPublicKey = field(repr=False, compare=False)
49+
expires: Union[int, float]
50+
nbf_delta: Union[int, float]
51+
algorithm: AlgorithmType
52+
algorithms: Sequence[AlgorithmType]
3453

3554
def __init__(
3655
self,
37-
key: Optional[Union[RSAPrivateKey, RSAPublicKey]],
38-
*, expires: Optional[int] = None,
39-
nbf_delta: Optional[int] = None,
56+
key: RSAPublicKey,
57+
*, options: dict[str, Any] | None = None,
58+
expires: int | float = DEFAULT_EXPIRATION,
59+
nbf_delta: int | float = NBF_DELTA,
4060
algorithm: AlgorithmType = "RS512",
4161
algorithms: Sequence[AlgorithmType] = ALGORITHMS,
42-
options: Optional[Dict[str, Any]] = None,
4362
):
44-
self.__public_key: RSAPublicKey
45-
self.__private_key: Optional[RSAPrivateKey]
46-
47-
if isinstance(key, RSAPrivateKey):
48-
self.__public_key = key.public_key()
49-
self.__private_key = key
50-
elif isinstance(key, RSAPublicKey):
51-
self.__public_key = key
52-
self.__private_key = None
53-
else:
54-
raise ValueError("You must provide either a public or private key")
55-
56-
self.__jwt = PyJWT(options)
57-
self.__expires = expires or self.DEFAULT_EXPIRATION
58-
self.__nbf_delta = nbf_delta or self.NBF_DELTA
59-
self.__algorithm = algorithm
60-
self.__algorithms = list(algorithms)
61-
62-
@staticmethod
63-
def _date_to_timestamp(
64-
value: DateType,
65-
default: Callable[[], R],
66-
timedelta_func: Callable[[float, float], int] = add,
67-
) -> Union[int, float, R]:
68-
if isinstance(value, timedelta):
69-
return timedelta_func(time.time(), value.total_seconds())
70-
elif isinstance(value, datetime):
71-
return value.timestamp()
72-
elif isinstance(value, (int, float)):
73-
return value
74-
elif value is Ellipsis:
75-
return default()
76-
77-
raise ValueError(type(value))
78-
79-
def encode(
80-
self,
81-
expired: DateType = ...,
82-
nbf: DateType = ...,
83-
**claims: Any,
84-
) -> str:
85-
if not self.__private_key:
86-
raise RuntimeError("Can't encode without private key")
87-
88-
claims.update(
89-
dict(
90-
exp=int(
91-
self._date_to_timestamp(
92-
expired,
93-
lambda: time.time() + self.__expires,
94-
),
95-
),
96-
nbf=int(
97-
self._date_to_timestamp(
98-
nbf,
99-
lambda: time.time() - self.__nbf_delta,
100-
timedelta_func=sub,
101-
),
102-
),
103-
),
104-
)
105-
106-
return self.__jwt.encode(
107-
claims,
108-
self.__private_key,
109-
algorithm=self.__algorithm,
110-
)
111-
112-
def decode(
113-
self, token: str, verify: bool = True, **kwargs: Any,
114-
) -> Dict[str, Any]:
115-
return self.__jwt.decode(
116-
token,
117-
key=self.__public_key,
118-
verify=verify,
119-
algorithms=self.__algorithms,
120-
**kwargs,
121-
)
63+
super().__setattr__('public_key', key)
64+
super().__setattr__('jwt', PyJWT(options))
65+
super().__setattr__('expires', expires)
66+
super().__setattr__('nbf_delta', nbf_delta)
67+
super().__setattr__('algorithm', algorithm)
68+
super().__setattr__('algorithms', algorithms)
69+
70+
def decode(self, token: str, verify: bool = True, **kwargs: Any) -> Dict[str, Any]:
71+
return self.jwt.decode(token, key=self.public_key, verify=verify, algorithms=self.algorithms, **kwargs)
72+
73+
74+
@dataclass(frozen=True, init=False)
75+
class JWTSigner(JWTDecoder):
76+
private_key: RSAPrivateKey = field(repr=False, compare=False)
77+
78+
def __init__(self, key: RSAPrivateKey, *, options: Optional[Dict[str, Any]] = None, **kwargs: Any):
79+
super(JWTDecoder, self).__setattr__('private_key', key)
80+
super().__init__(key.public_key(), options=options, **kwargs)
81+
82+
def encode(self, expired: DateType = ..., nbf: DateType = ..., **claims: Any) -> str:
83+
claims.setdefault('exp', int(date_to_timestamp(expired, lambda: time.time() + self.expires)))
84+
claims.setdefault('nbf', int(date_to_timestamp(nbf, lambda: time.time() - self.nbf_delta, timedelta_func=sub)))
85+
return self.jwt.encode(claims, self.private_key, algorithm=self.algorithm)
86+
87+
88+
@overload
89+
def JWT(
90+
key: RSAPrivateKey, *,
91+
options: dict[str, Any] | None = None,
92+
expires: int | float = DEFAULT_EXPIRATION,
93+
nbf_delta: int | float = NBF_DELTA,
94+
algorithm: AlgorithmType = "RS512",
95+
algorithms: Sequence[AlgorithmType] = ALGORITHMS,
96+
) -> JWTSigner: ...
97+
98+
99+
@overload
100+
def JWT( # type: ignore[overload-cannot-match]
101+
key: RSAPublicKey, *,
102+
options: dict[str, Any] | None = None,
103+
expires: int | float = DEFAULT_EXPIRATION,
104+
nbf_delta: int | float = NBF_DELTA,
105+
algorithm: AlgorithmType = "RS512",
106+
algorithms: Sequence[AlgorithmType] = ALGORITHMS,
107+
) -> JWTDecoder: ...
108+
109+
110+
def JWT(
111+
key: Union[RSAPrivateKey, RSAPublicKey],
112+
*,
113+
options: dict[str, Any] | None = None,
114+
expires: int | float = DEFAULT_EXPIRATION,
115+
nbf_delta: int | float = NBF_DELTA,
116+
algorithm: AlgorithmType = "RS512",
117+
algorithms: Sequence[AlgorithmType] = ALGORITHMS,
118+
) -> Union[JWTSigner, JWTDecoder]:
119+
kwargs = dict(
120+
expires=expires,
121+
nbf_delta=nbf_delta,
122+
algorithm=algorithm,
123+
algorithms=algorithms,
124+
options=options,
125+
)
126+
127+
if isinstance(key, RSAPrivateKey):
128+
return JWTSigner(key, **kwargs)
129+
elif isinstance(key, RSAPublicKey):
130+
return JWTDecoder(key, **kwargs)
131+
else:
132+
raise TypeError(f"Invalid key type: {type(key)}")

0 commit comments

Comments
 (0)