Skip to content

Commit ab06aad

Browse files
Add token validation (#1186)
* update requirements * add token validation * fix typchecker * fixing tests * fix comment linke 404 --------- Co-authored-by: Matt Gotteiner <[email protected]>
1 parent 2e900be commit ab06aad

File tree

9 files changed

+164
-22
lines changed

9 files changed

+164
-22
lines changed

app/backend/core/authentication.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,15 @@
77
import aiohttp
88
from azure.search.documents.aio import SearchClient
99
from azure.search.documents.indexes.models import SearchIndex
10+
from jose import jwt
1011
from msal import ConfidentialClientApplication
1112
from msal.token_cache import TokenCache
13+
from tenacity import (
14+
AsyncRetrying,
15+
retry_if_exception_type,
16+
stop_after_attempt,
17+
wait_random_exponential,
18+
)
1219

1320

1421
# AuthError is raised when the authentication token sent by the client UI cannot be parsed or there is an authentication error accessing the graph API
@@ -40,6 +47,15 @@ def __init__(
4047
self.client_app_id = client_app_id
4148
self.tenant_id = tenant_id
4249
self.authority = f"https://login.microsoftonline.com/{tenant_id}"
50+
# Depending on if requestedAccessTokenVersion is 1 or 2, the issuer and audience of the token may be different
51+
# See https://learn.microsoft.com/graph/api/resources/apiapplication
52+
self.valid_issuers = [
53+
f"https://sts.windows.net/{tenant_id}/",
54+
f"https://login.microsoftonline.com/{tenant_id}/v2.0",
55+
]
56+
self.valid_audiences = [f"api://{server_app_id}", str(server_app_id)]
57+
# See https://learn.microsoft.com/entra/identity-platform/access-tokens#validate-the-issuer for more information on token validation
58+
self.key_url = f"{self.authority}/discovery/v2.0/keys"
4359

4460
if self.use_authentication:
4561
field_names = [field.name for field in search_index.fields] if search_index else []
@@ -182,6 +198,11 @@ async def get_auth_claims_if_enabled(self, headers: dict) -> dict[str, Any]:
182198
# The scope is set to the Microsoft Graph API, which may need to be called for more authorization information
183199
# https://learn.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-on-behalf-of-flow
184200
auth_token = AuthenticationHelper.get_token_auth_header(headers)
201+
# Validate the token before use
202+
await self.validate_access_token(auth_token)
203+
204+
# Use the on-behalf-of-flow to acquire another token for use with Microsoft Graph
205+
# See https://learn.microsoft.com/entra/identity-platform/v2-oauth2-on-behalf-of-flow for more information
185206
graph_resource_access_token = self.confidential_client.acquire_token_on_behalf_of(
186207
user_assertion=auth_token, scopes=["https://graph.microsoft.com/.default"]
187208
)
@@ -207,7 +228,6 @@ async def get_auth_claims_if_enabled(self, headers: dict) -> dict[str, Any]:
207228
auth_claims["groups"] = await AuthenticationHelper.list_groups(graph_resource_access_token)
208229
return auth_claims
209230
except AuthError as e:
210-
print(e.error)
211231
logging.exception("Exception getting authorization information - " + json.dumps(e.error))
212232
if self.require_access_control:
213233
raise
@@ -237,3 +257,74 @@ async def check_path_auth(self, path: str, auth_claims: dict[str, Any], search_c
237257
break
238258

239259
return allowed
260+
261+
# See https://github.com/Azure-Samples/ms-identity-python-on-behalf-of/blob/939be02b11f1604814532fdacc2c2eccd198b755/FlaskAPI/helpers/authorization.py#L44
262+
async def validate_access_token(self, token: str):
263+
"""
264+
Validate an access token is issued by Entra
265+
"""
266+
jwks = None
267+
async for attempt in AsyncRetrying(
268+
retry=retry_if_exception_type(AuthError),
269+
wait=wait_random_exponential(min=15, max=60),
270+
stop=stop_after_attempt(5),
271+
):
272+
with attempt:
273+
async with aiohttp.ClientSession() as session:
274+
async with session.get(url=self.key_url) as resp:
275+
resp_status = resp.status
276+
if resp_status in [500, 502, 503, 504]:
277+
raise AuthError(
278+
error=f"Failed to get keys info: {await resp.text()}", status_code=resp_status
279+
)
280+
jwks = await resp.json()
281+
282+
if not jwks or "keys" not in jwks:
283+
raise AuthError({"code": "invalid_keys", "description": "Unable to get keys to validate auth token."}, 401)
284+
285+
rsa_key = None
286+
issuer = None
287+
audience = None
288+
try:
289+
unverified_header = jwt.get_unverified_header(token)
290+
unverified_claims = jwt.get_unverified_claims(token)
291+
issuer = unverified_claims.get("iss")
292+
audience = unverified_claims.get("aud")
293+
for key in jwks["keys"]:
294+
if key["kid"] == unverified_header["kid"]:
295+
rsa_key = {"kty": key["kty"], "kid": key["kid"], "use": key["use"], "n": key["n"], "e": key["e"]}
296+
break
297+
except Exception as exc:
298+
raise AuthError(
299+
{"code": "invalid_header", "description": "Unable to parse authorization token."}, 401
300+
) from exc
301+
if not rsa_key:
302+
raise AuthError({"code": "invalid_header", "description": "Unable to find appropriate key"}, 401)
303+
304+
if issuer not in self.valid_issuers:
305+
raise AuthError(
306+
{"code": "invalid_header", "description": f"Issuer {issuer} not in {','.join(self.valid_issuers)}"}, 401
307+
)
308+
309+
if audience not in self.valid_audiences:
310+
raise AuthError(
311+
{
312+
"code": "invalid_header",
313+
"description": f"Audience {audience} not in {','.join(self.valid_audiences)}",
314+
},
315+
401,
316+
)
317+
318+
try:
319+
jwt.decode(token, rsa_key, algorithms=["RS256"], audience=audience, issuer=issuer)
320+
except jwt.ExpiredSignatureError as jwt_expired_exc:
321+
raise AuthError({"code": "token_expired", "description": "token is expired"}, 401) from jwt_expired_exc
322+
except jwt.JWTClaimsError as jwt_claims_exc:
323+
raise AuthError(
324+
{"code": "invalid_claims", "description": "incorrect claims," "please check the audience and issuer"},
325+
401,
326+
) from jwt_claims_exc
327+
except Exception as exc:
328+
raise AuthError(
329+
{"code": "invalid_header", "description": "Unable to parse authorization token."}, 401
330+
) from exc

app/backend/requirements.in

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ quart
33
quart-cors
44
openai[datalib]>=1.3.7
55
tiktoken
6+
tenacity
67
azure-search-documents==11.4.0b11
78
azure-storage-blob
89
uvicorn
@@ -14,3 +15,5 @@ opentelemetry-instrumentation-requests
1415
opentelemetry-instrumentation-aiohttp-client
1516
msal
1617
azure-keyvault-secrets
18+
cryptography
19+
python-jose[cryptography]

app/backend/requirements.txt

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,20 @@ click==8.1.7
6767
# flask
6868
# quart
6969
# uvicorn
70-
cryptography==41.0.7
70+
cryptography==42.0.1
7171
# via
72+
# -r requirements.in
7273
# azure-identity
7374
# azure-storage-blob
7475
# msal
7576
# pyjwt
77+
# python-jose
7678
deprecated==1.2.14
7779
# via opentelemetry-api
7880
distro==1.8.0
7981
# via openai
82+
ecdsa==0.18.0
83+
# via python-jose
8084
fixedint==0.1.6
8185
# via azure-monitor-opentelemetry-exporter
8286
flask==3.0.0
@@ -257,18 +261,22 @@ portalocker==2.8.2
257261
# via msal-extensions
258262
priority==2.0.0
259263
# via hypercorn
264+
pyasn1==0.5.1
265+
# via
266+
# python-jose
267+
# rsa
260268
pycparser==2.21
261269
# via cffi
262270
pydantic==2.5.2
263271
# via openai
264272
pydantic-core==2.14.5
265273
# via pydantic
266274
pyjwt[crypto]==2.8.0
267-
# via
268-
# msal
269-
# pyjwt
275+
# via msal
270276
python-dateutil==2.8.2
271277
# via pandas
278+
python-jose[cryptography]==3.3.0
279+
# via -r requirements.in
272280
pytz==2023.3.post1
273281
# via pandas
274282
quart==0.19.4
@@ -288,16 +296,21 @@ requests==2.31.0
288296
# tiktoken
289297
requests-oauthlib==1.3.1
290298
# via msrest
299+
rsa==4.9
300+
# via python-jose
291301
six==1.16.0
292302
# via
293303
# azure-core
304+
# ecdsa
294305
# isodate
295306
# python-dateutil
296307
sniffio==1.3.0
297308
# via
298309
# anyio
299310
# httpx
300311
# openai
312+
tenacity==8.2.3
313+
# via -r requirements.in
301314
tiktoken==0.5.2
302315
# via -r requirements.in
303316
tqdm==4.66.1

app/mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,6 @@
22

33
[mypy-msal.*]
44
ignore_missing_imports = True
5+
6+
[mypy-jose.*]
7+
ignore_missing_imports = True

scripts/auth_init.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ def create_server_app_permission_setup_payload(server_app_id: str):
109109
"type": "User",
110110
}
111111
],
112+
# Required to match v2.0 OpenID configuration for token validation
113+
# Learn more at https://learn.microsoft.com/entra/identity-platform/v2-protocols-oidc#find-your-apps-openid-configuration-document-uri
114+
"requestedAccessTokenVersion": 2,
112115
},
113116
"requiredResourceAccess": [
114117
{

scripts/requirements.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@ azure-keyvault-secrets
1212
Pillow
1313
PyMuPDF
1414
types-Pillow
15+
cryptography
16+
python-jose[cryptography]

scripts/requirements.txt

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,18 @@ cffi==1.16.0
5454
# via cryptography
5555
charset-normalizer==3.3.2
5656
# via requests
57-
cryptography==41.0.7
57+
cryptography==42.0.1
5858
# via
59+
# -r requirements.in
5960
# azure-identity
6061
# azure-storage-blob
6162
# msal
6263
# pyjwt
64+
# python-jose
6365
distro==1.8.0
6466
# via openai
67+
ecdsa==0.18.0
68+
# via python-jose
6569
frozenlist==1.4.0
6670
# via
6771
# aiohttp
@@ -116,16 +120,18 @@ pillow==10.2.0
116120
# via -r requirements.in
117121
portalocker==2.8.2
118122
# via msal-extensions
123+
pyasn1==0.5.1
124+
# via
125+
# python-jose
126+
# rsa
119127
pycparser==2.21
120128
# via cffi
121129
pydantic==2.5.2
122130
# via openai
123131
pydantic-core==2.14.5
124132
# via pydantic
125133
pyjwt[crypto]==2.8.0
126-
# via
127-
# msal
128-
# pyjwt
134+
# via msal
129135
pymupdf==1.23.7
130136
# via -r requirements.in
131137
pymupdfb==1.23.7
@@ -134,6 +140,8 @@ pypdf==3.17.1
134140
# via -r requirements.in
135141
python-dateutil==2.8.2
136142
# via pandas
143+
python-jose[cryptography]==3.3.0
144+
# via -r requirements.in
137145
pytz==2023.3.post1
138146
# via pandas
139147
regex==2023.10.3
@@ -147,9 +155,12 @@ requests==2.31.0
147155
# tiktoken
148156
requests-oauthlib==1.3.1
149157
# via msrest
158+
rsa==4.9
159+
# via python-jose
150160
six==1.16.0
151161
# via
152162
# azure-core
163+
# ecdsa
153164
# isodate
154165
# python-dateutil
155166
sniffio==1.3.0

tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from openai.types.create_embedding_response import Usage
2424

2525
import app
26+
import core
2627
from core.authentication import AuthenticationHelper
2728

2829
from .mocks import (
@@ -302,6 +303,7 @@ async def auth_client(
302303
mock_openai_chatcompletion,
303304
mock_openai_embedding,
304305
mock_confidential_client_success,
306+
mock_validate_token_success,
305307
mock_list_groups_success,
306308
mock_acs_search_filter,
307309
mock_get_secret,
@@ -329,6 +331,14 @@ async def auth_client(
329331
yield client
330332

331333

334+
@pytest.fixture
335+
def mock_validate_token_success(monkeypatch):
336+
async def mock_validate_access_token(self, token):
337+
pass
338+
339+
monkeypatch.setattr(core.authentication.AuthenticationHelper, "validate_access_token", mock_validate_access_token)
340+
341+
332342
@pytest.fixture
333343
def mock_confidential_client_success(monkeypatch):
334344
def mock_acquire_token_on_behalf_of(self, *args, **kwargs):

0 commit comments

Comments
 (0)