Skip to content

Commit 9a4ec31

Browse files
committed
fix rsa_backend RSA private key PKCS8 encoding
1 parent 5dd8cbc commit 9a4ec31

File tree

1 file changed

+70
-10
lines changed

1 file changed

+70
-10
lines changed

jose/backends/rsa_backend.py

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import six
2-
from pyasn1.codec.der import encoder
3-
from pyasn1.type import univ
2+
from pyasn1.codec.der import decoder, encoder
3+
from pyasn1.error import PyAsn1Error
4+
from pyasn1.type import namedtype, univ
45

56
import rsa as pyrsa
67
import rsa.pem as pyrsa_pem
@@ -12,7 +13,9 @@
1213
from jose.utils import base64_to_long, long_to_base64
1314

1415

15-
PKCS8_RSA_HEADER = b'0\x82\x04\xbd\x02\x01\x000\r\x06\t*\x86H\x86\xf7\r\x01\x01\x01\x05\x00'
16+
LEGACY_INVALID_PKCS8_RSA_HEADER = b'0\x82\x04\xbd\x02\x01\x000\r\x06\t*\x86H\x86\xf7\r\x01\x01\x01\x05\x00'
17+
RSA_ENCRYPTION_ASN1_OID = "1.2.840.113549.1.1.1"
18+
1619
# Functions gcd and rsa_recover_prime_factors were copied from cryptography 1.9
1720
# to enable pure python rsa module to be in compliance with section 6.3.1 of RFC7518
1821
# which requires only private exponent (d) for private key.
@@ -83,6 +86,59 @@ def pem_to_spki(pem, fmt='PKCS8'):
8386
return key.to_pem(fmt)
8487

8588

89+
def _legacy_private_key_pkcs8_to_pkcs1(der_key):
90+
"""Legacy RSA private key PKCS8-to-PKCS1 conversion.
91+
92+
.. warning::
93+
94+
This is incorrect parsing and only works because the legacy PKCS1-to-PKCS8
95+
encoding was also incorrect.
96+
"""
97+
return der_key[len(LEGACY_INVALID_PKCS8_RSA_HEADER):]
98+
99+
100+
class PKCS8RsaPrivateKeyAlgorithm(univ.Sequence):
101+
"""ASN1 structure for recording RSA PrivateKeyAlgorithm identifiers."""
102+
componentType = namedtype.NamedTypes(
103+
namedtype.NamedType("rsaEncryption", univ.ObjectIdentifier()),
104+
namedtype.NamedType("parameters", univ.Null())
105+
)
106+
107+
108+
class PKCS8PrivateKey(univ.Sequence):
109+
"""ASN1 structure for recording PKCS8 private keys."""
110+
componentType = namedtype.NamedTypes(
111+
namedtype.NamedType("version", univ.Integer()),
112+
namedtype.NamedType("privateKeyAlgorithm", PKCS8RsaPrivateKeyAlgorithm()),
113+
namedtype.NamedType("privateKey", univ.OctetString())
114+
)
115+
116+
117+
def _private_key_pkcs8_to_pkcs1(pkcs8_key):
118+
"""Convert a PKCS8-encoded RSA private key to PKCS1."""
119+
decoded_values = decoder.decode(pkcs8_key, asn1Spec=PKCS8PrivateKey())
120+
121+
try:
122+
decoded_key = decoded_values[0]
123+
except IndexError:
124+
raise ValueError("Invalid private key encoding")
125+
126+
return decoded_key["privateKey"]
127+
128+
129+
def _private_key_pkcs1_to_pkcs8(pkcs1_key):
130+
"""Convert a PKCS1-encoded RSA private key to PKCS8."""
131+
algorithm = PKCS8RsaPrivateKeyAlgorithm()
132+
algorithm["rsaEncryption"] = RSA_ENCRYPTION_ASN1_OID
133+
134+
pkcs8_key = PKCS8PrivateKey()
135+
pkcs8_key["version"] = 0
136+
pkcs8_key["privateKeyAlgorithm"] = algorithm
137+
pkcs8_key["privateKey"] = pkcs1_key
138+
139+
return encoder.encode(pkcs8_key)
140+
141+
86142
class RSAKey(Key):
87143
SHA256 = 'SHA-256'
88144
SHA384 = 'SHA-384'
@@ -121,12 +177,15 @@ def __init__(self, key, algorithm):
121177
self._prepared_key = pyrsa.PrivateKey.load_pkcs1(key)
122178
except ValueError:
123179
try:
124-
# python-rsa does not support PKCS8 yet so we have to manually remove OID
125180
der = pyrsa_pem.load_pem(key, b'PRIVATE KEY')
126-
header, der = der[:22], der[22:]
127-
if header != PKCS8_RSA_HEADER:
128-
raise ValueError("Invalid PKCS8 header")
129-
self._prepared_key = pyrsa.PrivateKey._load_pkcs1_der(der)
181+
try:
182+
pkcs1_key = _private_key_pkcs8_to_pkcs1(der)
183+
except PyAsn1Error:
184+
# If the key was encoded using the old, invalid,
185+
# encoding then pyasn1 will throw an error attempting
186+
# to parse the key.
187+
pkcs1_key = _legacy_private_key_pkcs8_to_pkcs1(der)
188+
self._prepared_key = pyrsa.PrivateKey.load_pkcs1(pkcs1_key, format="DER")
130189
except ValueError as e:
131190
raise JWKError(e)
132191
return
@@ -183,7 +242,8 @@ def to_pem(self, pem_format='PKCS8'):
183242
if isinstance(self._prepared_key, pyrsa.PrivateKey):
184243
der = self._prepared_key.save_pkcs1(format='DER')
185244
if pem_format == 'PKCS8':
186-
pem = pyrsa_pem.save_pem(PKCS8_RSA_HEADER + der, pem_marker='PRIVATE KEY')
245+
pkcs8_der = _private_key_pkcs1_to_pkcs8(der)
246+
pem = pyrsa_pem.save_pem(pkcs8_der, pem_marker='PRIVATE KEY')
187247
elif pem_format == 'PKCS1':
188248
pem = pyrsa_pem.save_pem(der, pem_marker='RSA PRIVATE KEY')
189249
else:
@@ -196,7 +256,7 @@ def to_pem(self, pem_format='PKCS8'):
196256
der = encoder.encode(asn_key)
197257

198258
header = PubKeyHeader()
199-
header['oid'] = univ.ObjectIdentifier('1.2.840.113549.1.1.1')
259+
header['oid'] = univ.ObjectIdentifier(RSA_ENCRYPTION_ASN1_OID)
200260
pub_key = OpenSSLPubKey()
201261
pub_key['header'] = header
202262
pub_key['key'] = univ.BitString.fromOctetString(der)

0 commit comments

Comments
 (0)