|
| 1 | +import six |
| 2 | +from pyasn1.codec.der import encoder |
| 3 | +from pyasn1.type import univ |
| 4 | + |
| 5 | +import rsa as pyrsa |
| 6 | +import rsa.pem as pyrsa_pem |
| 7 | +from rsa.asn1 import OpenSSLPubKey, AsnPubKey, PubKeyHeader |
| 8 | + |
| 9 | +from jose.backends.base import Key |
| 10 | +from jose.constants import ALGORITHMS |
| 11 | +from jose.exceptions import JWKError |
| 12 | +from jose.utils import base64_to_long, long_to_base64 |
| 13 | + |
| 14 | + |
| 15 | +PKCS8_RSA_HEADER = b'0\x82\x04\xbd\x02\x01\x000\r\x06\t*\x86H\x86\xf7\r\x01\x01\x01\x05\x00' |
| 16 | +# Functions gcd and rsa_recover_prime_factors were copied from cryptography 1.9 |
| 17 | +# to enable pure python rsa module to be in compliance with section 6.3.1 of RFC7518 |
| 18 | +# which requires only private exponent (d) for private key. |
| 19 | + |
| 20 | +def _gcd(a, b): |
| 21 | + """Calculate the Greatest Common Divisor of a and b. |
| 22 | +
|
| 23 | + Unless b==0, the result will have the same sign as b (so that when |
| 24 | + b is divided by it, the result comes out positive). |
| 25 | + """ |
| 26 | + while b: |
| 27 | + a, b = b, a%b |
| 28 | + return a |
| 29 | + |
| 30 | + |
| 31 | +# Controls the number of iterations rsa_recover_prime_factors will perform |
| 32 | +# to obtain the prime factors. Each iteration increments by 2 so the actual |
| 33 | +# maximum attempts is half this number. |
| 34 | +_MAX_RECOVERY_ATTEMPTS = 1000 |
| 35 | + |
| 36 | + |
| 37 | +def _rsa_recover_prime_factors(n, e, d): |
| 38 | + """ |
| 39 | + Compute factors p and q from the private exponent d. We assume that n has |
| 40 | + no more than two factors. This function is adapted from code in PyCrypto. |
| 41 | + """ |
| 42 | + # See 8.2.2(i) in Handbook of Applied Cryptography. |
| 43 | + ktot = d * e - 1 |
| 44 | + # The quantity d*e-1 is a multiple of phi(n), even, |
| 45 | + # and can be represented as t*2^s. |
| 46 | + t = ktot |
| 47 | + while t % 2 == 0: |
| 48 | + t = t // 2 |
| 49 | + # Cycle through all multiplicative inverses in Zn. |
| 50 | + # The algorithm is non-deterministic, but there is a 50% chance |
| 51 | + # any candidate a leads to successful factoring. |
| 52 | + # See "Digitalized Signatures and Public Key Functions as Intractable |
| 53 | + # as Factorization", M. Rabin, 1979 |
| 54 | + spotted = False |
| 55 | + a = 2 |
| 56 | + while not spotted and a < _MAX_RECOVERY_ATTEMPTS: |
| 57 | + k = t |
| 58 | + # Cycle through all values a^{t*2^i}=a^k |
| 59 | + while k < ktot: |
| 60 | + cand = pow(a, k, n) |
| 61 | + # Check if a^k is a non-trivial root of unity (mod n) |
| 62 | + if cand != 1 and cand != (n - 1) and pow(cand, 2, n) == 1: |
| 63 | + # We have found a number such that (cand-1)(cand+1)=0 (mod n). |
| 64 | + # Either of the terms divides n. |
| 65 | + p = _gcd(cand + 1, n) |
| 66 | + spotted = True |
| 67 | + break |
| 68 | + k *= 2 |
| 69 | + # This value was not any good... let's try another! |
| 70 | + a += 2 |
| 71 | + if not spotted: |
| 72 | + raise ValueError("Unable to compute factors p and q from exponent d.") |
| 73 | + # Found ! |
| 74 | + q, r = divmod(n, p) |
| 75 | + assert r == 0 |
| 76 | + p, q = sorted((p, q), reverse=True) |
| 77 | + return (p, q) |
| 78 | + |
| 79 | + |
| 80 | +def pem_to_spki(pem, fmt='PKCS8'): |
| 81 | + key = RSAKey(pem, ALGORITHMS.RS256) |
| 82 | + return key.to_pem(fmt) |
| 83 | + |
| 84 | + |
| 85 | +class RSAKey(Key): |
| 86 | + SHA256 = 'SHA-256' |
| 87 | + SHA384 = 'SHA-384' |
| 88 | + SHA512 = 'SHA-512' |
| 89 | + |
| 90 | + def __init__(self, key, algorithm): |
| 91 | + if algorithm not in ALGORITHMS.RSA: |
| 92 | + raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm) |
| 93 | + |
| 94 | + self.hash_alg = { |
| 95 | + ALGORITHMS.RS256: self.SHA256, |
| 96 | + ALGORITHMS.RS384: self.SHA384, |
| 97 | + ALGORITHMS.RS512: self.SHA512 |
| 98 | + }.get(algorithm) |
| 99 | + self._algorithm = algorithm |
| 100 | + |
| 101 | + if isinstance(key, dict): |
| 102 | + self._prepared_key = self._process_jwk(key) |
| 103 | + return |
| 104 | + |
| 105 | + if isinstance(key, (pyrsa.PublicKey, pyrsa.PrivateKey)): |
| 106 | + self._prepared_key = key |
| 107 | + return |
| 108 | + |
| 109 | + if isinstance(key, six.string_types): |
| 110 | + key = key.encode('utf-8') |
| 111 | + |
| 112 | + if isinstance(key, six.binary_type): |
| 113 | + try: |
| 114 | + self._prepared_key = pyrsa.PublicKey.load_pkcs1(key) |
| 115 | + except ValueError: |
| 116 | + try: |
| 117 | + self._prepared_key = pyrsa.PublicKey.load_pkcs1_openssl_pem(key) |
| 118 | + except ValueError: |
| 119 | + try: |
| 120 | + self._prepared_key = pyrsa.PrivateKey.load_pkcs1(key) |
| 121 | + except ValueError: |
| 122 | + try: |
| 123 | + # python-rsa does not support PKCS8 yet so we have to manually remove OID |
| 124 | + der = pyrsa_pem.load_pem(key, b'PRIVATE KEY') |
| 125 | + header, der = der[:22], der[22:] |
| 126 | + if header != PKCS8_RSA_HEADER: |
| 127 | + raise ValueError("Invalid PKCS8 header") |
| 128 | + self._prepared_key = pyrsa.PrivateKey._load_pkcs1_der(der) |
| 129 | + except ValueError as e: |
| 130 | + raise JWKError(e) |
| 131 | + return |
| 132 | + raise JWKError('Unable to parse an RSA_JWK from key: %s' % key) |
| 133 | + |
| 134 | + def _process_jwk(self, jwk_dict): |
| 135 | + if not jwk_dict.get('kty') == 'RSA': |
| 136 | + raise JWKError("Incorrect key type. Expected: 'RSA', Recieved: %s" % jwk_dict.get('kty')) |
| 137 | + |
| 138 | + e = base64_to_long(jwk_dict.get('e')) |
| 139 | + n = base64_to_long(jwk_dict.get('n')) |
| 140 | + |
| 141 | + if not 'd' in jwk_dict: |
| 142 | + return pyrsa.PublicKey(e=e, n=n) |
| 143 | + else: |
| 144 | + d = base64_to_long(jwk_dict.get('d')) |
| 145 | + extra_params = ['p', 'q', 'dp', 'dq', 'qi'] |
| 146 | + |
| 147 | + if any(k in jwk_dict for k in extra_params): |
| 148 | + # Precomputed private key parameters are available. |
| 149 | + if not all(k in jwk_dict for k in extra_params): |
| 150 | + # These values must be present when 'p' is according to |
| 151 | + # Section 6.3.2 of RFC7518, so if they are not we raise |
| 152 | + # an error. |
| 153 | + raise JWKError('Precomputed private key parameters are incomplete.') |
| 154 | + |
| 155 | + p = base64_to_long(jwk_dict['p']) |
| 156 | + q = base64_to_long(jwk_dict['q']) |
| 157 | + return pyrsa.PrivateKey(e=e, n=n, d=d, p=p, q=q) |
| 158 | + else: |
| 159 | + p, q = _rsa_recover_prime_factors(n, e, d) |
| 160 | + return pyrsa.PrivateKey(n=n, e=e, d=d, p=p, q=q) |
| 161 | + |
| 162 | + |
| 163 | + |
| 164 | + def sign(self, msg): |
| 165 | + return pyrsa.sign(msg, self._prepared_key, self.hash_alg) |
| 166 | + |
| 167 | + def verify(self, msg, sig): |
| 168 | + try: |
| 169 | + pyrsa.verify(msg, sig, self._prepared_key) |
| 170 | + return True |
| 171 | + except pyrsa.pkcs1.VerificationError: |
| 172 | + return False |
| 173 | + |
| 174 | + def is_public(self): |
| 175 | + return isinstance(self._prepared_key, pyrsa.PublicKey) |
| 176 | + |
| 177 | + def public_key(self): |
| 178 | + if isinstance(self._prepared_key, pyrsa.PublicKey): |
| 179 | + return self |
| 180 | + return self.__class__(pyrsa.PublicKey(n=self._prepared_key.n, e=self._prepared_key.e), self._algorithm) |
| 181 | + |
| 182 | + def to_pem(self, pem_format='PKCS8'): |
| 183 | + |
| 184 | + if isinstance(self._prepared_key, pyrsa.PrivateKey): |
| 185 | + der = self._prepared_key.save_pkcs1(format='DER') |
| 186 | + if pem_format == 'PKCS8': |
| 187 | + pem = pyrsa_pem.save_pem(PKCS8_RSA_HEADER + der, pem_marker='PRIVATE KEY') |
| 188 | + elif pem_format == 'PKCS1': |
| 189 | + pem = pyrsa_pem.save_pem(der, pem_marker='RSA PRIVATE KEY') |
| 190 | + else: |
| 191 | + raise ValueError("Invalid pem format specified: %r" % (pem_format,)) |
| 192 | + else: |
| 193 | + if pem_format == 'PKCS8': |
| 194 | + asn_key = AsnPubKey() |
| 195 | + asn_key.setComponentByName('modulus', self._prepared_key.n) |
| 196 | + asn_key.setComponentByName('publicExponent', self._prepared_key.e) |
| 197 | + der = encoder.encode(asn_key) |
| 198 | + |
| 199 | + header = PubKeyHeader() |
| 200 | + header['oid'] = univ.ObjectIdentifier('1.2.840.113549.1.1.1') |
| 201 | + pub_key = OpenSSLPubKey() |
| 202 | + pub_key['header'] = header |
| 203 | + pub_key['key'] = univ.BitString.fromOctetString(der) |
| 204 | + |
| 205 | + der = encoder.encode(pub_key) |
| 206 | + pem = pyrsa_pem.save_pem(der, pem_marker='PUBLIC KEY') |
| 207 | + elif pem_format == 'PKCS1': |
| 208 | + der = self._prepared_key.save_pkcs1(format='DER') |
| 209 | + pem = pyrsa_pem.save_pem(der, pem_marker='RSA PUBLIC KEY') |
| 210 | + else: |
| 211 | + raise ValueError("Invalid pem format specified: %r" % (pem_format,)) |
| 212 | + return pem |
| 213 | + |
| 214 | + def to_dict(self): |
| 215 | + if not self.is_public(): |
| 216 | + public_key = self.public_key()._prepared_key |
| 217 | + else: |
| 218 | + public_key = self._prepared_key |
| 219 | + |
| 220 | + data = { |
| 221 | + 'alg': self._algorithm, |
| 222 | + 'kty': 'RSA', |
| 223 | + 'n': long_to_base64(public_key.n), |
| 224 | + 'e': long_to_base64(public_key.e), |
| 225 | + } |
| 226 | + |
| 227 | + if not self.is_public(): |
| 228 | + data.update({ |
| 229 | + 'd': long_to_base64(self._prepared_key.d), |
| 230 | + 'p': long_to_base64(self._prepared_key.p), |
| 231 | + 'q': long_to_base64(self._prepared_key.q), |
| 232 | + 'dp': long_to_base64(self._prepared_key.exp1), |
| 233 | + 'dq': long_to_base64(self._prepared_key.exp2), |
| 234 | + 'qi': long_to_base64(self._prepared_key.coef), |
| 235 | + }) |
| 236 | + |
| 237 | + return data |
0 commit comments