22import ecdsa
33from ecdsa .util import sigdecode_string , sigencode_string , sigdecode_der , sigencode_der
44
5- from jose .jwk import Key , base64_to_long
5+ from jose .backends .base import Key
6+ from jose .utils import base64_to_long
67from jose .constants import ALGORITHMS
78from jose .exceptions import JWKError
89
910from cryptography .exceptions import InvalidSignature
1011from cryptography .hazmat .backends import default_backend
11- from cryptography .hazmat .backends .openssl .rsa import _RSAPublicKey
12- from cryptography .hazmat .primitives import hashes
12+ from cryptography .hazmat .primitives import hashes , serialization
1313from cryptography .hazmat .primitives .asymmetric import ec , rsa , padding
1414from cryptography .hazmat .primitives .serialization import load_pem_private_key , load_pem_public_key
1515
@@ -34,10 +34,15 @@ def __init__(self, key, algorithm, cryptography_backend=default_backend):
3434 ALGORITHMS .ES384 : self .SHA384 ,
3535 ALGORITHMS .ES512 : self .SHA512
3636 }.get (algorithm )
37+ self ._algorithm = algorithm
3738
3839 self .curve = self .CURVE_MAP .get (self .hash_alg )
3940 self .cryptography_backend = cryptography_backend
4041
42+ if hasattr (key , 'public_bytes' ) or hasattr (key , 'private_bytes' ):
43+ self .prepared_key = key
44+ return
45+
4146 if isinstance (key , (ecdsa .SigningKey , ecdsa .VerifyingKey )):
4247 # convert to PEM and let cryptography below load it as PEM
4348 key = key .to_pem ().decode ('utf-8' )
@@ -93,6 +98,25 @@ def verify(self, msg, sig):
9398 except :
9499 return False
95100
101+ def public_key (self ):
102+ if hasattr (self .prepared_key , 'public_bytes' ):
103+ return self
104+ return self .__class__ (self .prepared_key .public_key (), self ._algorithm )
105+
106+ def to_pem (self ):
107+ if hasattr (self .prepared_key , 'public_bytes' ):
108+ pem = self .prepared_key .public_bytes (
109+ encoding = serialization .Encoding .PEM ,
110+ format = serialization .PublicFormat .SubjectPublicKeyInfo
111+ )
112+ return pem .decode ('utf-8' )
113+ pem = self .prepared_key .private_bytes (
114+ encoding = serialization .Encoding .PEM ,
115+ format = serialization .PrivateFormat .TraditionalOpenSSL ,
116+ encryption_algorithm = serialization .NoEncryption ()
117+ )
118+ return pem .decode ('utf-8' )
119+
96120
97121class CryptographyRSAKey (Key ):
98122 SHA256 = hashes .SHA256
@@ -108,10 +132,12 @@ def __init__(self, key, algorithm, cryptography_backend=default_backend):
108132 ALGORITHMS .RS384 : self .SHA384 ,
109133 ALGORITHMS .RS512 : self .SHA512
110134 }.get (algorithm )
135+ self ._algorithm = algorithm
111136
112137 self .cryptography_backend = cryptography_backend
113138
114- if isinstance (key , _RSAPublicKey ):
139+ # if it conforms to RSAPublicKey interface
140+ if hasattr (key , 'public_bytes' ) and hasattr (key , 'public_numbers' ):
115141 self .prepared_key = key
116142 return
117143
@@ -166,3 +192,21 @@ def verify(self, msg, sig):
166192 return True
167193 except InvalidSignature :
168194 return False
195+
196+ def public_key (self ):
197+ if hasattr (self .prepared_key , 'public_bytes' ):
198+ return self
199+ return self .__class__ (self .prepared_key .public_key (), self ._algorithm )
200+
201+ def to_pem (self ):
202+ if hasattr (self .prepared_key , 'public_bytes' ):
203+ return self .prepared_key .public_bytes (
204+ encoding = serialization .Encoding .PEM ,
205+ format = serialization .PublicFormat .SubjectPublicKeyInfo
206+ ).decode ('utf-8' )
207+
208+ return self .prepared_key .private_bytes (
209+ encoding = serialization .Encoding .PEM ,
210+ format = serialization .PrivateFormat .TraditionalOpenSSL ,
211+ encryption_algorithm = serialization .NoEncryption ()
212+ ).decode ('utf-8' )
0 commit comments