1+ import binascii
2+
13import six
2- from pyasn1 .codec .der import encoder
3- from pyasn1 .type import univ
4+ from pyasn1 .codec .der import decoder , encoder
5+ from pyasn1 .error import PyAsn1Error
6+ from pyasn1 .type import namedtype , univ
47
58import rsa as pyrsa
69import rsa .pem as pyrsa_pem
1215from jose .utils import base64_to_long , long_to_base64
1316
1417
15- PKCS8_RSA_HEADER = b'0\x82 \x04 \xbd \x02 \x01 \x00 0\r \x06 \t *\x86 H\x86 \xf7 \r \x01 \x01 \x01 \x05 \x00 '
18+ LEGACY_INVALID_PKCS8_RSA_HEADER = binascii .unhexlify (
19+ "30" # sequence
20+ "8204BD" # DER-encoded sequence contents length of 1213 bytes -- INCORRECT STATIC LENGTH
21+ "020100" # integer: 0 -- Version
22+ "30" # sequence
23+ "0D" # DER-encoded sequence contents length of 13 bytes -- PrivateKeyAlgorithmIdentifier
24+ "06092A864886F70D010101" # OID -- rsaEncryption
25+ "0500" # NULL -- parameters
26+ )
27+ ASN1_SEQUENCE_ID = binascii .unhexlify ("30" )
28+ RSA_ENCRYPTION_ASN1_OID = "1.2.840.113549.1.1.1"
29+
1630# Functions gcd and rsa_recover_prime_factors were copied from cryptography 1.9
1731# to enable pure python rsa module to be in compliance with section 6.3.1 of RFC7518
1832# which requires only private exponent (d) for private key.
@@ -83,6 +97,65 @@ def pem_to_spki(pem, fmt='PKCS8'):
8397 return key .to_pem (fmt )
8498
8599
100+ def _legacy_private_key_pkcs8_to_pkcs1 (pkcs8_key ):
101+ """Legacy RSA private key PKCS8-to-PKCS1 conversion.
102+
103+ .. warning::
104+
105+ This is incorrect parsing and only works because the legacy PKCS1-to-PKCS8
106+ encoding was also incorrect.
107+ """
108+ # Only allow this processing if the prefix matches
109+ # AND the following byte indicates an ASN1 sequence,
110+ # as we would expect with the legacy encoding.
111+ if not pkcs8_key .startswith (LEGACY_INVALID_PKCS8_RSA_HEADER + ASN1_SEQUENCE_ID ):
112+ raise ValueError ("Invalid private key encoding" )
113+
114+ return pkcs8_key [len (LEGACY_INVALID_PKCS8_RSA_HEADER ):]
115+
116+
117+ class PKCS8RsaPrivateKeyAlgorithm (univ .Sequence ):
118+ """ASN1 structure for recording RSA PrivateKeyAlgorithm identifiers."""
119+ componentType = namedtype .NamedTypes (
120+ namedtype .NamedType ("rsaEncryption" , univ .ObjectIdentifier ()),
121+ namedtype .NamedType ("parameters" , univ .Null ())
122+ )
123+
124+
125+ class PKCS8PrivateKey (univ .Sequence ):
126+ """ASN1 structure for recording PKCS8 private keys."""
127+ componentType = namedtype .NamedTypes (
128+ namedtype .NamedType ("version" , univ .Integer ()),
129+ namedtype .NamedType ("privateKeyAlgorithm" , PKCS8RsaPrivateKeyAlgorithm ()),
130+ namedtype .NamedType ("privateKey" , univ .OctetString ())
131+ )
132+
133+
134+ def _private_key_pkcs8_to_pkcs1 (pkcs8_key ):
135+ """Convert a PKCS8-encoded RSA private key to PKCS1."""
136+ decoded_values = decoder .decode (pkcs8_key , asn1Spec = PKCS8PrivateKey ())
137+
138+ try :
139+ decoded_key = decoded_values [0 ]
140+ except IndexError :
141+ raise ValueError ("Invalid private key encoding" )
142+
143+ return decoded_key ["privateKey" ]
144+
145+
146+ def _private_key_pkcs1_to_pkcs8 (pkcs1_key ):
147+ """Convert a PKCS1-encoded RSA private key to PKCS8."""
148+ algorithm = PKCS8RsaPrivateKeyAlgorithm ()
149+ algorithm ["rsaEncryption" ] = RSA_ENCRYPTION_ASN1_OID
150+
151+ pkcs8_key = PKCS8PrivateKey ()
152+ pkcs8_key ["version" ] = 0
153+ pkcs8_key ["privateKeyAlgorithm" ] = algorithm
154+ pkcs8_key ["privateKey" ] = pkcs1_key
155+
156+ return encoder .encode (pkcs8_key )
157+
158+
86159class RSAKey (Key ):
87160 SHA256 = 'SHA-256'
88161 SHA384 = 'SHA-384'
@@ -121,12 +194,15 @@ def __init__(self, key, algorithm):
121194 self ._prepared_key = pyrsa .PrivateKey .load_pkcs1 (key )
122195 except ValueError :
123196 try :
124- # python-rsa does not support PKCS8 yet so we have to manually remove OID
125197 der = pyrsa_pem .load_pem (key , b'PRIVATE KEY' )
126- header , der = der [:22 ], der [22 :]
127- if header != PKCS8_RSA_HEADER :
128- raise ValueError ("Invalid PKCS8 header" )
129- self ._prepared_key = pyrsa .PrivateKey ._load_pkcs1_der (der )
198+ try :
199+ pkcs1_key = _private_key_pkcs8_to_pkcs1 (der )
200+ except PyAsn1Error :
201+ # If the key was encoded using the old, invalid,
202+ # encoding then pyasn1 will throw an error attempting
203+ # to parse the key.
204+ pkcs1_key = _legacy_private_key_pkcs8_to_pkcs1 (der )
205+ self ._prepared_key = pyrsa .PrivateKey .load_pkcs1 (pkcs1_key , format = "DER" )
130206 except ValueError as e :
131207 raise JWKError (e )
132208 return
@@ -183,7 +259,8 @@ def to_pem(self, pem_format='PKCS8'):
183259 if isinstance (self ._prepared_key , pyrsa .PrivateKey ):
184260 der = self ._prepared_key .save_pkcs1 (format = 'DER' )
185261 if pem_format == 'PKCS8' :
186- pem = pyrsa_pem .save_pem (PKCS8_RSA_HEADER + der , pem_marker = 'PRIVATE KEY' )
262+ pkcs8_der = _private_key_pkcs1_to_pkcs8 (der )
263+ pem = pyrsa_pem .save_pem (pkcs8_der , pem_marker = 'PRIVATE KEY' )
187264 elif pem_format == 'PKCS1' :
188265 pem = pyrsa_pem .save_pem (der , pem_marker = 'RSA PRIVATE KEY' )
189266 else :
@@ -196,7 +273,7 @@ def to_pem(self, pem_format='PKCS8'):
196273 der = encoder .encode (asn_key )
197274
198275 header = PubKeyHeader ()
199- header ['oid' ] = univ .ObjectIdentifier ('1.2.840.113549.1.1.1' )
276+ header ['oid' ] = univ .ObjectIdentifier (RSA_ENCRYPTION_ASN1_OID )
200277 pub_key = OpenSSLPubKey ()
201278 pub_key ['header' ] = header
202279 pub_key ['key' ] = univ .BitString .fromOctetString (der )
0 commit comments