Skip to content

Commit 08f13d8

Browse files
authored
Merge pull request #437 from italia/feat/sign_header_x5c
feat: refactor of jwt key selection for x5c
2 parents 7fdbb7a + e5ab0ba commit 08f13d8

File tree

9 files changed

+218
-54
lines changed

9 files changed

+218
-54
lines changed

pyeudiw/jwk/parse.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import base64
21
from cryptojwt.jwk.ec import import_ec_key, ECKey
32
from cryptojwt.jwk.rsa import RSAKey, import_rsa_key
43
from ssl import DER_cert_to_PEM_cert
@@ -7,6 +6,8 @@
76
from pyeudiw.jwk.exceptions import InvalidJwk
87
from typing import Optional
98

9+
from pyeudiw.x509.verify import B64DER_cert_to_DER_cert
10+
1011
def _parse_rsa_key(pem: str) -> Optional[JWK]:
1112
try:
1213
public_key = import_rsa_key(pem)
@@ -71,7 +72,7 @@ def parse_b64der(b64der: str) -> JWK:
7172
"""
7273
Parse a (public) key from a Base64 encoded DER certificate.
7374
"""
74-
der = base64.b64decode(b64der)
75+
der = B64DER_cert_to_DER_cert(b64der)
7576
return parse_certificate(der)
7677

7778

pyeudiw/jwt/helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ def find_self_contained_key(header: dict) -> tuple[set[str], JWK] | None:
8787
candidate_key: JWK | None = None
8888
try:
8989
candidate_key = parse_x5c_keys(header["x5c"])[0]
90+
return set(["5xc"]), candidate_key
9091
except Exception as e:
9192
logger.debug(
9293
f"failed to parse key from x5c chain {header['x5c']}", exc_info=e
9394
)
94-
return set(["5xc"]), candidate_key
9595
if "jwk" in header:
9696
candidate_key = JWK(header["jwk"])
9797
return set(["jwk"]), candidate_key

pyeudiw/jwt/jws_helper.py

