11import six
2- from pyasn1 .codec .der import encoder
3- from pyasn1 .type import univ
2+ from pyasn1 .codec .der import decoder , encoder
3+ from pyasn1 .error import PyAsn1Error
4+ from pyasn1 .type import namedtype , univ
45
56import rsa as pyrsa
67import rsa .pem as pyrsa_pem
1213from jose .utils import base64_to_long , long_to_base64
1314
1415
15- PKCS8_RSA_HEADER = b'0\x82 \x04 \xbd \x02 \x01 \x00 0\r \x06 \t *\x86 H\x86 \xf7 \r \x01 \x01 \x01 \x05 \x00 '
16+ LEGACY_INVALID_PKCS8_RSA_HEADER = b'0\x82 \x04 \xbd \x02 \x01 \x00 0\r \x06 \t *\x86 H\x86 \xf7 \r \x01 \x01 \x01 \x05 \x00 '
17+ RSA_ENCRYPTION_ASN1_OID = "1.2.840.113549.1.1.1"
18+
1619# Functions gcd and rsa_recover_prime_factors were copied from cryptography 1.9
1720# to enable pure python rsa module to be in compliance with section 6.3.1 of RFC7518
1821# which requires only private exponent (d) for private key.
@@ -83,6 +86,59 @@ def pem_to_spki(pem, fmt='PKCS8'):
8386 return key .to_pem (fmt )
8487
8588
89+ def _legacy_private_key_pkcs8_to_pkcs1 (der_key ):
90+ """Legacy RSA private key PKCS8-to-PKCS1 conversion.
91+
92+ .. warning::
93+
94+ This is incorrect parsing and only works because the legacy PKCS1-to-PKCS8
95+ encoding was also incorrect.
96+ """
97+ return der_key [len (LEGACY_INVALID_PKCS8_RSA_HEADER ):]
98+
99+
100+ class PKCS8RsaPrivateKeyAlgorithm (univ .Sequence ):
101+ """ASN1 structure for recording RSA PrivateKeyAlgorithm identifiers."""
102+ componentType = namedtype .NamedTypes (
103+ namedtype .NamedType ("rsaEncryption" , univ .ObjectIdentifier ()),
104+ namedtype .NamedType ("parameters" , univ .Null ())
105+ )
106+
107+
108+ class PKCS8PrivateKey (univ .Sequence ):
109+ """ASN1 structure for recording PKCS8 private keys."""
110+ componentType = namedtype .NamedTypes (
111+ namedtype .NamedType ("version" , univ .Integer ()),
112+ namedtype .NamedType ("privateKeyAlgorithm" , PKCS8RsaPrivateKeyAlgorithm ()),
113+ namedtype .NamedType ("privateKey" , univ .OctetString ())
114+ )
115+
116+
117+ def _private_key_pkcs8_to_pkcs1 (pkcs8_key ):
118+ """Convert a PKCS8-encoded RSA private key to PKCS1."""
119+ decoded_values = decoder .decode (pkcs8_key , asn1Spec = PKCS8PrivateKey ())
120+
121+ try :
122+ decoded_key = decoded_values [0 ]
123+ except IndexError :
124+ raise ValueError ("Invalid private key encoding" )
125+
126+ return decoded_key ["privateKey" ]
127+
128+
129+ def _private_key_pkcs1_to_pkcs8 (pkcs1_key ):
130+ """Convert a PKCS1-encoded RSA private key to PKCS8."""
131+ algorithm = PKCS8RsaPrivateKeyAlgorithm ()
132+ algorithm ["rsaEncryption" ] = RSA_ENCRYPTION_ASN1_OID
133+
134+ pkcs8_key = PKCS8PrivateKey ()
135+ pkcs8_key ["version" ] = 0
136+ pkcs8_key ["privateKeyAlgorithm" ] = algorithm
137+ pkcs8_key ["privateKey" ] = pkcs1_key
138+
139+ return encoder .encode (pkcs8_key )
140+
141+
86142class RSAKey (Key ):
87143 SHA256 = 'SHA-256'
88144 SHA384 = 'SHA-384'
@@ -121,12 +177,15 @@ def __init__(self, key, algorithm):
121177 self ._prepared_key = pyrsa .PrivateKey .load_pkcs1 (key )
122178 except ValueError :
123179 try :
124- # python-rsa does not support PKCS8 yet so we have to manually remove OID
125180 der = pyrsa_pem .load_pem (key , b'PRIVATE KEY' )
126- header , der = der [:22 ], der [22 :]
127- if header != PKCS8_RSA_HEADER :
128- raise ValueError ("Invalid PKCS8 header" )
129- self ._prepared_key = pyrsa .PrivateKey ._load_pkcs1_der (der )
181+ try :
182+ pkcs1_key = _private_key_pkcs8_to_pkcs1 (der )
183+ except PyAsn1Error :
184+ # If the key was encoded using the old, invalid,
185+ # encoding then pyasn1 will throw an error attempting
186+ # to parse the key.
187+ pkcs1_key = _legacy_private_key_pkcs8_to_pkcs1 (der )
188+ self ._prepared_key = pyrsa .PrivateKey .load_pkcs1 (pkcs1_key , format = "DER" )
130189 except ValueError as e :
131190 raise JWKError (e )
132191 return
@@ -183,7 +242,8 @@ def to_pem(self, pem_format='PKCS8'):
183242 if isinstance (self ._prepared_key , pyrsa .PrivateKey ):
184243 der = self ._prepared_key .save_pkcs1 (format = 'DER' )
185244 if pem_format == 'PKCS8' :
186- pem = pyrsa_pem .save_pem (PKCS8_RSA_HEADER + der , pem_marker = 'PRIVATE KEY' )
245+ pkcs8_der = _private_key_pkcs1_to_pkcs8 (der )
246+ pem = pyrsa_pem .save_pem (pkcs8_der , pem_marker = 'PRIVATE KEY' )
187247 elif pem_format == 'PKCS1' :
188248 pem = pyrsa_pem .save_pem (der , pem_marker = 'RSA PRIVATE KEY' )
189249 else :
@@ -196,7 +256,7 @@ def to_pem(self, pem_format='PKCS8'):
196256 der = encoder .encode (asn_key )
197257
198258 header = PubKeyHeader ()
199- header ['oid' ] = univ .ObjectIdentifier ('1.2.840.113549.1.1.1' )
259+ header ['oid' ] = univ .ObjectIdentifier (RSA_ENCRYPTION_ASN1_OID )
200260 pub_key = OpenSSLPubKey ()
201261 pub_key ['header' ] = header
202262 pub_key ['key' ] = univ .BitString .fromOctetString (der )
0 commit comments