|
1 | 1 | import time |
| 2 | +from dataclasses import dataclass, field |
2 | 3 | from datetime import datetime, timedelta |
3 | 4 | from operator import add, sub |
4 | 5 | from typing import ( |
5 | | - TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, TypeVar, Union, |
| 6 | + TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, TypeVar, Union, overload, |
6 | 7 | ) |
7 | 8 |
|
8 | 9 | from jwt import PyJWT |
|
19 | 20 |
|
20 | 21 | R = TypeVar("R") |
21 | 22 | DAY = 86400 |
22 | | - |
23 | | - |
24 | | -class JWT: |
25 | | - __slots__ = ( |
26 | | - "__private_key", "__public_key", "__jwt", |
27 | | - "__expires", "__nbf_delta", "__algorithm", |
28 | | - "__algorithms", |
29 | | - ) |
30 | | - |
31 | | - DEFAULT_EXPIRATION = timedelta(days=31).total_seconds() |
32 | | - NBF_DELTA = 20 |
33 | | - ALGORITHMS = tuple(AlgorithmType.__args__) |
| 23 | +DEFAULT_EXPIRATION = timedelta(days=31).total_seconds() |
| 24 | +NBF_DELTA = 20 |
| 25 | +ALGORITHMS = tuple(AlgorithmType.__args__) |
| 26 | + |
| 27 | + |
| 28 | +def date_to_timestamp( |
| 29 | + value: DateType, |
| 30 | + default: Callable[[], R], |
| 31 | + timedelta_func: Callable[[float, float], int] = add, |
| 32 | +) -> Union[int, float, R]: |
| 33 | + if isinstance(value, timedelta): |
| 34 | + return timedelta_func(time.time(), value.total_seconds()) |
| 35 | + elif isinstance(value, datetime): |
| 36 | + return value.timestamp() |
| 37 | + elif isinstance(value, (int, float)): |
| 38 | + return value |
| 39 | + elif value is Ellipsis: |
| 40 | + return default() |
| 41 | + |
| 42 | + raise ValueError(type(value)) |
| 43 | + |
| 44 | + |
| 45 | +@dataclass(frozen=True, init=False) |
| 46 | +class JWTDecoder: |
| 47 | + jwt: PyJWT = field(repr=False, compare=False) |
| 48 | + public_key: RSAPublicKey = field(repr=False, compare=False) |
| 49 | + expires: Union[int, float] |
| 50 | + nbf_delta: Union[int, float] |
| 51 | + algorithm: AlgorithmType |
| 52 | + algorithms: Sequence[AlgorithmType] |
34 | 53 |
|
35 | 54 | def __init__( |
36 | 55 | self, |
37 | | - key: Optional[Union[RSAPrivateKey, RSAPublicKey]], |
38 | | - *, expires: Optional[int] = None, |
39 | | - nbf_delta: Optional[int] = None, |
| 56 | + key: RSAPublicKey, |
| 57 | + *, options: dict[str, Any] | None = None, |
| 58 | + expires: int | float = DEFAULT_EXPIRATION, |
| 59 | + nbf_delta: int | float = NBF_DELTA, |
40 | 60 | algorithm: AlgorithmType = "RS512", |
41 | 61 | algorithms: Sequence[AlgorithmType] = ALGORITHMS, |
42 | | - options: Optional[Dict[str, Any]] = None, |
43 | 62 | ): |
44 | | - self.__public_key: RSAPublicKey |
45 | | - self.__private_key: Optional[RSAPrivateKey] |
46 | | - |
47 | | - if isinstance(key, RSAPrivateKey): |
48 | | - self.__public_key = key.public_key() |
49 | | - self.__private_key = key |
50 | | - elif isinstance(key, RSAPublicKey): |
51 | | - self.__public_key = key |
52 | | - self.__private_key = None |
53 | | - else: |
54 | | - raise ValueError("You must provide either a public or private key") |
55 | | - |
56 | | - self.__jwt = PyJWT(options) |
57 | | - self.__expires = expires or self.DEFAULT_EXPIRATION |
58 | | - self.__nbf_delta = nbf_delta or self.NBF_DELTA |
59 | | - self.__algorithm = algorithm |
60 | | - self.__algorithms = list(algorithms) |
61 | | - |
62 | | - @staticmethod |
63 | | - def _date_to_timestamp( |
64 | | - value: DateType, |
65 | | - default: Callable[[], R], |
66 | | - timedelta_func: Callable[[float, float], int] = add, |
67 | | - ) -> Union[int, float, R]: |
68 | | - if isinstance(value, timedelta): |
69 | | - return timedelta_func(time.time(), value.total_seconds()) |
70 | | - elif isinstance(value, datetime): |
71 | | - return value.timestamp() |
72 | | - elif isinstance(value, (int, float)): |
73 | | - return value |
74 | | - elif value is Ellipsis: |
75 | | - return default() |
76 | | - |
77 | | - raise ValueError(type(value)) |
78 | | - |
79 | | - def encode( |
80 | | - self, |
81 | | - expired: DateType = ..., |
82 | | - nbf: DateType = ..., |
83 | | - **claims: Any, |
84 | | - ) -> str: |
85 | | - if not self.__private_key: |
86 | | - raise RuntimeError("Can't encode without private key") |
87 | | - |
88 | | - claims.update( |
89 | | - dict( |
90 | | - exp=int( |
91 | | - self._date_to_timestamp( |
92 | | - expired, |
93 | | - lambda: time.time() + self.__expires, |
94 | | - ), |
95 | | - ), |
96 | | - nbf=int( |
97 | | - self._date_to_timestamp( |
98 | | - nbf, |
99 | | - lambda: time.time() - self.__nbf_delta, |
100 | | - timedelta_func=sub, |
101 | | - ), |
102 | | - ), |
103 | | - ), |
104 | | - ) |
105 | | - |
106 | | - return self.__jwt.encode( |
107 | | - claims, |
108 | | - self.__private_key, |
109 | | - algorithm=self.__algorithm, |
110 | | - ) |
111 | | - |
112 | | - def decode( |
113 | | - self, token: str, verify: bool = True, **kwargs: Any, |
114 | | - ) -> Dict[str, Any]: |
115 | | - return self.__jwt.decode( |
116 | | - token, |
117 | | - key=self.__public_key, |
118 | | - verify=verify, |
119 | | - algorithms=self.__algorithms, |
120 | | - **kwargs, |
121 | | - ) |
| 63 | + super().__setattr__('public_key', key) |
| 64 | + super().__setattr__('jwt', PyJWT(options)) |
| 65 | + super().__setattr__('expires', expires) |
| 66 | + super().__setattr__('nbf_delta', nbf_delta) |
| 67 | + super().__setattr__('algorithm', algorithm) |
| 68 | + super().__setattr__('algorithms', algorithms) |
| 69 | + |
| 70 | + def decode(self, token: str, verify: bool = True, **kwargs: Any) -> Dict[str, Any]: |
| 71 | + return self.jwt.decode(token, key=self.public_key, verify=verify, algorithms=self.algorithms, **kwargs) |
| 72 | + |
| 73 | + |
| 74 | +@dataclass(frozen=True, init=False) |
| 75 | +class JWTSigner(JWTDecoder): |
| 76 | + private_key: RSAPrivateKey = field(repr=False, compare=False) |
| 77 | + |
| 78 | + def __init__(self, key: RSAPrivateKey, *, options: Optional[Dict[str, Any]] = None, **kwargs: Any): |
| 79 | + super(JWTDecoder, self).__setattr__('private_key', key) |
| 80 | + super().__init__(key.public_key(), options=options, **kwargs) |
| 81 | + |
| 82 | + def encode(self, expired: DateType = ..., nbf: DateType = ..., **claims: Any) -> str: |
| 83 | + claims.setdefault('exp', int(date_to_timestamp(expired, lambda: time.time() + self.expires))) |
| 84 | + claims.setdefault('nbf', int(date_to_timestamp(nbf, lambda: time.time() - self.nbf_delta, timedelta_func=sub))) |
| 85 | + return self.jwt.encode(claims, self.private_key, algorithm=self.algorithm) |
| 86 | + |
| 87 | + |
| 88 | +@overload |
| 89 | +def JWT( |
| 90 | + key: RSAPrivateKey, *, |
| 91 | + options: dict[str, Any] | None = None, |
| 92 | + expires: int | float = DEFAULT_EXPIRATION, |
| 93 | + nbf_delta: int | float = NBF_DELTA, |
| 94 | + algorithm: AlgorithmType = "RS512", |
| 95 | + algorithms: Sequence[AlgorithmType] = ALGORITHMS, |
| 96 | +) -> JWTSigner: ... |
| 97 | + |
| 98 | + |
| 99 | +@overload |
| 100 | +def JWT( # type: ignore[overload-cannot-match] |
| 101 | + key: RSAPublicKey, *, |
| 102 | + options: dict[str, Any] | None = None, |
| 103 | + expires: int | float = DEFAULT_EXPIRATION, |
| 104 | + nbf_delta: int | float = NBF_DELTA, |
| 105 | + algorithm: AlgorithmType = "RS512", |
| 106 | + algorithms: Sequence[AlgorithmType] = ALGORITHMS, |
| 107 | +) -> JWTDecoder: ... |
| 108 | + |
| 109 | + |
| 110 | +def JWT( |
| 111 | + key: Union[RSAPrivateKey, RSAPublicKey], |
| 112 | + *, |
| 113 | + options: dict[str, Any] | None = None, |
| 114 | + expires: int | float = DEFAULT_EXPIRATION, |
| 115 | + nbf_delta: int | float = NBF_DELTA, |
| 116 | + algorithm: AlgorithmType = "RS512", |
| 117 | + algorithms: Sequence[AlgorithmType] = ALGORITHMS, |
| 118 | +) -> Union[JWTSigner, JWTDecoder]: |
| 119 | + kwargs = dict( |
| 120 | + expires=expires, |
| 121 | + nbf_delta=nbf_delta, |
| 122 | + algorithm=algorithm, |
| 123 | + algorithms=algorithms, |
| 124 | + options=options, |
| 125 | + ) |
| 126 | + |
| 127 | + if isinstance(key, RSAPrivateKey): |
| 128 | + return JWTSigner(key, **kwargs) |
| 129 | + elif isinstance(key, RSAPublicKey): |
| 130 | + return JWTDecoder(key, **kwargs) |
| 131 | + else: |
| 132 | + raise TypeError(f"Invalid key type: {type(key)}") |
0 commit comments