Skip to content

Commit a8b1202

Browse files
blutrilpamelafox
andauthored
Replace python-jose with pyjwt (#1875)
* Replace python-jose with pyjwt. * Replace non-existent get_unverified_claims function * Change Exception to handle JWT-specific errors * Convert the public key to PEM format * Add pem format tests * Another test, plus autherror fixes --------- Co-authored-by: Pamela Fox <[email protected]>
1 parent 27816c1 commit a8b1202

File tree

4 files changed

+207
-46
lines changed

4 files changed

+207
-46
lines changed

app/backend/core/authentication.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
# Refactored from https://github.com/Azure-Samples/ms-identity-python-on-behalf-of
22

3+
import base64
34
import json
45
import logging
56
from typing import Any, Optional
67

78
import aiohttp
9+
import jwt
810
from azure.search.documents.aio import SearchClient
911
from azure.search.documents.indexes.models import SearchIndex
10-
from jose import jwt
11-
from jose.exceptions import ExpiredSignatureError, JWTClaimsError
12+
from cryptography.hazmat.primitives import serialization
13+
from cryptography.hazmat.primitives.asymmetric import rsa
1214
from msal import ConfidentialClientApplication
1315
from msal.token_cache import TokenCache
1416
from tenacity import (
@@ -282,6 +284,24 @@ async def check_path_auth(self, path: str, auth_claims: dict[str, Any], search_c
282284

283285
return allowed
284286

287+
async def create_pem_format(self, jwks, token):
288+
unverified_header = jwt.get_unverified_header(token)
289+
for key in jwks["keys"]:
290+
if key["kid"] == unverified_header["kid"]:
291+
# Construct the RSA public key
292+
public_numbers = rsa.RSAPublicNumbers(
293+
e=int.from_bytes(base64.urlsafe_b64decode(key["e"] + "=="), byteorder="big"),
294+
n=int.from_bytes(base64.urlsafe_b64decode(key["n"] + "=="), byteorder="big"),
295+
)
296+
public_key = public_numbers.public_key()
297+
298+
# Convert to PEM format
299+
pem_key = public_key.public_bytes(
300+
encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo
301+
)
302+
rsa_key = pem_key
303+
return rsa_key
304+
285305
# See https://github.com/Azure-Samples/ms-identity-python-on-behalf-of/blob/939be02b11f1604814532fdacc2c2eccd198b755/FlaskAPI/helpers/authorization.py#L44
286306
async def validate_access_token(self, token: str):
287307
"""
@@ -304,51 +324,38 @@ async def validate_access_token(self, token: str):
304324
jwks = await resp.json()
305325

306326
if not jwks or "keys" not in jwks:
307-
raise AuthError({"code": "invalid_keys", "description": "Unable to get keys to validate auth token."}, 401)
327+
raise AuthError("Unable to get keys to validate auth token.", 401)
308328

309329
rsa_key = None
310330
issuer = None
311331
audience = None
312332
try:
313-
unverified_header = jwt.get_unverified_header(token)
314-
unverified_claims = jwt.get_unverified_claims(token)
333+
unverified_claims = jwt.decode(token, options={"verify_signature": False})
315334
issuer = unverified_claims.get("iss")
316335
audience = unverified_claims.get("aud")
317-
for key in jwks["keys"]:
318-
if key["kid"] == unverified_header["kid"]:
319-
rsa_key = {"kty": key["kty"], "kid": key["kid"], "use": key["use"], "n": key["n"], "e": key["e"]}
320-
break
321-
except Exception as exc:
322-
raise AuthError(
323-
{"code": "invalid_header", "description": "Unable to parse authorization token."}, 401
324-
) from exc
336+
rsa_key = await self.create_pem_format(jwks, token)
337+
except jwt.PyJWTError as exc:
338+
raise AuthError("Unable to parse authorization token.", 401) from exc
325339
if not rsa_key:
326-
raise AuthError({"code": "invalid_header", "description": "Unable to find appropriate key"}, 401)
340+
raise AuthError("Unable to find appropriate key", 401)
327341

328342
if issuer not in self.valid_issuers:
329-
raise AuthError(
330-
{"code": "invalid_header", "description": f"Issuer {issuer} not in {','.join(self.valid_issuers)}"}, 401
331-
)
343+
raise AuthError(f"Issuer {issuer} not in {','.join(self.valid_issuers)}", 401)
332344

333345
if audience not in self.valid_audiences:
334346
raise AuthError(
335-
{
336-
"code": "invalid_header",
337-
"description": f"Audience {audience} not in {','.join(self.valid_audiences)}",
338-
},
347+
f"Audience {audience} not in {','.join(self.valid_audiences)}",
339348
401,
340349
)
341350

342351
try:
343352
jwt.decode(token, rsa_key, algorithms=["RS256"], audience=audience, issuer=issuer)
344-
except ExpiredSignatureError as jwt_expired_exc:
345-
raise AuthError({"code": "token_expired", "description": "token is expired"}, 401) from jwt_expired_exc
346-
except JWTClaimsError as jwt_claims_exc:
353+
except jwt.ExpiredSignatureError as jwt_expired_exc:
354+
raise AuthError("Token is expired", 401) from jwt_expired_exc
355+
except (jwt.InvalidAudienceError, jwt.InvalidIssuerError) as jwt_claims_exc:
347356
raise AuthError(
348-
{"code": "invalid_claims", "description": "incorrect claims," "please check the audience and issuer"},
357+
"Incorrect claims: please check the audience and issuer",
349358
401,
350359
) from jwt_claims_exc
351360
except Exception as exc:
352-
raise AuthError(
353-
{"code": "invalid_header", "description": "Unable to parse authorization token."}, 401
354-
) from exc
361+
raise AuthError("Unable to parse authorization token.", 401) from exc

app/backend/requirements.in

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ opentelemetry-instrumentation-aiohttp-client
2020
opentelemetry-instrumentation-openai
2121
msal
2222
cryptography
23-
python-jose[cryptography]
24-
types-python-jose
23+
PyJWT
2524
Pillow
2625
types-Pillow
2726
pypdf

app/backend/requirements.txt

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,10 @@ cryptography==43.0.0
8686
# azure-storage-blob
8787
# msal
8888
# pyjwt
89-
# python-jose
9089
deprecated==1.2.14
9190
# via opentelemetry-api
9291
distro==1.9.0
9392
# via openai
94-
ecdsa==0.19.0
95-
# via python-jose
9693
fixedint==0.1.6
9794
# via azure-monitor-opentelemetry-exporter
9895
flask==3.0.3
@@ -324,10 +321,6 @@ priority==2.0.0
324321
# via hypercorn
325322
psutil==5.9.8
326323
# via azure-monitor-opentelemetry-exporter
327-
pyasn1==0.6.0
328-
# via
329-
# python-jose
330-
# rsa
331324
pycparser==2.22
332325
# via cffi
333326
pydantic==2.8.2
@@ -349,8 +342,6 @@ python-dateutil==2.9.0.post0
349342
# microsoft-kiota-serialization-text
350343
# pendulum
351344
# time-machine
352-
python-jose[cryptography]==3.3.0
353-
# via -r requirements.in
354345
quart==0.19.6
355346
# via
356347
# -r requirements.in
@@ -368,8 +359,6 @@ requests==2.32.3
368359
# tiktoken
369360
requests-oauthlib==2.0.0
370361
# via msrest
371-
rsa==4.9
372-
# via python-jose
373362
six==1.16.0
374363
# via
375364
# azure-core
@@ -402,10 +391,6 @@ types-html5lib==1.1.11.20240228
402391
# via types-beautifulsoup4
403392
types-pillow==10.2.0.20240520
404393
# via -r requirements.in
405-
types-pyasn1==0.6.0.20240402
406-
# via types-python-jose
407-
types-python-jose==3.3.4.20240106
408-
# via -r requirements.in
409394
typing-extensions==4.12.2
410395
# via
411396
# azure-ai-documentintelligence

tests/test_authenticationhelper.py

Lines changed: 171 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
1+
import base64
12
import json
3+
import re
4+
from datetime import datetime, timedelta
25

6+
import aiohttp
7+
import jwt
38
import pytest
49
from azure.core.credentials import AzureKeyCredential
510
from azure.search.documents.aio import SearchClient
611
from azure.search.documents.indexes.models import SearchField, SearchIndex
12+
from cryptography.hazmat.primitives import serialization
13+
from cryptography.hazmat.primitives.asymmetric import rsa
714

815
from core.authentication import AuthenticationHelper, AuthError
916

10-
from .mocks import MockAsyncPageIterator
17+
from .mocks import MockAsyncPageIterator, MockResponse
1118

1219
MockSearchIndex = SearchIndex(
1320
name="test",
@@ -40,6 +47,36 @@ def create_search_client():
4047
return SearchClient(endpoint="", index_name="", credential=AzureKeyCredential(""))
4148

4249

50+
def create_mock_jwt(kid="mock_kid", oid="OID_X"):
51+
# Create a payload with necessary claims
52+
payload = {
53+
"iss": "https://login.microsoftonline.com/TENANT_ID/v2.0",
54+
"sub": "AaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaA",
55+
"aud": "SERVER_APP",
56+
"exp": int((datetime.utcnow() + timedelta(hours=1)).timestamp()),
57+
"iat": int(datetime.utcnow().timestamp()),
58+
"nbf": int(datetime.utcnow().timestamp()),
59+
"name": "John Doe",
60+
"oid": oid,
61+
"preferred_username": "[email protected]",
62+
"rh": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA.",
63+
"tid": "22222222-2222-2222-2222-222222222222",
64+
"uti": "AbCdEfGhIjKlMnOp-ABCDEFG",
65+
"ver": "2.0",
66+
}
67+
68+
# Create a header
69+
header = {"kid": kid, "alg": "RS256", "typ": "JWT"}
70+
71+
# Create a mock private key (for signing)
72+
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
73+
74+
# Create the JWT
75+
token = jwt.encode(payload, private_key, algorithm="RS256", headers=header)
76+
77+
return token, private_key.public_key(), payload
78+
79+
4380
@pytest.mark.asyncio
4481
async def test_get_auth_claims_success(mock_confidential_client_success, mock_validate_token_success):
4582
helper = create_authentication_helper()
@@ -479,3 +516,136 @@ async def mock_search(self, *args, **kwargs):
479516
)
480517
assert filter is None
481518
assert called_search is False
519+
520+
521+
@pytest.mark.asyncio
522+
async def test_create_pem_format(mock_confidential_client_success, mock_validate_token_success):
523+
helper = create_authentication_helper()
524+
mock_token, public_key, payload = create_mock_jwt(oid="OID_X")
525+
_, other_public_key, _ = create_mock_jwt(oid="OID_Y")
526+
mock_jwks = {
527+
"keys": [
528+
# Include a key with a different KID to ensure the correct key is selected
529+
{
530+
"kty": "RSA",
531+
"kid": "other_mock_kid",
532+
"use": "sig",
533+
"n": base64.urlsafe_b64encode(
534+
other_public_key.public_numbers().n.to_bytes(
535+
(other_public_key.public_numbers().n.bit_length() + 7) // 8, byteorder="big"
536+
)
537+
)
538+
.decode("utf-8")
539+
.rstrip("="),
540+
"e": base64.urlsafe_b64encode(
541+
other_public_key.public_numbers().e.to_bytes(
542+
(other_public_key.public_numbers().e.bit_length() + 7) // 8, byteorder="big"
543+
)
544+
)
545+
.decode("utf-8")
546+
.rstrip("="),
547+
},
548+
{
549+
"kty": "RSA",
550+
"kid": "mock_kid",
551+
"use": "sig",
552+
"n": base64.urlsafe_b64encode(
553+
public_key.public_numbers().n.to_bytes(
554+
(public_key.public_numbers().n.bit_length() + 7) // 8, byteorder="big"
555+
)
556+
)
557+
.decode("utf-8")
558+
.rstrip("="),
559+
"e": base64.urlsafe_b64encode(
560+
public_key.public_numbers().e.to_bytes(
561+
(public_key.public_numbers().e.bit_length() + 7) // 8, byteorder="big"
562+
)
563+
)
564+
.decode("utf-8")
565+
.rstrip("="),
566+
},
567+
]
568+
}
569+
570+
pem_key = await helper.create_pem_format(mock_jwks, mock_token)
571+
572+
# Assert that the result is bytes
573+
assert isinstance(pem_key, bytes), "create_pem_format should return bytes"
574+
575+
# Convert bytes to string for regex matching
576+
pem_str = pem_key.decode("utf-8")
577+
578+
# Assert that the key starts and ends with the correct markers
579+
assert pem_str.startswith("-----BEGIN PUBLIC KEY-----"), "PEM key should start with the correct marker"
580+
assert pem_str.endswith("-----END PUBLIC KEY-----\n"), "PEM key should end with the correct marker"
581+
582+
# Assert that the format matches the structure of a PEM key
583+
pem_regex = r"^-----BEGIN PUBLIC KEY-----\n([A-Za-z0-9+/\n]+={0,2})\n-----END PUBLIC KEY-----\n$"
584+
assert re.match(pem_regex, pem_str), "PEM key format is incorrect"
585+
586+
# Verify that the key can be used to decode the token
587+
try:
588+
decoded = jwt.decode(
589+
mock_token, key=pem_key, algorithms=["RS256"], audience=payload["aud"], issuer=payload["iss"]
590+
)
591+
assert decoded["oid"] == payload["oid"], "Decoded token should contain correct OID"
592+
except Exception as e:
593+
pytest.fail(f"jwt.decode raised an unexpected exception: {str(e)}")
594+
595+
# Try to load the key using cryptography library to ensure it's a valid PEM format
596+
try:
597+
loaded_public_key = serialization.load_pem_public_key(pem_key)
598+
assert isinstance(loaded_public_key, rsa.RSAPublicKey), "Loaded key should be an RSA public key"
599+
except Exception as e:
600+
pytest.fail(f"Failed to load PEM key: {str(e)}")
601+
602+
603+
@pytest.mark.asyncio
604+
async def test_validate_access_token(monkeypatch, mock_confidential_client_success):
605+
mock_token, public_key, payload = create_mock_jwt(oid="OID_X")
606+
607+
def mock_get(*args, **kwargs):
608+
return MockResponse(
609+
status=200,
610+
text=json.dumps(
611+
{
612+
"keys": [
613+
{
614+
"kty": "RSA",
615+
"use": "sig",
616+
"kid": "23nt",
617+
"x5t": "23nt",
618+
"n": "hu2SJ",
619+
"e": "AQAB",
620+
"x5c": ["MIIC/jCC"],
621+
"issuer": "https://login.microsoftonline.com/TENANT_ID/v2.0",
622+
},
623+
{
624+
"kty": "RSA",
625+
"use": "sig",
626+
"kid": "MGLq",
627+
"x5t": "MGLq",
628+
"n": "yfNcG8",
629+
"e": "AQAB",
630+
"x5c": ["MIIC/jCC"],
631+
"issuer": "https://login.microsoftonline.com/TENANT_ID/v2.0",
632+
},
633+
]
634+
}
635+
),
636+
)
637+
638+
monkeypatch.setattr(aiohttp.ClientSession, "get", mock_get)
639+
640+
def mock_decode(*args, **kwargs):
641+
return payload
642+
643+
monkeypatch.setattr(jwt, "decode", mock_decode)
644+
645+
async def mock_create_pem_format(*args, **kwargs):
646+
return public_key
647+
648+
monkeypatch.setattr(AuthenticationHelper, "create_pem_format", mock_create_pem_format)
649+
650+
helper = create_authentication_helper()
651+
await helper.validate_access_token(mock_token)

0 commit comments

Comments
 (0)