diff --git a/pyeudiw/jwt/helper.py b/pyeudiw/jwt/helper.py index 51ffdd1f..9da027b2 100644 --- a/pyeudiw/jwt/helper.py +++ b/pyeudiw/jwt/helper.py @@ -1,3 +1,4 @@ +import re import json from typing import Literal, TypeAlias @@ -18,6 +19,8 @@ KeyLike: TypeAlias = ECKey | RSAKey | OKPKey | SYMKey SerializationFormat = Literal["compact", "json"] +JWT_REGEX = r'^[A-Za-z0-9-_]+\.[A-Za-z0-9-_]+\.[A-Za-z0-9-_]+$' + class JWHelperInterface: def __init__(self, jwks: list[KeyLike | dict] | KeyLike | dict) -> None: @@ -120,6 +123,21 @@ def is_payload_expired(token_payload: dict) -> bool: return True return False +def is_jwt(token: str) -> bool: + """ + Check if a string is a JWT. + + :param token: The string to check. + :type token: str + + :returns: True if the string is a JWT, False otherwise. + :rtype: bool + """ + if not isinstance(token, str): + return False + if re.match(JWT_REGEX, token): + return True + return False def is_jwt_expired(token: str) -> bool: """ diff --git a/pyeudiw/satosa/frontends/openid4vci/endpoints/pushed_authorization_request_endpoint.py b/pyeudiw/satosa/frontends/openid4vci/endpoints/pushed_authorization_request_endpoint.py index 907cb6c2..c38f34d8 100644 --- a/pyeudiw/satosa/frontends/openid4vci/endpoints/pushed_authorization_request_endpoint.py +++ b/pyeudiw/satosa/frontends/openid4vci/endpoints/pushed_authorization_request_endpoint.py @@ -31,6 +31,8 @@ FORM_URLENCODED ) +from pyeudiw.jwt.helper import is_jwt + CLASS_NAME = "ParHandler.pushed_authorization_request_endpoint" class ParHandler(VCIBaseEndpoint): @@ -109,8 +111,15 @@ def endpoint(self, context: Context): request = data.get("request", "").strip() - if request and (self.signed_par_request == "true" or self.signed_par_request == "both"): + if request and self.signed_par_request in ("true", "both"): try: + if not is_jwt(request): + self._log_error( + CLASS_NAME, + f"invalid request parameter for `par`, invalid JWS: {request}" + ) + return self._handle_400(context, "invalid request parameters") + payload = self.jws_helper.verify(request) if not isinstance(payload, dict): @@ -120,6 +129,8 @@ def endpoint(self, context: Context): ) return self._handle_400(context, "invalid request parameters") + payload["jws"] = request + par_request = SignedParRequest.model_validate( payload, context={ ENDPOINT_CTX: "par", @@ -134,7 +145,7 @@ def endpoint(self, context: Context): f"invalid request parameter for `par`, invalid JWS: {request}" ) return self._handle_400(context, "invalid request parameters") - elif (self.signed_par_request == "false" or self.signed_par_request == "both"): + elif self.signed_par_request in ("false", "both"): par_request = ParRequest.model_validate( data, context={ ENDPOINT_CTX: "par", diff --git a/pyeudiw/satosa/frontends/openid4vci/models/par_request.py b/pyeudiw/satosa/frontends/openid4vci/models/par_request.py index d97472f2..f6da56e5 100644 --- a/pyeudiw/satosa/frontends/openid4vci/models/par_request.py +++ b/pyeudiw/satosa/frontends/openid4vci/models/par_request.py @@ -131,6 +131,7 @@ class SignedParRequest(OpenId4VciBaseModel): redirect_uri: str = None jti: str = None issuer_state: str = None + jws: str = None @model_validator(mode='after') def check_par_request(self) -> "ParRequest":