Skip to content

Commit b229393

Browse files
committed
Separate mypy linter rule
1 parent 49b1847 commit b229393

File tree

8 files changed

+73
-44
lines changed

8 files changed

+73
-44
lines changed

.github/workflows/pythonpackage.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,32 @@ on:
77
branches: [ master ]
88

99
jobs:
10+
mypy:
11+
runs-on: ubuntu-latest
12+
strategy:
13+
fail-fast: false
14+
15+
steps:
16+
- uses: actions/checkout@v2
17+
18+
- name: Setup python3.10
19+
uses: actions/setup-python@v2
20+
with:
21+
python-version: "3.10"
22+
23+
- name: Install poetry
24+
run: python -m pip install poetry
25+
26+
- name: Install dependencies
27+
run: poetry install
28+
env:
29+
FORCE_COLOR: yes
30+
31+
- name: Run mypy
32+
run: poetry run mypy jwt_rsa
33+
env:
34+
FORCE_COLOR: yes
35+
1036
tests:
1137
runs-on: ubuntu-latest
1238
strategy:

jwt_rsa/cli.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
from argparse import ArgumentParser
33
from pathlib import Path
44

5-
from jwt_rsa.types import AlgorithmType
6-
75
from . import convert, issue, key_tester, keygen, pubkey, verify
6+
from .token import ALGORITHMS
87

98

109
parser = ArgumentParser()
@@ -20,7 +19,7 @@
2019
"--kid", dest="kid", type=str, default="", help="Key ID, will be generated if missing",
2120
)
2221
keygen_parser.add_argument(
23-
"-a", "--algorithm", choices=AlgorithmType.__args__,
22+
"-a", "--algorithm", choices=ALGORITHMS,
2423
help="Key ID, will be generated if missing", default="RS512",
2524
)
2625
keygen_parser.add_argument("-u", "--use", dest="use", type=str, default="sig", choices=["sig", "enc"])

jwt_rsa/issue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969

7070

7171
def main(arguments: SimpleNamespace) -> None:
72-
jwt = JWT(private_key=load_private_key(arguments.private_key))
72+
jwt = JWT(load_private_key(arguments.private_key))
7373

7474
whoami = pwd.getpwuid(os.getuid())
7575

jwt_rsa/rsa.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import base64
22
import json
33
from pathlib import Path
4-
from typing import NamedTuple, Optional, TypedDict, Union, overload
4+
from typing import NamedTuple, Optional, TypedDict, overload
55

66
from cryptography.hazmat.backends import default_backend
77

@@ -11,6 +11,11 @@
1111

1212

1313
class KeyPair(NamedTuple):
14+
private: RSAPrivateKey
15+
public: RSAPublicKey
16+
17+
18+
class JWKKeyPair(NamedTuple):
1419
private: Optional[RSAPrivateKey]
1520
public: RSAPublicKey
1621

@@ -80,8 +85,8 @@ def load_jwk_private_key(jwk: RSAJWKPrivateKey) -> RSAPrivateKey:
8085
return private_numbers.private_key(default_backend())
8186

8287

83-
def load_jwk(jwk: Union[RSAJWKPublicKey, RSAJWKPrivateKey, str]) -> KeyPair:
84-
jwk_dict: Union[RSAJWKPublicKey, RSAJWKPrivateKey]
88+
def load_jwk(jwk: RSAJWKPublicKey | RSAJWKPrivateKey | str) -> JWKKeyPair:
89+
jwk_dict: RSAJWKPublicKey | RSAJWKPrivateKey
8590

8691
if isinstance(jwk, str):
8792
jwk_dict = json.loads(jwk)
@@ -92,10 +97,10 @@ def load_jwk(jwk: Union[RSAJWKPublicKey, RSAJWKPrivateKey, str]) -> KeyPair:
9297
private_key = load_jwk_private_key(jwk_dict) # type: ignore
9398
public_key = private_key.public_key()
9499
else: # Public key
95-
public_key = load_jwk_public_key(jwk_dict) # type: ignore
100+
public_key = load_jwk_public_key(jwk_dict)
96101
private_key = None
97102

98-
return KeyPair(private=private_key, public=public_key)
103+
return JWKKeyPair(private=private_key, public=public_key)
99104

100105