Lines changed: 88 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pyeudiw.jwk import JWK
1111
from pyeudiw.jwk.exceptions import KidError
1212
from pyeudiw.jwk.jwks import find_jwk_by_kid, find_jwk_by_thumbprint
13+
from pyeudiw.jwk.parse import parse_b64der
1314
from pyeudiw.jwt.exceptions import (
1415
JWEEncryptionError,
1516
JWSSigningError,
@@ -76,9 +77,12 @@ def sign(
7677
of available keys.
7778
7879
If the header already contains indication of a key, such as 'kid',
79-
'trust_chain' and 'x5c', there is no guarantee that the signing
80-
key to be used will be aligned with those header. We assume that is
81-
it responsibility of the class initiator to make those checks.
80+
'trust_chain' and 'x5c', the method will attempt to match the signing
81+
key among the available keys based on such claims, but there is no
82+
guarantee that the correct key will be selected. We assume that is
83+
it responsibility of the class initiator to make those checks. To
84+
avoid any possible ambiguity, it is suggested to initilize the class
85+
with one (signing) key only.
8286
8387
:param plain_dict: The payload to be signed.
8488
:param protected: Protected header for the JWS.
@@ -106,13 +110,10 @@ def sign(
106110
if signing_key["kty"] == "oct":
107111
raise JWSSigningError(f"Key {signing_key['kid']} is a symmetric key")
108112

109-
# Ensure the key ID in the header matches the signing key
110-
header_kid = protected.get("kid")
111-
signer_kid = signing_key.get("kid")
112-
if header_kid and signer_kid and (header_kid != signer_kid):
113-
raise JWSSigningError(
114-
f"token header contains a kid {header_kid} that does not match the signing key kid {signer_kid}"
115-
)
113+
try:
114+
_validate_key_with_jws_header(signing_key, protected, unprotected)
115+
except Exception as e:
116+
raise JWSSigningError(f"failed to validate signing key: it's content it not valid for current header claims: {e}", e)
116117

117118
payload = serialize_payload(plain_dict)
118119

@@ -125,6 +126,8 @@ def sign(
125126
protected["typ"] = "sd-jwt" if self.is_sd_jwt(plain_dict) else "JWT"
126127

127128
# Include the signing key's kid in the header if required
129+
header_kid = protected.get("kid")
130+
signer_kid = signing_key.get("kid")
128131
if kid_in_header and signer_kid:
129132
# note that is actually redundant as the underlying library auto-update the header with the kid
130133
protected["kid"] = signer_kid
@@ -177,9 +180,12 @@ def _select_signing_key(
177180
# Case 2: only one *singing* key
178181
if signing_key := self._select_key_by_use(use="sig"):
179182
return signing_key
180-
# Case 3: match key by kid: this goes beyond what promised on the method definition
183+
# Case 3: match key by kid
181184
if signing_key := self._select_key_by_kid(headers):
182185
return signing_key
186+
# Case 4: match key by x5c
187+
if signing_key := self._select_key_by_x5c(headers):
188+
return signing_key
183189
raise JWSSigningError(
184190
"signing error: not possible to uniquely determine the signing key"
185191
)
@@ -199,7 +205,7 @@ def _select_key_by_use(self, use: str) -> dict | None:
199205
return candidate_signing_keys[0]
200206
return None
201207

202-
def _select_key_by_kid(self, headers: tuple[dict, dict]) -> dict | None:
208+
def _select_key_by_kid(self, headers: tuple[dict[str, Any], dict[str, Any]]) -> dict | None:
203209
if not headers:
204210
return None
205211
if "kid" in headers[0]:
@@ -210,6 +216,19 @@ def _select_key_by_kid(self, headers: tuple[dict, dict]) -> dict | None:
210216
return None
211217
return find_jwk_by_kid([key.to_dict() for key in self.jwks], kid)
212218

219+
def _select_key_by_x5c(self, headers: tuple[dict[str, Any], dict[str, Any]]) -> dict | None:
220+
if not headers:
221+
return None
222+
x5c: list[str] | None = headers[0].get("x5c") or headers[1].get("x5c")
223+
if not x5c:
224+
return None
225+
header_jwk = parse_b64der(x5c[0])
226+
for key in self.jwks:
227+
key_d = key.to_dict()
228+
if JWK(key_d).thumbprint == header_jwk.thumbprint:
229+
return key_d
230+
return None
231+
213232
def verify(
214233
self, jwt: str, tolerance_s: int = DEFAULT_TOKEN_TIME_TOLERANCE
215234
) -> str | Any | bytes:
@@ -320,3 +339,60 @@ def is_sd_jwt(self, token: str) -> bool:
320339
# Log or handle errors (optional)
321340
logger.warning(f"Unable to determine if token is SD-JWT: {e}")
322341
return False
342+
343+
344+
def _validate_key_with_header_kid(key: dict, header: dict) -> None:
345+
"""
346+
:raises Exception: if the key is not compatible with the header content kid (if any)
347+
"""
348+
if (key_kid := key.get("kid")) and (header_kid := header.get("kid")) and (key_kid != header_kid):
349+
raise Exception(
350+
f"token header contains a kid {header_kid} that does not match the signing key kid {key_kid}"
351+
)
352+
return
353+
354+
355+
def _validate_key_with_header_x5c(key: dict, header: dict) -> None:
356+
"""
357+
Validate that a key has a public component that matches what defined in
358+
the x5c leaf certificate in the header (if any).
359+
Note that this method DOES NOT validate the chain. Instead, it actually
360+
checks that the leaf of the chain has the same cryptographic material
361+
of the argument key.
362+
363+
:raises Exception: if the key is not compatible with the header content x5c (if any)
364+
"""
365+
x5c: list[str] | None = header.get("x5c")
366+
if not x5c:
367+
return
368+
leaf_cert: str = x5c[0]
369+
370+
# if the key has a certificate, check the cert, otherwise check the public material
371+
key_x5c: list[str] | None = key.get("x5c")
372+
if key_x5c:
373+
if leaf_cert != (leaf_x5c_cert := key_x5c[0]):
374+
raise Exception(
375+
f"token header containes a chain whose leaf certificate {leaf_cert} does not match the signing key leaf certificate {leaf_x5c_cert}"\
376+
)
377+
return
378+
header_key = parse_b64der(leaf_cert)
379+
if header_key.thumbprint != JWK(key).thumbprint:
380+
raise Exception(
381+
f"public material of the key does not matches the key in the leaf certificate {leaf_cert}"
382+
)
383+
return
384+
385+
386+
def _validate_key_with_jws_header(key: dict, protected_jws_header: dict, unprotected_jws_header: dict) -> None:
387+
"""
388+
Validate that a key used for some operations (sign, verify) on a token
389+
is compatible with the token header itself.
390+
391+
:raises Exception: if the key is not compatible with the token header
392+
"""
393+
header = deepcopy(protected_jws_header)
394+
header.update(unprotected_jws_header)
395+
# NOTE: consistency with usage claims such as 'alg', 'kty' and 'use'
396+
# are done by the signer library and are not required here
397+
_validate_key_with_header_kid(key, header)
398+
_validate_key_with_header_x5c(key, header)

pyeudiw/openid4vp/presentation_submission/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def validate(
137137

138138
for descriptor in validated_submission.descriptor_map:
139139
handler = self.handlers.get(descriptor.format)
140-
141140
if not handler:
142141
raise MissingHandler(f"Handler for format '{descriptor.format}' not found.")
143142

pyeudiw/satosa/default/request_handler.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -75,27 +75,11 @@ def request_endpoint(self, context: Context, *args) -> Response:
7575
trust_params = self.trust_evaluator.get_jwt_header_trust_parameters(issuer=self.client_id)
7676
_protected_jwt_headers.update(trust_params)
7777

78-
metadata_key = None
79-
80-
if "x5c" in _protected_jwt_headers:
81-
# TODO: move this logic in the JWS signer...
82-
jwk = parse_b64der(_protected_jwt_headers["x5c"][0])
83-
84-
for key in self.config["metadata_jwks"]:
85-
if JWK(key).thumbprint == jwk.thumbprint:
86-
metadata_key = key
87-
break
88-
89-
if not metadata_key:
90-
return self._handle_500(
91-
context,
92-
"internal error: unable to find the key in the metadata",
93-
ValueError("unable to find the key in the metadata"),
94-
)
78+
if ("x5c" in _protected_jwt_headers) or ("kid" in _protected_jwt_headers):
79+
# let helper decide which key best fit the given header, otherise use default hich is the first confgiured key
80+
helper = JWSHelper(self.config["metadata_jwks"])
9581
else:
96-
metadata_key = self.default_metadata_private_jwk
97-
98-
helper = JWSHelper(metadata_key)
82+
helper = JWSHelper(self.default_metadata_private_jwk)
9983

10084
try:
10185
request_object_jwt = helper.sign(

pyeudiw/tests/jwt/test_helper.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
1+
from cryptography.hazmat.primitives.asymmetric import ec
2+
from cryptojwt.jwk.ec import ECKey
3+
4+
from pyeudiw.jwk import JWK
15
from pyeudiw.jwt.helper import validate_jwt_timestamps_claims
6+
from pyeudiw.jwt.jws_helper import _validate_key_with_jws_header
27
from pyeudiw.tools.utils import iat_now
8+
import pyeudiw.tests.x509.test_x509 as test_x509
9+
from pyeudiw.x509.verify import DER_cert_to_B64DER_cert
310

411

512
def test_validate_jwt_timestamps_claims_ok():
@@ -71,3 +78,63 @@ def test_test_validate_jwt_timestamps_claims_tolerance_window():
7178
assert (
7279
False
7380
), f"encountered unexpeted error when validating the lifetime of a token payload with a tolerance window (for exp): {e}"
81+
82+
83+
def test_validate_key_with_jws_header_x5c_ok():
84+
private_ec_key = ec.generate_private_key(ec.SECP256R1())
85+
x509_der_chain = test_x509.gen_chain(leaf_private_key=private_ec_key)
86+
x5c = [DER_cert_to_B64DER_cert(der) for der in x509_der_chain]
87+
88+
ec_jwk = ECKey()
89+
ec_jwk.load_key(private_ec_key)
90+
key = ec_jwk.serialize(private=True)
91+
92+
try:
93+
_validate_key_with_jws_header(key, {"x5c": x5c}, {})
94+
assert True
95+
except Exception as e:
96+
assert False, f"unexpected exception when validating header for correct key: {e}"
97+
98+
99+
def test_validate_key_with_jws_header_kid_ok():
100+
key = JWK().as_dict()
101+
kid = "1234567890"
102+
key["kid"] = kid
103+
104+
try:
105+
_validate_key_with_jws_header(key, {"kid": kid}, {})
106+
assert True
107+
except Exception as e:
108+
assert False, f"unexpected exception when validating header for correct key: {e}"
109+
110+
111+
def test_validate_key_with_jws_header_expect_x5c_fail():
112+
private_ec_key = ec.generate_private_key(ec.SECP256R1())
113+
x509_der_chain = test_x509.gen_chain(leaf_private_key=private_ec_key)
114+
x5c = [DER_cert_to_B64DER_cert(der) for der in x509_der_chain]
115+
116+
wrong_ec_key = ec.generate_private_key(ec.SECP256R1())
117+
wrong_ec_jwk = ECKey()
118+
wrong_ec_jwk.load_key(wrong_ec_key)
119+
wrong_key = wrong_ec_jwk.serialize(private=True)
120+
121+
try:
122+
_validate_key_with_jws_header(wrong_key, {"x5c": x5c}, {})
123+
assert False, f"should have encountered exception when validating header 'x5c' for wrong key"
124+
except Exception as _:
125+
assert True
126+
127+
def test_validate_key_with_jws_header_expect_kid_fail():
128+
wrong_key = JWK().as_dict()
129+
wrong_kid = "1234567890"
130+
wrong_key["kid"] = wrong_kid
131+
132+
key = JWK().as_dict()
133+
kid = "qwertyuiop"
134+
key["kid"] = kid
135+
136+
try:
137+
_validate_key_with_jws_header(key, {"kid": "1234567890"}, {})
138+
assert False, f"should have encountered exception when validating header 'kid' for wrong key"
139+
except Exception as _:
140+
assert True

pyeudiw/tests/jwt/test_sign_verify.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
from cryptography.hazmat.primitives.asymmetric import ec
2+
from cryptojwt.jwk.ec import ECKey
13
import pytest
24

35
from pyeudiw.jwt.jws_helper import DEFAULT_TOKEN_TIME_TOLERANCE, JWSHelper
46
from pyeudiw.jwt.utils import decode_jwt_header
7+
import pyeudiw.tests.x509.test_x509 as test_x509
58
from pyeudiw.tools.utils import iat_now
9+
from pyeudiw.x509.verify import DER_cert_to_B64DER_cert
610

711

812
class TestJWSHeperSelectSigningKey:
@@ -49,6 +53,19 @@ def test_JWSHelper_select_signing_key_infer_kid(self, sign_jwks):
4953
k = signer._select_signing_key(({"kid": exp_k["kid"]}, {}))
5054
assert k == exp_k
5155

56+
def test_JWSHelper_select_signing_key_infer_kid(self, sign_jwks: list[dict]):
57+
new_private_ec_key = ec.generate_private_key(ec.SECP256R1())
58+
x509_der_chain = test_x509.gen_chain(leaf_private_key=new_private_ec_key)
59+
x5c = [DER_cert_to_B64DER_cert(der) for der in x509_der_chain]
60+
new_ec_jwk = ECKey()
61+
new_ec_jwk.load_key(new_private_ec_key)
62+
exp_key: dict = new_ec_jwk.serialize(private=True)
63+
sign_jwks.append(exp_key)
64+
65+
signer = JWSHelper(sign_jwks)
66+
obt_key = signer._select_signing_key(({"x5c": x5c}, {}))
67+
assert exp_key == obt_key
68+
5269
def test_JWSHelper_select_signing_key_unique(self, sign_jwks):
5370
signer = JWSHelper(sign_jwks[0])
5471
exp_k = sign_jwks[0]

pyeudiw/tests/satosa/test_backend.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from pyeudiw.jwt.jwe_helper import JWEHelper
5555
from pyeudiw.satosa.utils.response import JsonResponse
5656
from pyeudiw.tests.x509.test_x509 import gen_chain
57-
from pyeudiw.x509.verify import to_pem_list
57+
from pyeudiw.x509.verify import PEM_cert_to_B64DER_cert, to_pem_list
5858
from pyeudiw.jwk.parse import parse_pem
5959

6060
PKEY = {
@@ -192,7 +192,7 @@ def _generate_payload(self, issuer_jwk, holder_jwk, nonce, state, aud, x509=Fals
192192

193193
if x509:
194194
additional_headers = {
195-
"x5c": self.chain
195+
"x5c": [PEM_cert_to_B64DER_cert(pem) for pem in self.chain]
196196
}
197197
else:
198198
additional_headers = {
@@ -622,7 +622,6 @@ def test_response_endpoint_x5c_chain(self, context):
622622
}
623623
context.request_method = "POST"
624624
context.http_headers = {"HTTP_CONTENT_TYPE": "application/x-www-form-urlencoded"}
625-
626625
response_endpoint = self.backend.response_endpoint(context)
627626
assert response_endpoint.status == "200"
628627
assert "redirect_uri" in response_endpoint.message

0 commit comments

Comments
 (0)