1010from pyeudiw .jwk import JWK
1111from pyeudiw .jwk .exceptions import KidError
1212from pyeudiw .jwk .jwks import find_jwk_by_kid , find_jwk_by_thumbprint
13+ from pyeudiw .jwk .parse import parse_b64der
1314from pyeudiw .jwt .exceptions import (
1415 JWEEncryptionError ,
1516 JWSSigningError ,
@@ -76,9 +77,12 @@ def sign(
7677 of available keys.
7778
7879 If the header already contains indication of a key, such as 'kid',
79- 'trust_chain' and 'x5c', there is no guarantee that the signing
80- key to be used will be aligned with those header. We assume that is
81- it responsibility of the class initiator to make those checks.
80+ 'trust_chain' and 'x5c', the method will attempt to match the signing
81+ key among the available keys based on such claims, but there is no
82+ guarantee that the correct key will be selected. We assume that is
83+ it responsibility of the class initiator to make those checks. To
84+ avoid any possible ambiguity, it is suggested to initilize the class
85+ with one (signing) key only.
8286
8387 :param plain_dict: The payload to be signed.
8488 :param protected: Protected header for the JWS.
@@ -106,13 +110,10 @@ def sign(
106110 if signing_key ["kty" ] == "oct" :
107111 raise JWSSigningError (f"Key { signing_key ['kid' ]} is a symmetric key" )
108112
109- # Ensure the key ID in the header matches the signing key
110- header_kid = protected .get ("kid" )
111- signer_kid = signing_key .get ("kid" )
112- if header_kid and signer_kid and (header_kid != signer_kid ):
113- raise JWSSigningError (
114- f"token header contains a kid { header_kid } that does not match the signing key kid { signer_kid } "
115- )
113+ try :
114+ _validate_key_with_jws_header (signing_key , protected , unprotected )
115+ except Exception as e :
116+ raise JWSSigningError (f"failed to validate signing key: it's content it not valid for current header claims: { e } " , e )
116117
117118 payload = serialize_payload (plain_dict )
118119
@@ -125,6 +126,8 @@ def sign(
125126 protected ["typ" ] = "sd-jwt" if self .is_sd_jwt (plain_dict ) else "JWT"
126127
127128 # Include the signing key's kid in the header if required
129+ header_kid = protected .get ("kid" )
130+ signer_kid = signing_key .get ("kid" )
128131 if kid_in_header and signer_kid :
129132 # note that is actually redundant as the underlying library auto-update the header with the kid
130133 protected ["kid" ] = signer_kid
@@ -177,9 +180,12 @@ def _select_signing_key(
177180 # Case 2: only one *singing* key
178181 if signing_key := self ._select_key_by_use (use = "sig" ):
179182 return signing_key
180- # Case 3: match key by kid: this goes beyond what promised on the method definition
183+ # Case 3: match key by kid
181184 if signing_key := self ._select_key_by_kid (headers ):
182185 return signing_key
186+ # Case 4: match key by x5c
187+ if signing_key := self ._select_key_by_x5c (headers ):
188+ return signing_key
183189 raise JWSSigningError (
184190 "signing error: not possible to uniquely determine the signing key"
185191 )
@@ -199,7 +205,7 @@ def _select_key_by_use(self, use: str) -> dict | None:
199205 return candidate_signing_keys [0 ]
200206 return None
201207
202- def _select_key_by_kid (self , headers : tuple [dict , dict ]) -> dict | None :
208+ def _select_key_by_kid (self , headers : tuple [dict [ str , Any ], dict [ str , Any ] ]) -> dict | None :
203209 if not headers :
204210 return None
205211 if "kid" in headers [0 ]:
@@ -210,6 +216,19 @@ def _select_key_by_kid(self, headers: tuple[dict, dict]) -> dict | None:
210216 return None
211217 return find_jwk_by_kid ([key .to_dict () for key in self .jwks ], kid )
212218
219+ def _select_key_by_x5c (self , headers : tuple [dict [str , Any ], dict [str , Any ]]) -> dict | None :
220+ if not headers :
221+ return None
222+ x5c : list [str ] | None = headers [0 ].get ("x5c" ) or headers [1 ].get ("x5c" )
223+ if not x5c :
224+ return None
225+ header_jwk = parse_b64der (x5c [0 ])
226+ for key in self .jwks :
227+ key_d = key .to_dict ()
228+ if JWK (key_d ).thumbprint == header_jwk .thumbprint :
229+ return key_d
230+ return None
231+
213232 def verify (
214233 self , jwt : str , tolerance_s : int = DEFAULT_TOKEN_TIME_TOLERANCE
215234 ) -> str | Any | bytes :
@@ -320,3 +339,60 @@ def is_sd_jwt(self, token: str) -> bool:
320339 # Log or handle errors (optional)
321340 logger .warning (f"Unable to determine if token is SD-JWT: { e } " )
322341 return False
342+
343+
344+ def _validate_key_with_header_kid (key : dict , header : dict ) -> None :
345+ """
346+ :raises Exception: if the key is not compatible with the header content kid (if any)
347+ """
348+ if (key_kid := key .get ("kid" )) and (header_kid := header .get ("kid" )) and (key_kid != header_kid ):
349+ raise Exception (
350+ f"token header contains a kid { header_kid } that does not match the signing key kid { key_kid } "
351+ )
352+ return
353+
354+
355+ def _validate_key_with_header_x5c (key : dict , header : dict ) -> None :
356+ """
357+ Validate that a key has a public component that matches what defined in
358+ the x5c leaf certificate in the header (if any).
359+ Note that this method DOES NOT validate the chain. Instead, it actually
360+ checks that the leaf of the chain has the same cryptographic material
361+ of the argument key.
362+
363+ :raises Exception: if the key is not compatible with the header content x5c (if any)
364+ """
365+ x5c : list [str ] | None = header .get ("x5c" )
366+ if not x5c :
367+ return
368+ leaf_cert : str = x5c [0 ]
369+
370+ # if the key has a certificate, check the cert, otherwise check the public material
371+ key_x5c : list [str ] | None = key .get ("x5c" )
372+ if key_x5c :
373+ if leaf_cert != (leaf_x5c_cert := key_x5c [0 ]):
374+ raise Exception (
375+ f"token header containes a chain whose leaf certificate { leaf_cert } does not match the signing key leaf certificate { leaf_x5c_cert } " \
376+ )
377+ return
378+ header_key = parse_b64der (leaf_cert )
379+ if header_key .thumbprint != JWK (key ).thumbprint :
380+ raise Exception (
381+ f"public material of the key does not matches the key in the leaf certificate { leaf_cert } "
382+ )
383+ return
384+
385+
386+ def _validate_key_with_jws_header (key : dict , protected_jws_header : dict , unprotected_jws_header : dict ) -> None :
387+ """
388+ Validate that a key used for some operations (sign, verify) on a token
389+ is compatible with the token header itself.
390+
391+ :raises Exception: if the key is not compatible with the token header
392+ """
393+ header = deepcopy (protected_jws_header )
394+ header .update (unprotected_jws_header )
395+ # NOTE: consistency with usage claims such as 'alg', 'kty' and 'use'
396+ # are done by the signer library and are not required here
397+ _validate_key_with_header_kid (key , header )
398+ _validate_key_with_header_x5c (key , header )
0 commit comments