diff --git a/jose/backends/cryptography_backend.py b/jose/backends/cryptography_backend.py index abd24260..1117e106 100644 --- a/jose/backends/cryptography_backend.py +++ b/jose/backends/cryptography_backend.py @@ -15,7 +15,7 @@ from cryptography.x509 import load_pem_x509_certificate from ..constants import ALGORITHMS -from ..exceptions import JWEError, JWKError +from ..exceptions import JWEError, JWKError, JWKAlgMismatchError from ..utils import base64_to_long, base64url_decode, base64url_encode, ensure_binary, long_to_base64 from .base import Key @@ -52,7 +52,7 @@ class CryptographyECKey(Key): def __init__(self, key, algorithm, cryptography_backend=default_backend): if algorithm not in ALGORITHMS.EC: - raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm) + raise JWKAlgMismatchError("%s is not a valid EC algorithm" % algorithm) self.hash_alg = { ALGORITHMS.ES256: self.SHA256, @@ -97,7 +97,7 @@ def __init__(self, key, algorithm, cryptography_backend=default_backend): def _process_jwk(self, jwk_dict): if not jwk_dict.get("kty") == "EC": - raise JWKError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty")) + raise JWKAlgMismatchError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty")) if not all(k in jwk_dict for k in ["x", "y", "crv"]): raise JWKError("Mandatory parameters are missing") @@ -226,7 +226,7 @@ class CryptographyRSAKey(Key): def __init__(self, key, algorithm, cryptography_backend=default_backend): if algorithm not in ALGORITHMS.RSA: - raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm) + raise JWKAlgMismatchError("%s is not a valid RSA algorithm" % algorithm) self.hash_alg = { ALGORITHMS.RS256: self.SHA256, @@ -273,7 +273,7 @@ def __init__(self, key, algorithm, cryptography_backend=default_backend): def _process_jwk(self, jwk_dict): if not jwk_dict.get("kty") == "RSA": - raise JWKError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty")) + raise JWKAlgMismatchError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty")) e = base64_to_long(jwk_dict.get("e", 256)) n = base64_to_long(jwk_dict.get("n")) @@ -441,9 +441,9 @@ class CryptographyAESKey(Key): def __init__(self, key, algorithm): if algorithm not in ALGORITHMS.AES: - raise JWKError("%s is not a valid AES algorithm" % algorithm) + raise JWKAlgMismatchError("%s is not a valid AES algorithm" % algorithm) if algorithm not in ALGORITHMS.SUPPORTED.union(ALGORITHMS.AES_PSEUDO): - raise JWKError("%s is not a supported algorithm" % algorithm) + raise JWKAlgMismatchError("%s is not a supported algorithm" % algorithm) self._algorithm = algorithm self._mode = self.MODES.get(self._algorithm) @@ -538,7 +538,7 @@ class CryptographyHMACKey(Key): def __init__(self, key, algorithm): if algorithm not in ALGORITHMS.HMAC: - raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm) + raise JWKAlgMismatchError("hash_alg: %s is not a valid hash algorithm" % algorithm) self._algorithm = algorithm self._hash_alg = self.ALG_MAP.get(algorithm) @@ -569,7 +569,7 @@ def __init__(self, key, algorithm): def _process_jwk(self, jwk_dict): if not jwk_dict.get("kty") == "oct": - raise JWKError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty")) + raise JWKAlgMismatchError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty")) k = jwk_dict.get("k") k = k.encode("utf-8") diff --git a/jose/backends/ecdsa_backend.py b/jose/backends/ecdsa_backend.py index 756c7ea8..ecb5aac6 100644 --- a/jose/backends/ecdsa_backend.py +++ b/jose/backends/ecdsa_backend.py @@ -4,7 +4,7 @@ from jose.backends.base import Key from jose.constants import ALGORITHMS -from jose.exceptions import JWKError +from jose.exceptions import JWKError, JWKAlgMismatchError from jose.utils import base64_to_long, long_to_base64 @@ -35,7 +35,7 @@ class ECDSAECKey(Key): def __init__(self, key, algorithm): if algorithm not in ALGORITHMS.EC: - raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm) + raise JWKAlgMismatchError("%s is not a valid EC algorithm" % algorithm) self.hash_alg = { ALGORITHMS.ES256: self.SHA256, @@ -75,7 +75,7 @@ def __init__(self, key, algorithm): def _process_jwk(self, jwk_dict): if not jwk_dict.get("kty") == "EC": - raise JWKError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty")) + raise JWKAlgMismatchError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get("kty")) if not all(k in jwk_dict for k in ["x", "y", "crv"]): raise JWKError("Mandatory parameters are missing") diff --git a/jose/backends/native.py b/jose/backends/native.py index eb3a6ae3..7661e2c5 100644 --- a/jose/backends/native.py +++ b/jose/backends/native.py @@ -4,7 +4,7 @@ from jose.backends.base import Key from jose.constants import ALGORITHMS -from jose.exceptions import JWKError +from jose.exceptions import JWKError, JWKAlgMismatchError from jose.utils import base64url_decode, base64url_encode @@ -22,7 +22,7 @@ class HMACKey(Key): def __init__(self, key, algorithm): if algorithm not in ALGORITHMS.HMAC: - raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm) + raise JWKAlgMismatchError("hash_alg: %s is not a valid hash algorithm" % algorithm) self._algorithm = algorithm self._hash_alg = self.HASHES.get(algorithm) @@ -53,7 +53,7 @@ def __init__(self, key, algorithm): def _process_jwk(self, jwk_dict): if not jwk_dict.get("kty") == "oct": - raise JWKError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty")) + raise JWKAlgMismatchError("Incorrect key type. Expected: 'oct', Received: %s" % jwk_dict.get("kty")) k = jwk_dict.get("k") k = k.encode("utf-8") diff --git a/jose/backends/rsa_backend.py b/jose/backends/rsa_backend.py index 4e8ccf1c..c908e4b3 100644 --- a/jose/backends/rsa_backend.py +++ b/jose/backends/rsa_backend.py @@ -13,7 +13,7 @@ ) from jose.backends.base import Key from jose.constants import ALGORITHMS -from jose.exceptions import JWEError, JWKError +from jose.exceptions import JWEError, JWKError, JWKAlgMismatchError from jose.utils import base64_to_long, long_to_base64 ALGORITHMS.SUPPORTED.remove(ALGORITHMS.RSA_OAEP) # RSA OAEP not supported @@ -124,7 +124,7 @@ class RSAKey(Key): def __init__(self, key, algorithm): if algorithm not in ALGORITHMS.RSA: - raise JWKError("hash_alg: %s is not a valid hash algorithm" % algorithm) + raise JWKAlgMismatchError("%s is not a valid RSA algorithm" % algorithm) if algorithm in ALGORITHMS.RSA_KW and algorithm != ALGORITHMS.RSA1_5: raise JWKError("alg: %s is not supported by the RSA backend" % algorithm) @@ -174,7 +174,7 @@ def __init__(self, key, algorithm): def _process_jwk(self, jwk_dict): if not jwk_dict.get("kty") == "RSA": - raise JWKError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty")) + raise JWKAlgMismatchError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get("kty")) e = base64_to_long(jwk_dict.get("e")) n = base64_to_long(jwk_dict.get("n")) diff --git a/jose/exceptions.py b/jose/exceptions.py index e8edc3b6..2099208d 100644 --- a/jose/exceptions.py +++ b/jose/exceptions.py @@ -30,6 +30,11 @@ class JWKError(JOSEError): pass +class JWKAlgMismatchError(JWKError): + '''JWK Key type doesn't support the given algorithm.''' + pass + + class JWEError(JOSEError): """Base error for all JWE errors""" diff --git a/jose/jws.py b/jose/jws.py index bfaf6bd0..944130e1 100644 --- a/jose/jws.py +++ b/jose/jws.py @@ -5,7 +5,7 @@ from jose import jwk from jose.backends.base import Key from jose.constants import ALGORITHMS -from jose.exceptions import JWSError, JWSSignatureError +from jose.exceptions import JWSError, JWSSignatureError, JWKAlgMismatchError from jose.utils import base64url_decode, base64url_encode @@ -205,7 +205,10 @@ def _load(jwt): def _sig_matches_keys(keys, signing_input, signature, alg): for key in keys: if not isinstance(key, Key): - key = jwk.construct(key, alg) + try: + key = jwk.construct(key, alg) + except JWKAlgMismatchError: + continue try: if key.verify(signing_input, signature): return True diff --git a/tests/test_jws.py b/tests/test_jws.py index 01b5fd05..75a7398d 100644 --- a/tests/test_jws.py +++ b/tests/test_jws.py @@ -6,7 +6,7 @@ from jose import jwk, jws from jose.backends import RSAKey from jose.constants import ALGORITHMS -from jose.exceptions import JWSError +from jose.exceptions import JWSError, JWKAlgMismatchError try: from jose.backends.cryptography_backend import CryptographyRSAKey @@ -25,6 +25,16 @@ def test_unicode_token(self): token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhIjoiYiJ9.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8" jws.verify(token, "secret", ["HS256"]) + def test_hetero_keys(self): + class BadKey(jwk.Key): + def __init__(self, key, algorithm): + if key != "xyzw": + raise JWKAlgMismatchError("%s is not a valid XYZW algorithm" % algorithm) + + token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhIjoiYiJ9.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8" + jwk.register_key("XYZW", BadKey) + jws.verify(token, {"keys": [{"alg": "XYZW"}, "secret"]}, ["XYZW", "HS256"]) + def test_multiple_keys(self): old_jwk_verify = jwk.HMACKey.verify try: