Skip to content

Commit 5ecd789

Browse files
authored
Merge pull request #4 from mosquito/bugfixes
Refactor JWT Class into JWTDecoder and JWTSigner, Improve Type Safety and Test Coverage
2 parents 9baab51 + a213617 commit 5ecd789

File tree

11 files changed

+213
-174
lines changed

11 files changed

+213
-174
lines changed

.github/workflows/pythonpackage.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,32 @@ on:
77
branches: [ master ]
88

99
jobs:
10+
mypy:
11+
runs-on: ubuntu-latest
12+
strategy:
13+
fail-fast: false
14+
15+
steps:
16+
- uses: actions/checkout@v2
17+
18+
- name: Setup python3.10
19+
uses: actions/setup-python@v2
20+
with:
21+
python-version: "3.10"
22+
23+
- name: Install poetry
24+
run: python -m pip install poetry
25+
26+
- name: Install dependencies
27+
run: poetry install
28+
env:
29+
FORCE_COLOR: yes
30+
31+
- name: Run mypy
32+
run: poetry run mypy jwt_rsa
33+
env:
34+
FORCE_COLOR: yes
35+
1036
tests:
1137
runs-on: ubuntu-latest
1238
strategy:

jwt_rsa/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
RSAJWKPrivateKey, RSAJWKPublicKey, generate_rsa, load_private_key,
33
load_public_key, rsa_to_jwk,
44
)
5-
from .token import JWT
5+
from .token import JWT, JWTDecoder, JWTSigner
66
from .types import RSAPrivateKey, RSAPublicKey
77

88

99
__all__ = (
1010
"JWT",
11+
"JWTDecoder",
12+
"JWTSigner",
1113
"RSAJWKPrivateKey",
1214
"RSAJWKPublicKey",
1315
"RSAPrivateKey",

jwt_rsa/cli.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
from argparse import ArgumentParser
33
from pathlib import Path
44

5-
from jwt_rsa.types import AlgorithmType
6-
75
from . import convert, issue, key_tester, keygen, pubkey, verify
6+
from .token import ALGORITHMS
87

98

109
parser = ArgumentParser()
@@ -20,7 +19,7 @@
2019
"--kid", dest="kid", type=str, default="", help="Key ID, will be generated if missing",
2120
)
2221
keygen_parser.add_argument(
23-
"-a", "--algorithm", choices=AlgorithmType.__args__,
22+
"-a", "--algorithm", choices=ALGORITHMS,
2423
help="Key ID, will be generated if missing", default="RS512",
2524
)
2625
keygen_parser.add_argument("-u", "--use", dest="use", type=str, default="sig", choices=["sig", "enc"])

jwt_rsa/issue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969

7070

7171
def main(arguments: SimpleNamespace) -> None:
72-
jwt = JWT(private_key=load_private_key(arguments.private_key))
72+
jwt = JWT(load_private_key(arguments.private_key))
7373

7474
whoami = pwd.getpwuid(os.getuid())
7575

jwt_rsa/rsa.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import base64
22
import json
33
from pathlib import Path
4-
from typing import NamedTuple, Optional, TypedDict, Union, overload
4+
from typing import NamedTuple, Optional, TypedDict, overload
55

66
from cryptography.hazmat.backends import default_backend
77

@@ -11,6 +11,11 @@
1111

1212

1313
class KeyPair(NamedTuple):
14+
private: RSAPrivateKey
15+
public: RSAPublicKey
16+
17+
18+
class JWKKeyPair(NamedTuple):
1419
private: Optional[RSAPrivateKey]
1520
public: RSAPublicKey
1621

@@ -80,8 +85,8 @@ def load_jwk_private_key(jwk: RSAJWKPrivateKey) -> RSAPrivateKey:
8085
return private_numbers.private_key(default_backend())
8186

8287

83-
def load_jwk(jwk: Union[RSAJWKPublicKey, RSAJWKPrivateKey, str]) -> KeyPair:
84-
jwk_dict: Union[RSAJWKPublicKey, RSAJWKPrivateKey]
88+
def load_jwk(jwk: RSAJWKPublicKey | RSAJWKPrivateKey | str) -> JWKKeyPair:
89+
jwk_dict: RSAJWKPublicKey | RSAJWKPrivateKey
8590

8691
if isinstance(jwk, str):
8792
jwk_dict = json.loads(jwk)
@@ -92,10 +97,10 @@ def load_jwk(jwk: Union[RSAJWKPublicKey, RSAJWKPrivateKey, str]) -> KeyPair:
9297
private_key = load_jwk_private_key(jwk_dict) # type: ignore
9398
public_key = private_key.public_key()
9499
else: # Public key
95-
public_key = load_jwk_public_key(jwk_dict) # type: ignore
100+
public_key = load_jwk_public_key(jwk_dict)
96101
private_key = None
97102

98-
return KeyPair(private=private_key, public=public_key)
103+
return JWKKeyPair(private=private_key, public=public_key)
99104

100105

101106
def int_to_base64url(value: int) -> str:
@@ -106,24 +111,24 @@ def int_to_base64url(value: int) -> str:
106111

107112
@overload
108113
def rsa_to_jwk(
109-
key: RSAPublicKey, *, kid: str = "", alg: AlgorithmType = "RS256", use: str = "sig"
114+
key: RSAPublicKey, *, kid: str = "", alg: AlgorithmType = "RS256", use: str = "sig",
110115
) -> RSAJWKPublicKey: ...
111116

112117

113118
@overload
114-
def rsa_to_jwk( # type: ignore[overload-cannot-match]
119+
def rsa_to_jwk(
115120
key: RSAPrivateKey, *, kid: str = "", alg: AlgorithmType = "RS256", use: str = "sig",
116121
) -> RSAJWKPrivateKey: ...
117122

118123

119124
def rsa_to_jwk(
120-
key: Union[RSAPrivateKey, RSAPublicKey],
125+
key: RSAPrivateKey | RSAPublicKey,
121126
*,
122127
kid: str = "",
123128
alg: AlgorithmType = "RS256",
124129
use: str = "sig",
125130
kty: str = "RSA",
126-
) -> Union[RSAJWKPublicKey, RSAJWKPrivateKey]:
131+
) -> RSAJWKPublicKey | RSAJWKPrivateKey:
127132
if isinstance(key, RSAPublicKey):
128133
public_numbers = key.public_numbers()
129134
private_numbers = None
@@ -161,12 +166,14 @@ def rsa_to_jwk(
161166
)
162167

163168

164-
def load_private_key(data: Union[str, RSAJWKPrivateKey, Path]) -> RSAPrivateKey:
169+
def load_private_key(data: str | RSAJWKPrivateKey | Path) -> RSAPrivateKey:
165170
if isinstance(data, Path):
166171
data = data.read_text()
167172
if isinstance(data, str):
168173
if data.startswith("-----BEGIN "):
169-
return serialization.load_pem_private_key(data.encode(), None, default_backend())
174+
result = serialization.load_pem_private_key(data.encode(), None, default_backend())
175+
assert isinstance(result, RSAPrivateKey)
176+
return result
170177
if data.strip().startswith("{"):
171178
return load_jwk_private_key(json.loads(data))
172179
if isinstance(data, dict):
@@ -177,12 +184,14 @@ def load_private_key(data: Union[str, RSAJWKPrivateKey, Path]) -> RSAPrivateKey:
177184
return key
178185

179186

180-
def load_public_key(data: Union[str, RSAJWKPublicKey, Path]) -> RSAPublicKey:
187+
def load_public_key(data: str | RSAJWKPublicKey | Path) -> RSAPublicKey:
181188
if isinstance(data, Path):
182189
data = data.read_text()
183190
if isinstance(data, str):
184191
if data.startswith("-----BEGIN "):
185-
return serialization.load_pem_public_key(data.encode(), default_backend())
192+
result = serialization.load_pem_public_key(data.encode(), default_backend())
193+
assert isinstance(result, RSAPublicKey)
194+
return result
186195
if data.strip().startswith("{"):
187196
return load_jwk_public_key(json.loads(data))
188197
if isinstance(data, dict):

0 commit comments

Comments
 (0)