11import base64
22import json
33from pathlib import Path
4- from typing import NamedTuple , Optional , TypedDict , Union , overload
4+ from typing import NamedTuple , Optional , TypedDict , overload
55
66from cryptography .hazmat .backends import default_backend
77
1111
1212
1313class KeyPair (NamedTuple ):
14+ private : RSAPrivateKey
15+ public : RSAPublicKey
16+
17+
18+ class JWKKeyPair (NamedTuple ):
1419 private : Optional [RSAPrivateKey ]
1520 public : RSAPublicKey
1621
@@ -80,8 +85,8 @@ def load_jwk_private_key(jwk: RSAJWKPrivateKey) -> RSAPrivateKey:
8085 return private_numbers .private_key (default_backend ())
8186
8287
83- def load_jwk (jwk : Union [ RSAJWKPublicKey , RSAJWKPrivateKey , str ] ) -> KeyPair :
84- jwk_dict : Union [ RSAJWKPublicKey , RSAJWKPrivateKey ]
88+ def load_jwk (jwk : RSAJWKPublicKey | RSAJWKPrivateKey | str ) -> JWKKeyPair :
89+ jwk_dict : RSAJWKPublicKey | RSAJWKPrivateKey
8590
8691 if isinstance (jwk , str ):
8792 jwk_dict = json .loads (jwk )
@@ -92,10 +97,10 @@ def load_jwk(jwk: Union[RSAJWKPublicKey, RSAJWKPrivateKey, str]) -> KeyPair:
9297 private_key = load_jwk_private_key (jwk_dict ) # type: ignore
9398 public_key = private_key .public_key ()
9499 else : # Public key
95- public_key = load_jwk_public_key (jwk_dict ) # type: ignore
100+ public_key = load_jwk_public_key (jwk_dict )
96101 private_key = None
97102
98- return KeyPair (private = private_key , public = public_key )
103+ return JWKKeyPair (private = private_key , public = public_key )
99104
100105
101106def int_to_base64url (value : int ) -> str :
@@ -106,24 +111,24 @@ def int_to_base64url(value: int) -> str:
106111
107112@overload
108113def rsa_to_jwk (
109- key : RSAPublicKey , * , kid : str = "" , alg : AlgorithmType = "RS256" , use : str = "sig"
114+ key : RSAPublicKey , * , kid : str = "" , alg : AlgorithmType = "RS256" , use : str = "sig" ,
110115) -> RSAJWKPublicKey : ...
111116
112117
113118@overload
114- def rsa_to_jwk ( # type: ignore[overload-cannot-match]
119+ def rsa_to_jwk (
115120 key : RSAPrivateKey , * , kid : str = "" , alg : AlgorithmType = "RS256" , use : str = "sig" ,
116121) -> RSAJWKPrivateKey : ...
117122
118123
119124def rsa_to_jwk (
120- key : Union [ RSAPrivateKey , RSAPublicKey ] ,
125+ key : RSAPrivateKey | RSAPublicKey ,
121126 * ,
122127 kid : str = "" ,
123128 alg : AlgorithmType = "RS256" ,
124129 use : str = "sig" ,
125130 kty : str = "RSA" ,
126- ) -> Union [ RSAJWKPublicKey , RSAJWKPrivateKey ] :
131+ ) -> RSAJWKPublicKey | RSAJWKPrivateKey :
127132 if isinstance (key , RSAPublicKey ):
128133 public_numbers = key .public_numbers ()
129134 private_numbers = None
@@ -161,12 +166,14 @@ def rsa_to_jwk(
161166 )
162167
163168
164- def load_private_key (data : Union [ str , RSAJWKPrivateKey , Path ] ) -> RSAPrivateKey :
169+ def load_private_key (data : str | RSAJWKPrivateKey | Path ) -> RSAPrivateKey :
165170 if isinstance (data , Path ):
166171 data = data .read_text ()
167172 if isinstance (data , str ):
168173 if data .startswith ("-----BEGIN " ):
169- return serialization .load_pem_private_key (data .encode (), None , default_backend ())
174+ result = serialization .load_pem_private_key (data .encode (), None , default_backend ())
175+ assert isinstance (result , RSAPrivateKey )
176+ return result
170177 if data .strip ().startswith ("{" ):
171178 return load_jwk_private_key (json .loads (data ))
172179 if isinstance (data , dict ):
@@ -177,12 +184,14 @@ def load_private_key(data: Union[str, RSAJWKPrivateKey, Path]) -> RSAPrivateKey:
177184 return key
178185
179186
180- def load_public_key (data : Union [ str , RSAJWKPublicKey , Path ] ) -> RSAPublicKey :
187+ def load_public_key (data : str | RSAJWKPublicKey | Path ) -> RSAPublicKey :
181188 if isinstance (data , Path ):
182189 data = data .read_text ()
183190 if isinstance (data , str ):
184191 if data .startswith ("-----BEGIN " ):
185- return serialization .load_pem_public_key (data .encode (), default_backend ())
192+ result = serialization .load_pem_public_key (data .encode (), default_backend ())
193+ assert isinstance (result , RSAPublicKey )
194+ return result
186195 if data .strip ().startswith ("{" ):
187196 return load_jwk_public_key (json .loads (data ))
188197 if isinstance (data , dict ):
0 commit comments