diff --git a/jose/jws.py b/jose/jws.py index bfaf6bd0..f5427ffb 100644 --- a/jose/jws.py +++ b/jose/jws.py @@ -45,7 +45,56 @@ def sign(payload, key, headers=None, algorithm=ALGORITHMS.HS256): return signed_output -def verify(token, key, algorithms, verify=True): +def sign_detached(payload, key, headers=None, algorithm=ALGORITHMS.HS256): + """Signs a claims set and returns a JWS as a detached payload string, as per RFC7797 + + Args: + payload (str or dict): A string to sign + key (str or dict): The key to use for signing the claim set. Can be + individual JWK or JWK set. + headers (dict, optional): A set of headers that will be added to + the default headers. Any headers that are added as additional + headers will override the default headers. + if the signature needs to be generated on encoded payload, then + header has to contain {"b64":True} + algorithm (str, optional): The algorithm to use for signing the + the claims. Defaults to HS256. + + Returns: + str: The string representation of the header, and signature in detached jws format + payload: the payload as received in the request or encoed if {"b4":True} header is passed in the call + + Raises: + JWSError: If there is an error signing the token. + + Examples: + + >>> jws.sign_detached({'a': 'b'}, 'secret', algorithm='HS256') + 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8', {'a': 'b'} + + + >>> jws.sign_detached({'a': 'b'}, 'secret', {"b64": True}, algorithm='HS256') + 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8', eyJhIjoiYiJ9 + + """ + + if algorithm not in ALGORITHMS.SUPPORTED: + raise JWSError("Algorithm %s not supported." % algorithm) + + if headers: + if "b64" in headers and headers["b64"] is True: + payload = _encode_payload(payload) + headers.update({"crit": ["b64"]}) + else: + headers = {"b64": "false"} + + encoded_header = _encode_header(algorithm, additional_headers=headers) + signed_output = _sign_header_and_claims(encoded_header, payload, algorithm, key, True) + + return signed_output, payload + + +def verify(token, key, algorithms=None, verify=True, payload=None): """Verifies a JWS string's signature. Args: @@ -53,9 +102,11 @@ def verify(token, key, algorithms, verify=True): key (str or dict): A key to attempt to verify the payload with. Can be individual JWK or JWK set. algorithms (str or list): Valid algorithms that should be used to verify the JWS. + payload (str or dict): Unencoded payload if the token is a detached jws Returns: str: The str representation of the payload, assuming the signature is valid. + If the token is a detached jws with "b64" true in the header, the return value will be encoded payload Raises: JWSError: If there is an exception verifying a token. @@ -65,9 +116,12 @@ def verify(token, key, algorithms, verify=True): >>> token = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhIjoiYiJ9.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8' >>> jws.verify(token, 'secret', algorithms='HS256') + >>> token = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8' + >>> jws.verify(token, 'secret', algorithms='HS256', payload={"a":"b"}) + """ - header, payload, signing_input, signature = _load(token) + header, payload, signing_input, signature = _load(token, payload) if verify: _verify_signature(signing_input, header, signature, key, algorithms) @@ -126,7 +180,7 @@ def get_unverified_claims(token): def _encode_header(algorithm, additional_headers=None): - header = {"typ": "JWT", "alg": algorithm} + header = {"typ": "JOSE", "alg": algorithm} if additional_headers: header.update(additional_headers) @@ -153,7 +207,7 @@ def _encode_payload(payload): return base64url_encode(payload) -def _sign_header_and_claims(encoded_header, encoded_claims, algorithm, key): +def _sign_header_and_claims(encoded_header, encoded_claims, algorithm, key, is_detached=False): signing_input = b".".join([encoded_header, encoded_claims]) try: if not isinstance(key, Key): @@ -164,12 +218,15 @@ def _sign_header_and_claims(encoded_header, encoded_claims, algorithm, key): encoded_signature = base64url_encode(signature) - encoded_string = b".".join([encoded_header, encoded_claims, encoded_signature]) + if is_detached: + encoded_string = b"..".join([encoded_header, encoded_signature]) + else: + encoded_string = b".".join([encoded_header, encoded_claims, encoded_signature]) return encoded_string.decode("utf-8") -def _load(jwt): +def _load(jwt, payload=None): if isinstance(jwt, str): jwt = jwt.encode("utf-8") try: @@ -189,10 +246,15 @@ def _load(jwt): if not isinstance(header, Mapping): raise JWSError("Invalid header string: must be a json object") - try: - payload = base64url_decode(claims_segment) - except (TypeError, binascii.Error): - raise JWSError("Invalid payload padding") + if not payload: + try: + payload = base64url_decode(claims_segment) + except (TypeError, binascii.Error): + raise JWSError("Invalid payload padding") + else: + if "b64" in header and header["b64"] is True: + payload = _encode_payload(payload) + signing_input = b"".join([signing_input, payload]) try: signature = base64url_decode(crypto_segment) diff --git a/tests/test_jws.py b/tests/test_jws.py index 01b5fd05..1cf3f43c 100644 --- a/tests/test_jws.py +++ b/tests/test_jws.py @@ -7,6 +7,7 @@ from jose.backends import RSAKey from jose.constants import ALGORITHMS from jose.exceptions import JWSError +from jose.utils import base64url_decode, base64url_encode try: from jose.backends.cryptography_backend import CryptographyRSAKey @@ -132,7 +133,7 @@ def test_add_headers(self, payload): expected_headers = { "test": "header", "alg": "HS256", - "typ": "JWT", + "typ": "JOSE", } token = jws.sign(payload, "secret", headers=additional_headers) @@ -307,6 +308,14 @@ def test_jwk_set_failure(self, jwk_set): with pytest.raises(JWSError): payload = jws.verify(google_id_token, jwk_set, ALGORITHMS.RS256) # noqa: F841 + def test_RSA256_detached(self, payload): + token, payload = jws.sign_detached(payload, rsa_private_key, algorithm=ALGORITHMS.RS256) + assert jws.verify(token, rsa_public_key, payload=payload) == payload + + def test_RSA256_detached_encoded(self, payload): + token, encoded_payload = jws.sign_detached(payload, rsa_private_key, {"b64": True}, algorithm=ALGORITHMS.RS256) + assert jws.verify(token, rsa_public_key, payload=payload) == encoded_payload + def test_RSA256(self, payload): token = jws.sign(payload, rsa_private_key, algorithm=ALGORITHMS.RS256) assert jws.verify(token, rsa_public_key, ALGORITHMS.RS256) == payload