101106
def int_to_base64url(value: int) -> str:
@@ -111,19 +116,19 @@ def rsa_to_jwk(
111116

112117

113118
@overload
114-
def rsa_to_jwk( # type: ignore[overload-cannot-match]
119+
def rsa_to_jwk(
115120
key: RSAPrivateKey, *, kid: str = "", alg: AlgorithmType = "RS256", use: str = "sig",
116121
) -> RSAJWKPrivateKey: ...
117122

118123

119124
def rsa_to_jwk(
120-
key: Union[RSAPrivateKey, RSAPublicKey],
125+
key: RSAPrivateKey | RSAPublicKey,
121126
*,
122127
kid: str = "",
123128
alg: AlgorithmType = "RS256",
124129
use: str = "sig",
125130
kty: str = "RSA",
126-
) -> Union[RSAJWKPublicKey, RSAJWKPrivateKey]:
131+
) -> RSAJWKPublicKey | RSAJWKPrivateKey:
127132
if isinstance(key, RSAPublicKey):
128133
public_numbers = key.public_numbers()
129134
private_numbers = None
@@ -161,12 +166,14 @@ def rsa_to_jwk(
161166
)
162167

163168

164-
def load_private_key(data: Union[str, RSAJWKPrivateKey, Path]) -> RSAPrivateKey:
169+
def load_private_key(data: str | RSAJWKPrivateKey | Path) -> RSAPrivateKey:
165170
if isinstance(data, Path):
166171
data = data.read_text()
167172
if isinstance(data, str):
168173
if data.startswith("-----BEGIN "):
169-
return serialization.load_pem_private_key(data.encode(), None, default_backend())
174+
result = serialization.load_pem_private_key(data.encode(), None, default_backend())
175+
assert isinstance(result, RSAPrivateKey)
176+
return result
170177
if data.strip().startswith("{"):
171178
return load_jwk_private_key(json.loads(data))
172179
if isinstance(data, dict):
@@ -177,12 +184,14 @@ def load_private_key(data: Union[str, RSAJWKPrivateKey, Path]) -> RSAPrivateKey:
177184
return key
178185

179186

180-
def load_public_key(data: Union[str, RSAJWKPublicKey, Path]) -> RSAPublicKey:
187+
def load_public_key(data: str | RSAJWKPublicKey | Path) -> RSAPublicKey:
181188
if isinstance(data, Path):
182189
data = data.read_text()
183190
if isinstance(data, str):
184191
if data.startswith("-----BEGIN "):
185-
return serialization.load_pem_public_key(data.encode(), default_backend())
192+
result = serialization.load_pem_public_key(data.encode(), default_backend())
193+
assert isinstance(result, RSAPublicKey)
194+
return result
186195
if data.strip().startswith("{"):
187196
return load_jwk_public_key(json.loads(data))
188197
if isinstance(data, dict):

jwt_rsa/token.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,25 @@
22
from dataclasses import dataclass, field
33
from datetime import datetime, timedelta
44
from operator import add, sub
5-
from typing import (
6-
TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, TypeVar, Union, overload,
7-
)
5+
from types import EllipsisType
6+
from typing import Any, Callable, Dict, Optional, Sequence, TypeVar, overload
87

98
from jwt import PyJWT
109

11-
from .types import AlgorithmType, RSAPrivateKey, RSAPublicKey
12-
13-
14-
if TYPE_CHECKING:
15-
# pylama:ignore=E0602
16-
DateType = Union[timedelta, datetime, float, int, ellipsis]
17-
else:
18-
DateType = Union[timedelta, datetime, float, int, type(Ellipsis)]
19-
10+
from .types import AlgorithmType, RSAPrivateKey, RSAPublicKey, DateType
2011

2112
R = TypeVar("R")
2213
DAY = 86400
2314
DEFAULT_EXPIRATION = timedelta(days=31).total_seconds()
2415
NBF_DELTA = 20
25-
ALGORITHMS = tuple(AlgorithmType.__args__)
16+
ALGORITHMS: Sequence[AlgorithmType] = ("RS256", "RS384", "RS512")
2617

2718

2819
def date_to_timestamp(
29-
value: DateType,
20+
value: DateType | EllipsisType,
3021
default: Callable[[], R],
3122
timedelta_func: Callable[[float, float], int] = add,
32-
) -> Union[int, float, R]:
23+
) -> int | float | R:
3324
if isinstance(value, timedelta):
3425
return timedelta_func(time.time(), value.total_seconds())
3526
elif isinstance(value, datetime):
@@ -46,8 +37,8 @@ def date_to_timestamp(
4637
class JWTDecoder:
4738
jwt: PyJWT = field(repr=False, compare=False)
4839
public_key: RSAPublicKey = field(repr=False, compare=True)
49-
expires: Union[int, float]
50-
nbf_delta: Union[int, float]
40+
expires: int | float
41+
nbf_delta: int | float
5142
algorithm: AlgorithmType
5243
algorithms: Sequence[AlgorithmType]
5344

@@ -79,7 +70,7 @@ def __init__(self, key: RSAPrivateKey, *, options: Optional[Dict[str, Any]] = No
7970
super(JWTDecoder, self).__setattr__('private_key', key)
8071
super().__init__(key.public_key(), options=options, **kwargs)
8172

82-
def encode(self, expired: DateType = ..., nbf: DateType = ..., **claims: Any) -> str:
73+
def encode(self, expired: DateType | EllipsisType = ..., nbf: DateType | EllipsisType = ..., **claims: Any) -> str:
8374
claims.setdefault('exp', int(date_to_timestamp(expired, lambda: time.time() + self.expires)))
8475
claims.setdefault('nbf', int(date_to_timestamp(nbf, lambda: time.time() - self.nbf_delta, timedelta_func=sub)))
8576
return self.jwt.encode(claims, self.private_key, algorithm=self.algorithm)
@@ -97,7 +88,7 @@ def JWT(
9788

9889

9990
@overload
100-
def JWT( # type: ignore[overload-cannot-match]
91+
def JWT(
10192
key: RSAPublicKey, *,
10293
options: dict[str, Any] | None = None,
10394
expires: int | float = DEFAULT_EXPIRATION,
@@ -108,15 +99,15 @@ def JWT( # type: ignore[overload-cannot-match]
10899

109100

110101
def JWT(
111-
key: Union[RSAPrivateKey, RSAPublicKey],
102+
key: RSAPrivateKey | RSAPublicKey,
112103
*,
113104
options: dict[str, Any] | None = None,
114105
expires: int | float = DEFAULT_EXPIRATION,
115106
nbf_delta: int | float = NBF_DELTA,
116107
algorithm: AlgorithmType = "RS512",
117108
algorithms: Sequence[AlgorithmType] = ALGORITHMS,
118-
) -> Union[JWTSigner, JWTDecoder]:
119-
kwargs = dict(
109+
) -> JWTSigner | JWTDecoder:
110+
kwargs: dict[str, Any] = dict(
120111
expires=expires,
121112
nbf_delta=nbf_delta,
122113
algorithm=algorithm,

jwt_rsa/types.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import timedelta, datetime
12
from typing import Literal
23

34
from cryptography.hazmat.primitives import serialization
@@ -11,10 +12,11 @@
1112

1213

1314
AlgorithmType = Literal["RS256", "RS384", "RS512"]
14-
15+
DateType = timedelta | datetime | float | int
1516

1617
__all__ = (
1718
"AlgorithmType",
19+
"DateType",
1820
"RSAPrivateKey",
1921
"RSAPublicKey",
2022
"serialization",

jwt_rsa/verify.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,18 @@
33
from types import SimpleNamespace
44

55
from .rsa import generate_rsa, load_private_key, load_public_key
6-
from .token import JWT
6+
from .token import JWT, JWTSigner, JWTDecoder
77

88

99
def main(arguments: SimpleNamespace) -> None:
10+
jwt: JWTSigner | JWTDecoder
1011
if arguments.private_key:
11-
jwt = JWT(private_key=load_private_key(arguments.private_key))
12+
jwt = JWT(load_private_key(arguments.private_key))
1213
elif arguments.public_key:
13-
jwt = JWT(public_key=load_public_key(arguments.public_key))
14+
jwt = JWT(load_public_key(arguments.public_key))
1415
elif not arguments.verify:
15-
jwt = JWT(*generate_rsa(1024))
16+
key_pair = generate_rsa(1024)
17+
jwt = JWT(key_pair.private)
1618
else:
1719
print("Either private or public key must be provided", file=sys.stderr)
1820
exit(1)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ requires = ["poetry-core"]
4646
build-backend = "poetry.core.masonry.api"
4747

4848
[tool.pylama]
49-
linters = "pycodestyle,pyflakes,mccabe,mccabe,mypy"
49+
linters = "pycodestyle,pyflakes,mccabe,mccabe"
5050

5151
[tool.pylama.linter.pycodestyle]
5252
max_line_length = 119

0 commit comments

Comments
 (0)