Skip to content

Commit 8fb73f3

Browse files
authored
feat: create verify jwt (#3)
1 parent a6af4ec commit 8fb73f3

File tree

5 files changed

+300
-5
lines changed

5 files changed

+300
-5
lines changed

mcpauth/exceptions.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,15 +146,13 @@ def to_json(self, show_cause: bool = False) -> Dict[str, Optional[str]]:
146146
class MCPAuthJwtVerificationExceptionCode(str, Enum):
147147
INVALID_JWT = "invalid_jwt"
148148
JWT_VERIFICATION_FAILED = "jwt_verification_failed"
149-
JWT_EXPIRED = "jwt_expired"
150149

151150

152151
jwt_verification_exception_description: Dict[
153152
MCPAuthJwtVerificationExceptionCode, str
154153
] = {
155154
MCPAuthJwtVerificationExceptionCode.INVALID_JWT: "The provided JWT is invalid or malformed.",
156155
MCPAuthJwtVerificationExceptionCode.JWT_VERIFICATION_FAILED: "JWT verification failed. The token could not be verified.",
157-
MCPAuthJwtVerificationExceptionCode.JWT_EXPIRED: "The provided JWT has expired.",
158156
}
159157

160158

mcpauth/middleware/create_bearer_auth.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ def get_bearer_token_from_headers(headers: Headers) -> str:
5151

5252
auth_header = headers.get("authorization") or headers.get("Authorization")
5353

54-
print(f"Authorization header: {auth_header}")
55-
5654
if not auth_header:
5755
raise MCPAuthBearerAuthException(BearerAuthExceptionCode.MISSING_AUTH_HEADER)
5856

mcpauth/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class AuthInfo(BaseModel):
7575
- https://datatracker.ietf.org/doc/html/rfc8707
7676
"""
7777

78-
claims: Optional[Dict[str, Any]]
78+
claims: Dict[str, Any]
7979
"""
8080
The raw claims from the token, which can include any additional information provided by the
8181
token issuer.

mcpauth/utils/create_verify_jwt.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from typing import Annotated, Any, List, Optional, Union
2+
from jwt import PyJWK, PyJWKClient, PyJWTError, decode
3+
from pydantic import BaseModel, StringConstraints, ValidationError
4+
from ..types import AuthInfo, VerifyAccessTokenFunction
5+
from ..exceptions import (
6+
MCPAuthJwtVerificationException,
7+
MCPAuthJwtVerificationExceptionCode,
8+
)
9+
10+
NonEmptyString = Annotated[str, StringConstraints(min_length=1)]
11+
12+
13+
class JwtBaseModel(BaseModel):
14+
aud: Optional[Union[NonEmptyString, List[NonEmptyString]]] = None
15+
iss: NonEmptyString
16+
client_id: NonEmptyString
17+
sub: NonEmptyString
18+
scope: Optional[Union[str, List[str]]] = None
19+
scopes: Optional[Union[str, List[str]]] = None
20+
exp: Optional[int] = None
21+
22+
23+
def create_verify_jwt(
24+
input: Union[str, PyJWKClient, PyJWK],
25+
algorithms: List[str] = ["RS256", "PS256", "ES256", "ES384", "ES512"],
26+
leeway: int = 60,
27+
options: dict[str, Any] = {},
28+
) -> VerifyAccessTokenFunction:
29+
"""
30+
Creates a JWT verification function using the provided JWKS URI.
31+
32+
:param input: Supports one of the following:
33+
- A JWKS URI (string) that points to a JSON Web Key Set.
34+
- An instance of `PyJWKClient` that has been initialized with the JWKS URI.
35+
- An instance of `PyJWK` that represents a single JWK.
36+
:param algorithms: A list of acceptable algorithms for verifying the JWT signature.
37+
:param leeway: The amount of leeway (in seconds) to allow when checking the expiration time of the JWT.
38+
:param options: Additional options to pass to the JWT decode function (`jwt.decode`).
39+
:return: A function that can be used to verify JWTs.
40+
"""
41+
42+
jwks = (
43+
input
44+
if isinstance(input, PyJWKClient)
45+
else (
46+
PyJWKClient(
47+
input, headers={"user-agent": "@mcp-auth/python", "accept": "*/*"}
48+
)
49+
if isinstance(input, str)
50+
else input
51+
)
52+
)
53+
54+
def verify_jwt(token: str) -> AuthInfo:
55+
try:
56+
signing_key = (
57+
jwks.get_signing_key_from_jwt(token)
58+
if isinstance(jwks, PyJWKClient)
59+
else jwks
60+
)
61+
decoded = decode(
62+
token,
63+
signing_key.key,
64+
algorithms=algorithms,
65+
leeway=leeway,
66+
options={
67+
"verify_aud": False,
68+
"verify_iss": False,
69+
}
70+
| options,
71+
)
72+
base_model = JwtBaseModel(**decoded)
73+
scopes = base_model.scope or base_model.scopes
74+
return AuthInfo(
75+
token=token,
76+
issuer=base_model.iss,
77+
client_id=base_model.client_id,
78+
subject=base_model.sub,
79+
audience=base_model.aud,
80+
scopes=(scopes.split(" ") if isinstance(scopes, str) else scopes) or [],
81+
expires_at=base_model.exp,
82+
claims=decoded,
83+
)
84+
except (PyJWTError, ValidationError) as e:
85+
raise MCPAuthJwtVerificationException(
86+
MCPAuthJwtVerificationExceptionCode.INVALID_JWT,
87+
cause=e,
88+
)
89+
except Exception as e:
90+
raise MCPAuthJwtVerificationException(
91+
MCPAuthJwtVerificationExceptionCode.JWT_VERIFICATION_FAILED,
92+
cause=e,
93+
)
94+
95+
return verify_jwt
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
import pytest
2+
import time
3+
import jwt
4+
import base64
5+
from typing import Dict, Any
6+
from mcpauth.utils.create_verify_jwt import create_verify_jwt
7+
from mcpauth.types import AuthInfo
8+
9+
10+
from mcpauth.exceptions import (
11+
MCPAuthJwtVerificationException,
12+
MCPAuthJwtVerificationExceptionCode,
13+
)
14+
15+
_secret_key = b"super-secret-key-for-testing"
16+
_algorithm = "HS256"
17+
18+
19+
def create_jwk(key: bytes = _secret_key) -> jwt.PyJWK:
20+
"""Create a JWK for testing purposes"""
21+
return jwt.PyJWK(
22+
{
23+
"kty": "oct",
24+
"k": base64.urlsafe_b64encode(key).decode("utf-8"),
25+
"alg": _algorithm,
26+
}
27+
)
28+
29+
30+
def create_jwt(payload: Dict[str, Any]) -> str:
31+
"""Create a test JWT with the given payload"""
32+
return jwt.encode(
33+
{
34+
**payload,
35+
"iat": int(time.time()),
36+
"exp": int(time.time()) + 3600, # 1 hour
37+
},
38+
_secret_key,
39+
algorithm=_algorithm,
40+
)
41+
42+
43+
verify_jwt = create_verify_jwt(create_jwk(), algorithms=[_algorithm])
44+
45+
46+
class TestCreateVerifyJwtErrorHandling:
47+
def test_should_throw_error_if_signature_verification_fails(self):
48+
# Create JWT with correct secret
49+
jwt_token = create_jwt({"client_id": "client12345", "sub": "user12345"})
50+
verify_jwt = create_verify_jwt(
51+
create_jwk(b"wrong-secret-key-for-testing"), algorithms=[_algorithm]
52+
)
53+
54+
# Verify that the correct exception is raised
55+
with pytest.raises(MCPAuthJwtVerificationException) as exc_info:
56+
verify_jwt(jwt_token)
57+
58+
assert exc_info.value.code == MCPAuthJwtVerificationExceptionCode.INVALID_JWT
59+
assert isinstance(exc_info.value.cause, jwt.InvalidSignatureError)
60+
61+
def test_should_throw_error_if_jwt_payload_missing_iss(self):
62+
# Test different invalid JWT payloads
63+
jwt_missing_iss = create_jwt({"client_id": "client12345", "sub": "user12345"})
64+
jwt_invalid_iss_type = create_jwt(
65+
{"iss": 12345, "client_id": "client12345", "sub": "user12345"}
66+
)
67+
jwt_empty_iss = create_jwt(
68+
{"iss": "", "client_id": "client12345", "sub": "user12345"}
69+
)
70+
71+
for token in [jwt_missing_iss, jwt_invalid_iss_type, jwt_empty_iss]:
72+
with pytest.raises(MCPAuthJwtVerificationException) as exc_info:
73+
verify_jwt(token)
74+
assert (
75+
exc_info.value.code == MCPAuthJwtVerificationExceptionCode.INVALID_JWT
76+
)
77+
78+
def test_should_throw_error_if_jwt_payload_missing_client_id(self):
79+
# Test different invalid JWT payloads
80+
jwt_missing_client_id = create_jwt(
81+
{"iss": "https://logto.io/", "sub": "user12345"}
82+
)
83+
jwt_invalid_client_id_type = create_jwt(
84+
{"iss": "https://logto.io/", "client_id": 12345, "sub": "user12345"}
85+
)
86+
jwt_empty_client_id = create_jwt(
87+
{"iss": "https://logto.io/", "client_id": "", "sub": "user12345"}
88+
)
89+
90+
for token in [
91+
jwt_missing_client_id,
92+
jwt_invalid_client_id_type,
93+
jwt_empty_client_id,
94+
]:
95+
with pytest.raises(MCPAuthJwtVerificationException) as exc_info:
96+
verify_jwt(token)
97+
assert (
98+
exc_info.value.code == MCPAuthJwtVerificationExceptionCode.INVALID_JWT
99+
)
100+
101+
def test_should_throw_error_if_jwt_payload_missing_sub(self):
102+
# Test different invalid JWT payloads
103+
jwt_missing_sub = create_jwt(
104+
{"iss": "https://logto.io/", "client_id": "client12345"}
105+
)
106+
jwt_invalid_sub_type = create_jwt(
107+
{"iss": "https://logto.io/", "client_id": "client12345", "sub": 12345}
108+
)
109+
jwt_empty_sub = create_jwt(
110+
{"iss": "https://logto.io/", "client_id": "client12345", "sub": ""}
111+
)
112+
113+
for token in [jwt_missing_sub, jwt_invalid_sub_type, jwt_empty_sub]:
114+
with pytest.raises(MCPAuthJwtVerificationException) as exc_info:
115+
verify_jwt(token)
116+
assert (
117+
exc_info.value.code == MCPAuthJwtVerificationExceptionCode.INVALID_JWT
118+
)
119+
120+
121+
class TestCreateVerifyJwtNormalBehavior:
122+
def test_should_return_verified_jwt_payload_with_string_scope(self):
123+
# Create JWT with string scope
124+
claims = {
125+
"iss": "https://logto.io/",
126+
"client_id": "client12345",
127+
"sub": "user12345",
128+
"scope": "read write",
129+
"aud": "audience12345",
130+
}
131+
jwt_token = create_jwt(claims)
132+
133+
# Verify
134+
result = verify_jwt(jwt_token)
135+
136+
# Assertions
137+
assert isinstance(result, AuthInfo)
138+
assert result.token == jwt_token
139+
assert result.issuer == claims["iss"]
140+
assert result.client_id == claims["client_id"]
141+
assert result.subject == claims["sub"]
142+
assert result.audience == claims["aud"]
143+
assert result.scopes == ["read", "write"]
144+
assert "exp" in result.claims
145+
assert "iat" in result.claims
146+
assert result.expires_at is not None
147+
148+
def test_should_return_verified_jwt_payload_with_array_scope(self):
149+
# Create JWT with array scope
150+
claims: Dict[str, Any] = {
151+
"iss": "https://logto.io/",
152+
"client_id": "client12345",
153+
"sub": "user12345",
154+
"scope": ["read", "write"],
155+
}
156+
jwt_token = create_jwt(claims)
157+
158+
# Verify
159+
result = verify_jwt(jwt_token)
160+
161+
# Assertions
162+
assert result.issuer == claims["iss"]
163+
assert result.client_id == claims["client_id"]
164+
assert result.subject == claims["sub"]
165+
assert result.scopes == ["read", "write"]
166+
167+
def test_should_return_verified_jwt_payload_with_scopes_field(self):
168+
# Create JWT with scopes field
169+
claims: Dict[str, Any] = {
170+
"iss": "https://logto.io/",
171+
"client_id": "client12345",
172+
"sub": "user12345",
173+
"scopes": ["read", "write"],
174+
}
175+
jwt_token = create_jwt(claims)
176+
177+
# Verify
178+
result = verify_jwt(jwt_token)
179+
180+
# Assertions
181+
assert result.issuer == claims["iss"]
182+
assert result.client_id == claims["client_id"]
183+
assert result.subject == claims["sub"]
184+
assert result.scopes == ["read", "write"]
185+
186+
def test_should_return_verified_jwt_payload_without_scopes(self):
187+
# Create JWT without scope or scopes
188+
claims = {
189+
"iss": "https://logto.io/",
190+
"client_id": "client12345",
191+
"sub": "user12345",
192+
"aud": "audience12345",
193+
}
194+
jwt_token = create_jwt(claims)
195+
196+
# Verify
197+
result = verify_jwt(jwt_token)
198+
199+
# Assertions
200+
assert result.issuer == claims["iss"]
201+
assert result.client_id == claims["client_id"]
202+
assert result.subject == claims["sub"]
203+
assert result.audience == claims["aud"]
204+
assert result.scopes == []

0 commit comments

Comments
 (0)